File size: 1,404 Bytes
eff3d61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from PIL import Image

from src import constants
from src import base_model as bm
from src import graph_model as gm


if __name__ == "__main__":

    # Build example vision-olfaction sample with dummy data
    example_image = Image.new('RGB', (constants.IMG_DIM, constants.IMG_DIM))
    example_image.save(f"/tmp/image_example.jpg")
    example_olf_vec = torch.randn(constants.AROMA_VEC_LENGTH)

    # -------- Option A --------
    # Load the base models
    vision_lang_encoder, olf_encoder, graph_model = bm.load_model()
    # Get probability from base models
    ovl_classifier_base = bm.run_inference(
        vision_lang_encoder=vision_lang_encoder,
        olf_encoder=olf_encoder,
        graph_model=graph_model,
        image=example_image,
        olf_vec=example_olf_vec
    )
    print(f"Olfaction-Vision-Language Logits from Base Model: {ovl_classifier_base}")

    # -------- Option B --------
    # Load the graph attention models
    vision_lang_encoder, olf_encoder, graph_model = gm.load_model()
    # Get probability from graph attention models
    ovl_classifier_graph = gm.run_inference(
        vision_lang_encoder=vision_lang_encoder,
        olf_encoder=olf_encoder,
        graph_model=graph_model,
        image=example_image,
        olf_vec=example_olf_vec
    )
    print(f"Olfaction-Vision-Language Logits from Graph Attention Model: {ovl_classifier_graph}")