File size: 1,404 Bytes
eff3d61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
from PIL import Image
from src import constants
from src import base_model as bm
from src import graph_model as gm
if __name__ == "__main__":
# Build example vision-olfaction sample with dummy data
example_image = Image.new('RGB', (constants.IMG_DIM, constants.IMG_DIM))
example_image.save(f"/tmp/image_example.jpg")
example_olf_vec = torch.randn(constants.AROMA_VEC_LENGTH)
# -------- Option A --------
# Load the base models
vision_lang_encoder, olf_encoder, graph_model = bm.load_model()
# Get probability from base models
ovl_classifier_base = bm.run_inference(
vision_lang_encoder=vision_lang_encoder,
olf_encoder=olf_encoder,
graph_model=graph_model,
image=example_image,
olf_vec=example_olf_vec
)
print(f"Olfaction-Vision-Language Logits from Base Model: {ovl_classifier_base}")
# -------- Option B --------
# Load the graph attention models
vision_lang_encoder, olf_encoder, graph_model = gm.load_model()
# Get probability from graph attention models
ovl_classifier_graph = gm.run_inference(
vision_lang_encoder=vision_lang_encoder,
olf_encoder=olf_encoder,
graph_model=graph_model,
image=example_image,
olf_vec=example_olf_vec
)
print(f"Olfaction-Vision-Language Logits from Graph Attention Model: {ovl_classifier_graph}")
|