File size: 1,561 Bytes
eff3d61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
from torchvision import transforms
from transformers import CLIPModel, SiglipModel
from src import constants
# INFERENCE
def run_inference(vision_lang_encoder, olf_encoder, graph_model, image, olf_vec):
vision_lang_encoder.eval()
olf_encoder.eval()
graph_model.eval()
transform = transforms.Compose([
transforms.Resize((constants.IMG_DIM, constants.IMG_DIM)),
transforms.ToTensor(),
])
image_tensor = transform(image).unsqueeze(0).to(constants.DEVICE)
olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(constants.DEVICE)
with torch.no_grad():
vision_embed = vision_lang_encoder.get_image_features(pixel_values=image_tensor)
olf_embed = olf_encoder(olf_tensor)
logits = graph_model(vision_embed, olf_embed).squeeze()
return logits
def load_model():
# Use CLIP as default baseline
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(constants.DEVICE)
clip_model.eval()
"""
Or, you can also use SigLIP:
SiglipModel.from_pretrained(
"google/siglip-so400m-patch14-384",
attn_implementation="flash_attention_2",
dtype=torch.float16,
device_map=constants.DEVICE,
)
"""
olf_encoder = torch.jit.load(constants.ENCODER_SMALL_BASE_PATH).to(constants.DEVICE)
olf_encoder.eval()
graph_model = torch.jit.load(constants.OVLE_SMALL_BASE_PATH).to(constants.DEVICE)
graph_model.eval()
return clip_model, olf_encoder, graph_model
|