File size: 1,732 Bytes
eff3d61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from torchvision import transforms
from transformers import CLIPModel, SiglipModel

from src import constants


# INFERENCE
def run_inference(vision_lang_encoder, olf_encoder, graph_model, image, olf_vec):
    vision_lang_encoder.eval()
    olf_encoder.eval()
    graph_model.eval()

    transform = transforms.Compose([
        transforms.Resize((constants.IMG_DIM, constants.IMG_DIM)),
        transforms.ToTensor(),
    ])

    image_tensor = transform(image).unsqueeze(0).to(constants.DEVICE)
    olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(constants.DEVICE)

    with torch.no_grad():
        vision_embed = vision_lang_encoder.get_image_features(pixel_values=image_tensor)
        olf_embed = olf_encoder(olf_tensor)

        nodes = torch.cat([vision_embed, olf_embed], dim=0)
        edge_index = torch.cartesian_prod(torch.arange(nodes.size(0)), torch.arange(nodes.size(0))).T.to(constants.DEVICE)
        logits = graph_model(nodes, edge_index)

    return logits


def load_model():
    # Use CLIP as default baseline
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(constants.DEVICE)
    clip_model.eval()
    """
    Or, you can also use SigLIP:
        SiglipModel.from_pretrained(
            "google/siglip-so400m-patch14-384",
            attn_implementation="flash_attention_2",
            dtype=torch.float16,
            device_map=constants.DEVICE,
        )
    """
    olf_encoder = torch.jit.load(constants.ENCODER_SMALL_GRAPH_PATH).to(constants.DEVICE)
    olf_encoder.eval()
    graph_model = torch.jit.load(constants.OVLE_SMALL_GRAPH_PATH).to(constants.DEVICE)
    graph_model.eval()

    return clip_model, olf_encoder, graph_model