Upload folder using huggingface_hub
Browse files- .gitignore +4 -0
- README.md +127 -0
- model/ovlc/base/gnn_512_c.pt +3 -0
- model/ovlc/base/gnn_512_c.safetensors +3 -0
- model/ovlc/base/gnn_state_dict_512_c.pth +3 -0
- model/ovlc/base/olf_encoder_512_c.pt +3 -0
- model/ovlc/base/olf_encoder_512_c.safetensors +3 -0
- model/ovlc/base/olf_encoder_state_dict_512_c.pth +3 -0
- model/ovlc/gat/gat_gnn_512_c.pt +3 -0
- model/ovlc/gat/gat_gnn_512_c.safetensors +3 -0
- model/ovlc/gat/gat_gnn_state_dict_512_c.pth +3 -0
- model/ovlc/gat/gat_olf_encoder_512_c.safetensors +3 -0
- model/ovlc/gat/gat_olf_encoder_state_dict_512_c.pth +3 -0
- model/ovlc/gat/olf_encoder_gat_512_c.pt +3 -0
- model_cards/ovl_classifier.md +110 -0
- notebooks/olfaction_vision_language_classifier_inference.ipynb +235 -0
- requirements.txt +33 -0
- src/base_model.py +52 -0
- src/constants.py +19 -0
- src/graph_model.py +48 -0
- src/main.py +41 -0
- src/utils.py +29 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*__pycache__
|
| 2 |
+
*.idea
|
| 3 |
+
*.DS_Store
|
| 4 |
+
*data/
|
README.md
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
tags:
|
| 5 |
+
- classifier
|
| 6 |
+
- multimodal
|
| 7 |
+
- olfaction-vision-language
|
| 8 |
+
- olfaction
|
| 9 |
+
- olfactory
|
| 10 |
+
- scentience
|
| 11 |
+
- neural-network
|
| 12 |
+
- graph-neural-network
|
| 13 |
+
- gnn
|
| 14 |
+
- vision-language
|
| 15 |
+
- vision
|
| 16 |
+
- language
|
| 17 |
+
- robotics
|
| 18 |
+
- multimodal
|
| 19 |
+
- smell
|
| 20 |
+
license: mit
|
| 21 |
+
datasets:
|
| 22 |
+
- kordelfrance/olfaction-vision-language-dataset
|
| 23 |
+
- detection-datasets/coco
|
| 24 |
+
- seyonec/goodscents_leffingwell
|
| 25 |
+
base_model: Scentience-OVL-Classifiers-Base
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
# Olfaction-Vision-Language Classifiers
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
[](#license)
|
| 32 |
+
[](https://colab.research.google.com/drive/1H5OSeO43YfhAT9MqcJKaaSknFYhjimvg?usp=sharing)
|
| 33 |
+
[](https://arxiv.org/abs/2506.00398)
|
| 34 |
+
[](https://huggingface.co/kordelfrance/Olfaction-Vision-Language-Embeddings)
|
| 35 |
+
|
| 36 |
+
</div>
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Description
|
| 41 |
+
|
| 42 |
+
This repository is a foundational series of multimodal joint classifier models trained on olfaction, vision, and language data.
|
| 43 |
+
It is meant as a quick start on loading the olfaction-vision-language models and getting the probability/logits of the presence of observed chemical compounds in a visual scene given a set of aroma descriptors.
|
| 44 |
+
For example, given an input image and a set of observed aromas (fruity, musky, etc), what is the probability that acetone is present?
|
| 45 |
+
|
| 46 |
+
Based on the original series of [embeddings models here](https://huggingface.co/kordelfrance/Olfaction-Vision-Language-Embeddings), these models are built specifically for prototyping and exploratory tasks within AR/VR, robotics, and embodied artificial intelligence.
|
| 47 |
+
Analogous to how CLIP and SigLIP embeddings give vision-language relationships, our embeddings models here give olfaction-vision-language (OVL) relationships.
|
| 48 |
+
|
| 49 |
+
Whether these models are used for better vision-scent navigation with drones, triangulating the source of an odor in an image, extracting aromas from a scene, or augmenting a VR experience with scent, we hope their release will catalyze further research in olfaction, especially olfactory robotics.
|
| 50 |
+
We especially hope these models encourage the community to contribute to building standardized datasets and evaluation protocols for olfaction-vision-language learning.
|
| 51 |
+
|
| 52 |
+
## Models
|
| 53 |
+
We offer four olfaction-vision-language (OVL) classifier models with this repository:
|
| 54 |
+
- (1) `ovlc-gat`: The OVL base model built around a graph-attention network. This model is optimal for online tasks where accuracy is paramount and inference time is not as critical.
|
| 55 |
+
- (2) `ovlc-base`: The original OVL base model optimized for faster inference and edge-based robotics. This model is optimized for export to common frameworks that run on Android, iOS, Rust, and others.
|
| 56 |
+
|
| 57 |
+
## Training Data
|
| 58 |
+
A sample dataset is included, but the full datasets are linked in the `Datasets` pane of this repo.
|
| 59 |
+
Training code for replicating full construction of all models will be released soon.
|
| 60 |
+
|
| 61 |
+
Please refer to original series of [embeddings models here](https://huggingface.co/kordelfrance/Olfaction-Vision-Language-Embeddings) for more information.
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## Directory Structure
|
| 65 |
+
|
| 66 |
+
```text
|
| 67 |
+
Olfaction-Vision-Language-Classifier-Models/
|
| 68 |
+
├── data/ # Sample training dataset
|
| 69 |
+
├── requirements.txt # Python dependencies
|
| 70 |
+
├── model/ # Classifier models
|
| 71 |
+
├── model_cards/ # Specifications for each embedding model
|
| 72 |
+
├── notebooks/ # Notebooks for loading the models for inference
|
| 73 |
+
├── src/ # Source code for inference, model loading, utils
|
| 74 |
+
└── README.md # Overview of repository contributions and usage
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Citation
|
| 80 |
+
If you use any of these models, please cite:
|
| 81 |
+
```
|
| 82 |
+
@misc{france2025ovlembeddings,
|
| 83 |
+
title = {Scentience-OVLE-v1: Joint Olfaction-Vision-Language Embeddings},
|
| 84 |
+
author = {Kordel Kade France},
|
| 85 |
+
year = {2025},
|
| 86 |
+
howpublished = {Hugging Face},
|
| 87 |
+
url = {https://huggingface.co/kordelfrance/Olfaction-Vision-Language-Embeddings}
|
| 88 |
+
}
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
```
|
| 92 |
+
@misc{france2025olfactionstandards,
|
| 93 |
+
title={Position: Olfaction Standardization is Essential for the Advancement of Embodied Artificial Intelligence},
|
| 94 |
+
author={Kordel K. France and Rohith Peddi and Nik Dennler and Ovidiu Daescu},
|
| 95 |
+
year={2025},
|
| 96 |
+
eprint={2506.00398},
|
| 97 |
+
archivePrefix={arXiv},
|
| 98 |
+
primaryClass={cs.AI},
|
| 99 |
+
url={https://arxiv.org/abs/2506.00398},
|
| 100 |
+
}
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
If you leverage the CLIP or SigLIP models, please cite:
|
| 105 |
+
```
|
| 106 |
+
@misc{radford2021clip,
|
| 107 |
+
title = {Learning Transferable Visual Models From Natural Language Supervision},
|
| 108 |
+
author = {Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
|
| 109 |
+
year = 2021,
|
| 110 |
+
url = {https://arxiv.org/abs/2103.00020},
|
| 111 |
+
eprint = {2103.00020},
|
| 112 |
+
archiveprefix = {arXiv},
|
| 113 |
+
primaryclass = {cs.CV}
|
| 114 |
+
}
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
```
|
| 118 |
+
@misc{zhai2023siglip,
|
| 119 |
+
title={Sigmoid Loss for Language Image Pre-Training},
|
| 120 |
+
author={Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
|
| 121 |
+
year={2023},
|
| 122 |
+
eprint={2303.15343},
|
| 123 |
+
archivePrefix={arXiv},
|
| 124 |
+
primaryClass={cs.CV},
|
| 125 |
+
url={https://arxiv.org/abs/2303.15343},
|
| 126 |
+
}
|
| 127 |
+
```
|
model/ovlc/base/gnn_512_c.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb4b6457aabc3b952367066bf1dc651d8ff04612b7e921a6c1110b5d880ccce1
|
| 3 |
+
size 6316648
|
model/ovlc/base/gnn_512_c.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:413f346ff9acfec2e4070a66ac54dd80a2486a81ff2e253a3c0a5a68ee950310
|
| 3 |
+
size 6304684
|
model/ovlc/base/gnn_state_dict_512_c.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47203e8bc45a421252e1f2224d68ad11a58952b5c5cf321f36ca924864e601b9
|
| 3 |
+
size 6308841
|
model/ovlc/base/olf_encoder_512_c.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5abeab85ebef72e56b1554cba0e03eb37e99e67adf0093ef67c9d46fcca23db8
|
| 3 |
+
size 1284568
|
model/ovlc/base/olf_encoder_512_c.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c01e4857b307039fa54657c6840783788dd24c09da6bce909cf084925ac6f966
|
| 3 |
+
size 1269944
|
model/ovlc/base/olf_encoder_state_dict_512_c.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00d8b4b52c8a49d21bca5e6c3610f9c7b567a39f1838ad86d7363eddfb109627
|
| 3 |
+
size 1277657
|
model/ovlc/gat/gat_gnn_512_c.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d19648560bf0d48a3e63daf9e3c610f29d0471cfdab0d3e70098a2c34d16fb8
|
| 3 |
+
size 34801880
|
model/ovlc/gat/gat_gnn_512_c.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5533a3a457bfa0e95f3a7cffe122343dc25cbdfedabb31cc381ad65a12fb6c24
|
| 3 |
+
size 34712556
|
model/ovlc/gat/gat_gnn_state_dict_512_c.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e0878cd5f89dee08aa87860a724a1c13420994b2c7fee651b053764aaed026d
|
| 3 |
+
size 34716913
|
model/ovlc/gat/gat_olf_encoder_512_c.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1fdef4dd630e6ec48f34ea74a3b9a4d003a93f884d6ce707fe6a3ae81ca69d11
|
| 3 |
+
size 1269944
|
model/ovlc/gat/gat_olf_encoder_state_dict_512_c.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b0bd47e1c2e63a17c9c84ce7eeb5956f004857c012e46f2bb365b6496a332c54
|
| 3 |
+
size 1278625
|
model/ovlc/gat/olf_encoder_gat_512_c.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9419610adcb252387f080625c23e5aff62d0d4218ca7caae9f7aac8c5077f28a
|
| 3 |
+
size 1284920
|
model_cards/ovl_classifier.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Card: Scentience-OVLE-Large-v1
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Model Details
|
| 5 |
+
- **Model Name:** `Scentience OVLE Large v1`
|
| 6 |
+
- **Developed by:** Kordel K. France
|
| 7 |
+
- **Date:** September 2025
|
| 8 |
+
- **Architecture:**
|
| 9 |
+
- **Olfaction encoder:** 138-sensor embedding
|
| 10 |
+
- **Vision encoder:** CLIP-based
|
| 11 |
+
- **Language encoder:** CLIP-based
|
| 12 |
+
- **Fusion strategy:** Joint embedding space via multimodal contrastive training
|
| 13 |
+
- **Parameter Count (Base):** 29.9M (without CLIP), 181.2M (with CLIP)
|
| 14 |
+
- **Parameter Count (GAT):** 143.2M (without CLIP), 294.5M (with CLIP)
|
| 15 |
+
- **Embedding Dimension:** 2048
|
| 16 |
+
- **License:** MIT
|
| 17 |
+
- **Contact:** kordel@scentience.ai, kordel.france@utdallas.edu
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Intended Use
|
| 22 |
+
- **Primary purpose:** Research in multimodal machine learning involving olfaction, vision, and language.
|
| 23 |
+
- **Example applications:**
|
| 24 |
+
- Cross-modal retrieval (odor → image, odor → text, etc.)
|
| 25 |
+
- Robotics and UAV navigation guided by chemical cues
|
| 26 |
+
- Chemical dataset exploration and visualization
|
| 27 |
+
- **Intended users:** Researchers, developers, and educators working in ML, robotics, chemistry, and HCI.
|
| 28 |
+
- **Out of scope:** Not intended for safety-critical tasks (e.g., gas leak detection, medical diagnosis, or regulatory use).
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## Training Data
|
| 33 |
+
- **Olfaction data:** Language-aligned olfactory data curated from GoodScents and LeffingWell datasets.
|
| 34 |
+
- **Vision data:** COCO dataset.
|
| 35 |
+
- **Language data:** Smell descriptors and text annotations curated from literature.
|
| 36 |
+
|
| 37 |
+
For more information on how the training data was accumulated, please see the [HuggingFace dataset URL here](https://huggingface.co/datasets/kordelfrance/olfaction-vision-language-dataset)
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## Evaluation
|
| 42 |
+
- Retrieval tasks: odor→image (Top-5 recall = 62%)
|
| 43 |
+
- Odor descriptor classification accuracy = 71%
|
| 44 |
+
- Cross-modal embedding alignment qualitatively verified on 200 sample triplets.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Limitations of Evaluation
|
| 49 |
+
To the best of our knowledge, there are currently no open-source datasets that provide aligned olfactory, visual, and linguistic annotations. A “true” multimodal evaluation would require measuring the chemical composition of scenes (e.g., using gas chromatography mass spectrometry) while simultaneously capturing images and collecting perceptual descriptors from human olfactory judges. Such a benchmark would demand substantial new data collection efforts and instrumentation.
|
| 50 |
+
Consequently, we evaluate our models indirectly, using surrogate metrics (e.g., cross-modal retrieval performance, odor descriptor classification accuracy, clustering quality). While these evaluations do not provide ground-truth verification of odor presence in images, they offer a first step toward demonstrating alignment between modalities.
|
| 51 |
+
We draw analogy from past successes in ML datasets such as precursors to CLIP that lacked large paired datasets and were evaluated on retrieval-like tasks.
|
| 52 |
+
As a result, we release this model to catalyze further research and encourage the community to contribute to building standardized datasets and evaluation protocols for olfaction-vision-language learning.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## Limitations
|
| 57 |
+
- Limited odor diversity (approx. 5000 unique compounds).
|
| 58 |
+
- Embeddings depend on sensor calibration; not guaranteed across devices.
|
| 59 |
+
- Cultural subjectivity in smell annotations may bias embeddings.
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## Ethical Considerations
|
| 64 |
+
- Not to be used for covert detection of substances or surveillance.
|
| 65 |
+
- Unreliable in safety-critical contexts (e.g., gas leak detection).
|
| 66 |
+
- Recognizes cultural sensitivity in smell perception.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Environmental Impact
|
| 71 |
+
- Trained on 4×A100 GPUs for 48 hours (~200 kg CO2eq).
|
| 72 |
+
- Sensor dataset collection required ~500 lab hours.
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## Citation
|
| 77 |
+
If you use this model, please cite:
|
| 78 |
+
```
|
| 79 |
+
@misc{france2025ovlembeddings,
|
| 80 |
+
title = {Scentience-OVLE-Base-v1: Joint Olfaction-Vision-Language Embeddings},
|
| 81 |
+
author = {Kordel Kade France},
|
| 82 |
+
year = {2025},
|
| 83 |
+
howpublished = {Hugging Face},
|
| 84 |
+
url = {https://huggingface.co/kordelfrance/Olfaction-Vision-Language-Embeddings}
|
| 85 |
+
}
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
```
|
| 89 |
+
@misc{radford2021clip,
|
| 90 |
+
title = {Learning Transferable Visual Models From Natural Language Supervision},
|
| 91 |
+
author = {Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
|
| 92 |
+
year = 2021,
|
| 93 |
+
url = {https://arxiv.org/abs/2103.00020},
|
| 94 |
+
eprint = {2103.00020},
|
| 95 |
+
archiveprefix = {arXiv},
|
| 96 |
+
primaryclass = {cs.CV}
|
| 97 |
+
}
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
```
|
| 101 |
+
@misc{zhai2023siglip,
|
| 102 |
+
title={Sigmoid Loss for Language Image Pre-Training},
|
| 103 |
+
author={Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
|
| 104 |
+
year={2023},
|
| 105 |
+
eprint={2303.15343},
|
| 106 |
+
archivePrefix={arXiv},
|
| 107 |
+
primaryClass={cs.CV},
|
| 108 |
+
url={https://arxiv.org/abs/2303.15343},
|
| 109 |
+
}
|
| 110 |
+
```
|
notebooks/olfaction_vision_language_classifier_inference.ipynb
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "A100"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"source": [
|
| 22 |
+
"# Olfaction-Vision-Language-Classifier"
|
| 23 |
+
],
|
| 24 |
+
"metadata": {
|
| 25 |
+
"id": "SHn5L_NdNL-V"
|
| 26 |
+
}
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"source": [
|
| 31 |
+
"This is a quick start on loading the olfaction-vision-language models and getting the probability/logits of the presence of observed chemical compounds in a visual scene given a set of aroma descriptors."
|
| 32 |
+
],
|
| 33 |
+
"metadata": {
|
| 34 |
+
"id": "_40H8e0QNCj_"
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "markdown",
|
| 39 |
+
"source": [
|
| 40 |
+
"### Install Libraries"
|
| 41 |
+
],
|
| 42 |
+
"metadata": {
|
| 43 |
+
"id": "g0qE7ci4M_V6"
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": 1,
|
| 49 |
+
"metadata": {
|
| 50 |
+
"id": "cxabQxw9LSzM",
|
| 51 |
+
"colab": {
|
| 52 |
+
"base_uri": "https://localhost:8080/"
|
| 53 |
+
},
|
| 54 |
+
"outputId": "a29b3012-09f1-4ef6-ac54-8f30494fea9f"
|
| 55 |
+
},
|
| 56 |
+
"outputs": [
|
| 57 |
+
{
|
| 58 |
+
"output_type": "stream",
|
| 59 |
+
"name": "stdout",
|
| 60 |
+
"text": [
|
| 61 |
+
"Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (4.57.1)\n",
|
| 62 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers) (3.20.0)\n",
|
| 63 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.34.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.36.0)\n",
|
| 64 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2.0.2)\n",
|
| 65 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (25.0)\n",
|
| 66 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.3)\n",
|
| 67 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2024.11.6)\n",
|
| 68 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers) (2.32.4)\n",
|
| 69 |
+
"Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.1)\n",
|
| 70 |
+
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.6.2)\n",
|
| 71 |
+
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers) (4.67.1)\n",
|
| 72 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (2025.3.0)\n",
|
| 73 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (4.15.0)\n",
|
| 74 |
+
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (1.2.0)\n",
|
| 75 |
+
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (3.4.4)\n",
|
| 76 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (3.11)\n",
|
| 77 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (2.5.0)\n",
|
| 78 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (2025.10.5)\n",
|
| 79 |
+
"Requirement already satisfied: safetensors in /usr/local/lib/python3.12/dist-packages (0.6.2)\n"
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
],
|
| 83 |
+
"source": [
|
| 84 |
+
"!pip install transformers\n",
|
| 85 |
+
"!pip install safetensors"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "markdown",
|
| 90 |
+
"source": [
|
| 91 |
+
"### Import and Configure"
|
| 92 |
+
],
|
| 93 |
+
"metadata": {
|
| 94 |
+
"id": "7HB6A3RDMrQ0"
|
| 95 |
+
}
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"source": [
|
| 100 |
+
"import torch\n",
|
| 101 |
+
"import torch.nn as nn\n",
|
| 102 |
+
"from safetensors.torch import load_file\n",
|
| 103 |
+
"from torchvision import transforms\n",
|
| 104 |
+
"from transformers import CLIPProcessor, CLIPModel\n",
|
| 105 |
+
"from PIL import Image\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 109 |
+
"EMBED_DIM = 512 # Embedding dims = 512 for classifiers\n",
|
| 110 |
+
"ENCODER_FILE_PATH = f\"./olf_encoder_{EMBED_DIM}_c.pt\"\n",
|
| 111 |
+
"GNN_FILE_PATH = f\"./gnn_{EMBED_DIM}_c.pt\""
|
| 112 |
+
],
|
| 113 |
+
"metadata": {
|
| 114 |
+
"id": "BHUd6n3bLbqo"
|
| 115 |
+
},
|
| 116 |
+
"execution_count": 2,
|
| 117 |
+
"outputs": []
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "markdown",
|
| 121 |
+
"source": [
|
| 122 |
+
"### Embeddings Function"
|
| 123 |
+
],
|
| 124 |
+
"metadata": {
|
| 125 |
+
"id": "Tad9Pu6PMn9g"
|
| 126 |
+
}
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"source": [
|
| 131 |
+
"def get_embeddings(clip_model, olf_encoder, graph_model, image, olf_vec):\n",
|
| 132 |
+
" \"\"\"\n",
|
| 133 |
+
" Gets joint olfaction-vision-language embeddings for a given image and olfaction vector.\n",
|
| 134 |
+
"\n",
|
| 135 |
+
" :param clip_model: vision-language model\n",
|
| 136 |
+
" :param olf_encoder: olfactory encoder from aromas/molecules\n",
|
| 137 |
+
" :param graph_model: cross-modal associator\n",
|
| 138 |
+
" :param image: PIL image\n",
|
| 139 |
+
" :param olf_vec: olfaction vector\n",
|
| 140 |
+
" :return: joint olfaction-vision-language embeddings\n",
|
| 141 |
+
" \"\"\"\n",
|
| 142 |
+
" clip_model.eval()\n",
|
| 143 |
+
" olf_encoder.eval()\n",
|
| 144 |
+
" graph_model.eval()\n",
|
| 145 |
+
"\n",
|
| 146 |
+
" transform = transforms.Compose([\n",
|
| 147 |
+
" transforms.Resize((224, 224)),\n",
|
| 148 |
+
" transforms.ToTensor(),\n",
|
| 149 |
+
" ])\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" image_tensor = transform(image).unsqueeze(0).to(DEVICE)\n",
|
| 152 |
+
" olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)\n",
|
| 153 |
+
"\n",
|
| 154 |
+
" with torch.no_grad():\n",
|
| 155 |
+
" vision_embeds = clip_model.get_image_features(pixel_values=image_tensor)\n",
|
| 156 |
+
" if EMBED_DIM != 768 and EMBED_DIM != 512:\n",
|
| 157 |
+
" projection = nn.Linear(vision_embeds.shape[-1], EMBED_DIM).to(DEVICE)\n",
|
| 158 |
+
" vision_embeds = projection(vision_embeds).to(DEVICE)\n",
|
| 159 |
+
" vision_embeds = vision_embeds.to(DEVICE)\n",
|
| 160 |
+
" olf_embeds = olf_encoder(olf_tensor).to(DEVICE)\n",
|
| 161 |
+
" ovl_logits = graph_model(vision_embeds, olf_embeds).squeeze()\n",
|
| 162 |
+
"\n",
|
| 163 |
+
" return ovl_logits"
|
| 164 |
+
],
|
| 165 |
+
"metadata": {
|
| 166 |
+
"id": "UzhGG8CzMmBs"
|
| 167 |
+
},
|
| 168 |
+
"execution_count": 3,
|
| 169 |
+
"outputs": []
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "markdown",
|
| 173 |
+
"source": [
|
| 174 |
+
"### Get Joint Embeddings from a Data Sample"
|
| 175 |
+
],
|
| 176 |
+
"metadata": {
|
| 177 |
+
"id": "fQIb0yeiMwRN"
|
| 178 |
+
}
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"source": [
|
| 183 |
+
"# Load the models\n",
|
| 184 |
+
"olf_encoder = torch.jit.load(ENCODER_FILE_PATH)\n",
|
| 185 |
+
"graph_model = torch.jit.load(GNN_FILE_PATH)\n",
|
| 186 |
+
"clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(DEVICE)\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"# Build example vision-olfaction sample with dummy data\n",
|
| 189 |
+
"example_image = Image.new('RGB', (224, 224))\n",
|
| 190 |
+
"example_image.save(f\"/tmp/image_example.jpg\")\n",
|
| 191 |
+
"example_olf_vec = torch.randn(112)\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"# Run inference\n",
|
| 194 |
+
"logits = get_embeddings(\n",
|
| 195 |
+
" clip_model,\n",
|
| 196 |
+
" olf_encoder,\n",
|
| 197 |
+
" graph_model,\n",
|
| 198 |
+
" example_image,\n",
|
| 199 |
+
" example_olf_vec\n",
|
| 200 |
+
")\n",
|
| 201 |
+
"print(\"Logits\", logits)"
|
| 202 |
+
],
|
| 203 |
+
"metadata": {
|
| 204 |
+
"id": "_U5qqxn8Mibo",
|
| 205 |
+
"colab": {
|
| 206 |
+
"base_uri": "https://localhost:8080/"
|
| 207 |
+
},
|
| 208 |
+
"outputId": "23b27505-539d-4262-d1a2-45186367afcb"
|
| 209 |
+
},
|
| 210 |
+
"execution_count": 4,
|
| 211 |
+
"outputs": [
|
| 212 |
+
{
|
| 213 |
+
"output_type": "stream",
|
| 214 |
+
"name": "stderr",
|
| 215 |
+
"text": [
|
| 216 |
+
"/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n",
|
| 217 |
+
"Access to the secret `HF_TOKEN` has not been granted on this notebook.\n",
|
| 218 |
+
"You will not be requested again.\n",
|
| 219 |
+
"Please restart the session if you want to be prompted again.\n",
|
| 220 |
+
" warnings.warn(\n",
|
| 221 |
+
"/tmp/ipython-input-2356812152.py:22: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
| 222 |
+
" olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)\n"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"output_type": "stream",
|
| 227 |
+
"name": "stdout",
|
| 228 |
+
"text": [
|
| 229 |
+
"Logits tensor(3.8031e+23, device='cuda:0')\n"
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
]
|
| 233 |
+
}
|
| 234 |
+
]
|
| 235 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hf-xet==1.1.9
|
| 2 |
+
hf_transfer==0.1.9
|
| 3 |
+
huggingface-hub==0.34.4
|
| 4 |
+
matplotlib==3.10.0
|
| 5 |
+
numpy==2.0.2
|
| 6 |
+
openai==1.106.1
|
| 7 |
+
opencv-contrib-python==4.12.0.88
|
| 8 |
+
opencv-python==4.12.0.88
|
| 9 |
+
opencv-python-headless==4.12.0.88
|
| 10 |
+
openpyxl==3.1.5
|
| 11 |
+
opt_einsum==3.4.0
|
| 12 |
+
pandas==2.2.2
|
| 13 |
+
pandas-datareader==0.10.0
|
| 14 |
+
requests==2.32.4
|
| 15 |
+
safetensors==0.6.2
|
| 16 |
+
scikit-image==0.25.2
|
| 17 |
+
scikit-learn==1.6.1
|
| 18 |
+
scipy==1.16.1
|
| 19 |
+
sklearn-pandas==2.2.0
|
| 20 |
+
tiktoken==0.11.0
|
| 21 |
+
tokenizers==0.22.0
|
| 22 |
+
torch==2.8.0+cu126
|
| 23 |
+
torch-geometric==2.6.1
|
| 24 |
+
torchao==0.10.0
|
| 25 |
+
torchaudio==2.8.0+cu126
|
| 26 |
+
torchdata==0.11.0
|
| 27 |
+
torchsummary==1.5.1
|
| 28 |
+
torchtune==0.6.1
|
| 29 |
+
torchvision==0.23.0+cu126
|
| 30 |
+
tqdm==4.67.1
|
| 31 |
+
transformers==4.56.1
|
| 32 |
+
urllib3==2.5.0
|
| 33 |
+
uvicorn==0.35.0
|
src/base_model.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from transformers import CLIPModel, SiglipModel
|
| 4 |
+
|
| 5 |
+
from src import constants
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# INFERENCE
|
| 9 |
+
def run_inference(vision_lang_encoder, olf_encoder, graph_model, image, olf_vec):
|
| 10 |
+
vision_lang_encoder.eval()
|
| 11 |
+
olf_encoder.eval()
|
| 12 |
+
graph_model.eval()
|
| 13 |
+
|
| 14 |
+
transform = transforms.Compose([
|
| 15 |
+
transforms.Resize((constants.IMG_DIM, constants.IMG_DIM)),
|
| 16 |
+
transforms.ToTensor(),
|
| 17 |
+
])
|
| 18 |
+
|
| 19 |
+
image_tensor = transform(image).unsqueeze(0).to(constants.DEVICE)
|
| 20 |
+
olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(constants.DEVICE)
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
vision_embed = vision_lang_encoder.get_image_features(pixel_values=image_tensor)
|
| 24 |
+
olf_embed = olf_encoder(olf_tensor)
|
| 25 |
+
|
| 26 |
+
nodes = torch.cat([vision_embed, olf_embed], dim=0)
|
| 27 |
+
edge_index = torch.cartesian_prod(torch.arange(nodes.size(0)), torch.arange(nodes.size(0))).T.to(constants.DEVICE)
|
| 28 |
+
logits = graph_model(nodes, edge_index)
|
| 29 |
+
|
| 30 |
+
return logits
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_model():
|
| 34 |
+
# Use CLIP as default baseline
|
| 35 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(constants.DEVICE)
|
| 36 |
+
clip_model.eval()
|
| 37 |
+
"""
|
| 38 |
+
Or, you can also use SigLIP:
|
| 39 |
+
SiglipModel.from_pretrained(
|
| 40 |
+
"google/siglip-so400m-patch14-384",
|
| 41 |
+
attn_implementation="flash_attention_2",
|
| 42 |
+
dtype=torch.float16,
|
| 43 |
+
device_map=constants.DEVICE,
|
| 44 |
+
)
|
| 45 |
+
"""
|
| 46 |
+
olf_encoder = torch.jit.load(constants.ENCODER_SMALL_GRAPH_PATH).to(constants.DEVICE)
|
| 47 |
+
olf_encoder.eval()
|
| 48 |
+
graph_model = torch.jit.load(constants.OVLE_SMALL_GRAPH_PATH).to(constants.DEVICE)
|
| 49 |
+
graph_model.eval()
|
| 50 |
+
|
| 51 |
+
return clip_model, olf_encoder, graph_model
|
| 52 |
+
|
src/constants.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# Number of features in each aroma vector
|
| 5 |
+
AROMA_VEC_LENGTH: int = 138
|
| 6 |
+
# All images are normalized to this
|
| 7 |
+
IMG_DIM: int = 224
|
| 8 |
+
# Each model was trained with these hyperparams
|
| 9 |
+
BATCH_SIZE: int = 16
|
| 10 |
+
EMBED_DIM: int = 512
|
| 11 |
+
|
| 12 |
+
# CPU or GPU?
|
| 13 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
|
| 15 |
+
# Paths to models
|
| 16 |
+
OVLE_SMALL_BASE_PATH: str = f"./model/ovlc/base/gnn_{EMBED_DIM}_c.pt"
|
| 17 |
+
ENCODER_SMALL_BASE_PATH: str = f"./model/ovlc/base/olf_encoder_{EMBED_DIM}_c.pt"
|
| 18 |
+
OVLE_LARGE_GRAPH_PATH: str = f"./model/ovlc/graph/gat_gnn_{EMBED_DIM}_c.pt"
|
| 19 |
+
ENCODER_LARGE_GRAPH_PATH: str = f"./model/ovlc/graph/gat_olf_encoder_{EMBED_DIM}_c.pt"
|
src/graph_model.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from transformers import CLIPModel, SiglipModel
|
| 4 |
+
|
| 5 |
+
from src import constants
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# INFERENCE
|
| 9 |
+
def run_inference(vision_lang_encoder, olf_encoder, graph_model, image, olf_vec):
|
| 10 |
+
vision_lang_encoder.eval()
|
| 11 |
+
olf_encoder.eval()
|
| 12 |
+
graph_model.eval()
|
| 13 |
+
|
| 14 |
+
transform = transforms.Compose([
|
| 15 |
+
transforms.Resize((constants.IMG_DIM, constants.IMG_DIM)),
|
| 16 |
+
transforms.ToTensor(),
|
| 17 |
+
])
|
| 18 |
+
|
| 19 |
+
image_tensor = transform(image).unsqueeze(0).to(constants.DEVICE)
|
| 20 |
+
olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(constants.DEVICE)
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
vision_embed = vision_lang_encoder.get_image_features(pixel_values=image_tensor)
|
| 24 |
+
olf_embed = olf_encoder(olf_tensor)
|
| 25 |
+
logits = graph_model(vision_embed, olf_embed).squeeze()
|
| 26 |
+
|
| 27 |
+
return logits
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_model():
|
| 31 |
+
# Use CLIP as default baseline
|
| 32 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(constants.DEVICE)
|
| 33 |
+
clip_model.eval()
|
| 34 |
+
"""
|
| 35 |
+
Or, you can also use SigLIP:
|
| 36 |
+
SiglipModel.from_pretrained(
|
| 37 |
+
"google/siglip-so400m-patch14-384",
|
| 38 |
+
attn_implementation="flash_attention_2",
|
| 39 |
+
dtype=torch.float16,
|
| 40 |
+
device_map=constants.DEVICE,
|
| 41 |
+
)
|
| 42 |
+
"""
|
| 43 |
+
olf_encoder = torch.jit.load(constants.ENCODER_SMALL_BASE_PATH).to(constants.DEVICE)
|
| 44 |
+
olf_encoder.eval()
|
| 45 |
+
graph_model = torch.jit.load(constants.OVLE_SMALL_BASE_PATH).to(constants.DEVICE)
|
| 46 |
+
graph_model.eval()
|
| 47 |
+
|
| 48 |
+
return clip_model, olf_encoder, graph_model
|
src/main.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
from src import constants
|
| 5 |
+
from src import base_model as bm
|
| 6 |
+
from src import graph_model as gm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
|
| 11 |
+
# Build example vision-olfaction sample with dummy data
|
| 12 |
+
example_image = Image.new('RGB', (constants.IMG_DIM, constants.IMG_DIM))
|
| 13 |
+
example_image.save(f"/tmp/image_example.jpg")
|
| 14 |
+
example_olf_vec = torch.randn(constants.AROMA_VEC_LENGTH)
|
| 15 |
+
|
| 16 |
+
# -------- Option A --------
|
| 17 |
+
# Load the base models
|
| 18 |
+
vision_lang_encoder, olf_encoder, graph_model = bm.load_model()
|
| 19 |
+
# Get probability from base models
|
| 20 |
+
ovl_classifier_base = bm.run_inference(
|
| 21 |
+
vision_lang_encoder=vision_lang_encoder,
|
| 22 |
+
olf_encoder=olf_encoder,
|
| 23 |
+
graph_model=graph_model,
|
| 24 |
+
image=example_image,
|
| 25 |
+
olf_vec=example_olf_vec
|
| 26 |
+
)
|
| 27 |
+
print(f"Olfaction-Vision-Language Logits from Base Model: {ovl_classifier_base}")
|
| 28 |
+
|
| 29 |
+
# -------- Option B --------
|
| 30 |
+
# Load the graph attention models
|
| 31 |
+
vision_lang_encoder, olf_encoder, graph_model = gm.load_model()
|
| 32 |
+
# Get probability from graph attention models
|
| 33 |
+
ovl_classifier_graph = gm.run_inference(
|
| 34 |
+
vision_lang_encoder=vision_lang_encoder,
|
| 35 |
+
olf_encoder=olf_encoder,
|
| 36 |
+
graph_model=graph_model,
|
| 37 |
+
image=example_image,
|
| 38 |
+
olf_vec=example_olf_vec
|
| 39 |
+
)
|
| 40 |
+
print(f"Olfaction-Vision-Language Logits from Graph Attention Model: {ovl_classifier_graph}")
|
| 41 |
+
|
src/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, Dataset
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from src import constants
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# DATASET EXAMPLE
|
| 10 |
+
class OlfactionVisionDataset(Dataset):
|
| 11 |
+
def __init__(self, image_paths, olfaction_vectors, labels):
|
| 12 |
+
self.image_paths = image_paths
|
| 13 |
+
self.olfaction_vectors = olfaction_vectors
|
| 14 |
+
self.labels = labels
|
| 15 |
+
self.transform = transforms.Compose([
|
| 16 |
+
transforms.Resize((constants.IMG_DIM, constants.IMG_DIM)),
|
| 17 |
+
transforms.ToTensor(),
|
| 18 |
+
])
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return len(self.image_paths)
|
| 22 |
+
|
| 23 |
+
def __getitem__(self, idx):
|
| 24 |
+
img_path = self.image_paths[idx]
|
| 25 |
+
image = self.transform(Image.open(img_path).convert('RGB'))
|
| 26 |
+
olf_vec = self.olfaction_vectors[idx]
|
| 27 |
+
label = self.labels[idx]
|
| 28 |
+
return image, torch.tensor(olf_vec, dtype=torch.float32), label
|
| 29 |
+
|