| import torch | |
| from PIL import Image | |
| from src import constants | |
| from src import base_model as bm | |
| from src import graph_model as gm | |
| if __name__ == "__main__": | |
| # Build example vision-olfaction sample with dummy data | |
| example_image = Image.new('RGB', (constants.IMG_DIM, constants.IMG_DIM)) | |
| example_image.save(f"/tmp/image_example.jpg") | |
| example_olf_vec = torch.randn(constants.AROMA_VEC_LENGTH) | |
| # -------- Option A -------- | |
| # Load the base models | |
| vision_lang_encoder, olf_encoder, graph_model = bm.load_model() | |
| # Get probability from base models | |
| ovl_classifier_base = bm.run_inference( | |
| vision_lang_encoder=vision_lang_encoder, | |
| olf_encoder=olf_encoder, | |
| graph_model=graph_model, | |
| image=example_image, | |
| olf_vec=example_olf_vec | |
| ) | |
| print(f"Olfaction-Vision-Language Logits from Base Model: {ovl_classifier_base}") | |
| # -------- Option B -------- | |
| # Load the graph attention models | |
| vision_lang_encoder, olf_encoder, graph_model = gm.load_model() | |
| # Get probability from graph attention models | |
| ovl_classifier_graph = gm.run_inference( | |
| vision_lang_encoder=vision_lang_encoder, | |
| olf_encoder=olf_encoder, | |
| graph_model=graph_model, | |
| image=example_image, | |
| olf_vec=example_olf_vec | |
| ) | |
| print(f"Olfaction-Vision-Language Logits from Graph Attention Model: {ovl_classifier_graph}") | |