--- tags: - model_hub_mixin - pytorch_model_hub_mixin --- ## ⚙️ Usage Our pretrained model are made available through `rshf` and `transformers` package for easy inference. Load and initialize: ```python from rshf.prom3e import ProM3E model = ProM3E.from_pretrained("MVRL/ProM3E") ``` Inference: ```python # Get precomputed embeddings from taxabind for image, sat, loc, env, text, audio # Replace missing modalities with any vector # Stack embeddings in the order: image, sat, loc, env, text, audio # Pass through the model # Example: image_embeds = torch.randn(2, 512) sat_embeds = torch.randn(2, 512) loc_embeds = torch.randn(2, 512) env_embeds = torch.randn(2, 512) text_embeds = torch.randn(2, 512) audio_embeds = torch.randn(2, 512) modalities = torch.stack((image_embeds, sat_embeds, loc_embeds, env_embeds, text_embeds, audio_embeds), dim=1) modalities = torch.nn.functional.normalize(modalities, dim=-1) unmasked_modalities = [0, 2] reconstructions, mu, log_var, hidden_repr = model.forward_inference(modalities, unmasked_modalities) ```