| --- |
| tags: |
| - model_hub_mixin |
| - pytorch_model_hub_mixin |
| --- |
| |
| ## ⚙️ Usage |
| Our pretrained model are made available through `rshf` and `transformers` package for easy inference. |
|
|
| Load and initialize: |
| ```python |
| from rshf.prom3e import ProM3E |
| |
| model = ProM3E.from_pretrained("MVRL/ProM3E") |
| ``` |
|
|
| Inference: |
| ```python |
| # Get precomputed embeddings from taxabind for image, sat, loc, env, text, audio |
| # Replace missing modalities with any vector |
| # Stack embeddings in the order: image, sat, loc, env, text, audio |
| # Pass through the model |
| |
| # Example: |
| image_embeds = torch.randn(2, 512) |
| sat_embeds = torch.randn(2, 512) |
| loc_embeds = torch.randn(2, 512) |
| env_embeds = torch.randn(2, 512) |
| text_embeds = torch.randn(2, 512) |
| audio_embeds = torch.randn(2, 512) |
| |
| modalities = torch.stack((image_embeds, sat_embeds, loc_embeds, env_embeds, text_embeds, audio_embeds), dim=1) |
| |
| modalities = torch.nn.functional.normalize(modalities, dim=-1) |
| |
| unmasked_modalities = [0, 2] |
| |
| reconstructions, mu, log_var, hidden_repr = model.forward_inference(modalities, unmasked_modalities) |
| ``` |