File size: 1,057 Bytes
faae6a6 52b54c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | ---
tags:
- model_hub_mixin
- pytorch_model_hub_mixin
---
## ⚙️ Usage
Our pretrained model are made available through `rshf` and `transformers` package for easy inference.
Load and initialize:
```python
from rshf.prom3e import ProM3E
model = ProM3E.from_pretrained("MVRL/ProM3E")
```
Inference:
```python
# Get precomputed embeddings from taxabind for image, sat, loc, env, text, audio
# Replace missing modalities with any vector
# Stack embeddings in the order: image, sat, loc, env, text, audio
# Pass through the model
# Example:
image_embeds = torch.randn(2, 512)
sat_embeds = torch.randn(2, 512)
loc_embeds = torch.randn(2, 512)
env_embeds = torch.randn(2, 512)
text_embeds = torch.randn(2, 512)
audio_embeds = torch.randn(2, 512)
modalities = torch.stack((image_embeds, sat_embeds, loc_embeds, env_embeds, text_embeds, audio_embeds), dim=1)
modalities = torch.nn.functional.normalize(modalities, dim=-1)
unmasked_modalities = [0, 2]
reconstructions, mu, log_var, hidden_repr = model.forward_inference(modalities, unmasked_modalities)
``` |