# model.py import torch import torch.nn as nn import torch.nn.functional as F import warnings from transformers import ( HubertModel, AutoProcessor, AutoTokenizer, AutoModel ) warnings.filterwarnings("ignore") import torchvision.transforms as transforms from PIL import Image from peft import ( LoraConfig, get_peft_model, TaskType, ) ################################################################# # Audio Embedder ################################################################# class AudioEmbedder(nn.Module): """ Uses a pre-trained HuBERT (or similar) to extract audio features from raw audio (16kHz). Projects them down to a desired embedding dimension. """ def __init__(self, embedding_dim=512, hubert_name="facebook/hubert-base-ls960"): super().__init__() self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") self.hubert = HubertModel.from_pretrained(hubert_name) self.projection = nn.Linear(self.hubert.config.hidden_size, embedding_dim) for param in self.hubert.parameters(): param.requires_grad = True for param in self.projection.parameters(): param.requires_grad = True def forward(self, audio_input: torch.Tensor) -> torch.Tensor: """ Args: audio_input: (B, T) raw audio waveform at 16kHz Returns: audio_feats: (B, Na, D) B = batch size Na = number of audio tokens (T/320 for Hubert) D = embedding_dim """ if len(audio_input.shape) == 3: # shape: [B, 1, T] audio_input = audio_input.squeeze(0) # squeeze first dim to get [B, T] inputs = self.processor( audio_input, return_tensors="pt", sampling_rate=16000, padding=True, return_attention_mask=True ).input_values.squeeze(0) device = next(self.parameters()).device inputs = inputs.to(device) hubert_output = self.hubert(inputs).last_hidden_state # (B, T', hidden_size) audio_feats = self.projection(hubert_output) # (B, T', D) return audio_feats ################################################################# # Text Embedder ################################################################# class TextEmbedder(nn.Module): """ Uses a pre-trained BERT-like model (ModernBERT or similar) to extract text features. Projects them down to a desired embedding dimension. """ def __init__(self, embedding_dim=512, model_name="answerdotai/ModernBERT-base"): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.encoder = AutoModel.from_pretrained(model_name) self.projection = nn.Linear(self.encoder.config.hidden_size, embedding_dim) print("Using text model: ", model_name) for param in self.encoder.parameters(): param.requires_grad = True for param in self.projection.parameters(): param.requires_grad = True def forward(self, text_list): """ Args: text_list: List[str], batch of text inputs Returns: text_feats: (B, Nt, D) attention_mask: (B, Nt) """ inputs = self.tokenizer( text_list, padding=True, truncation=True, add_special_tokens=False, max_length=128, return_tensors="pt" ) device = next(self.parameters()).device for k in inputs: inputs[k] = inputs[k].to(device) outputs = self.encoder(**inputs) # (B, Nt, hidden_size) hidden_states = outputs.last_hidden_state text_feats = self.projection(hidden_states) # (B, Nt, D) return text_feats, inputs["attention_mask"] ################################################################# # Visual Embedder ################################################################# class ViTLoRAEmbedder(nn.Module): def __init__(self, model_name='facebookresearch/dinov2', arch='dinov2_vitb14', embedding_dim=512, dropout_prob=0.1, lora_rank=16, lora_alpha=32): super().__init__() self.model = torch.hub.load(model_name, arch) #print(f"Using DINOv2 model with LoRA adapters: {arch}") for param in self.model.parameters(): param.requires_grad = False target_modules = ["attn.qkv", "attn.proj"] lora_config = LoraConfig( task_type=TaskType.FEATURE_EXTRACTION, inference_mode=True, r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=dropout_prob, ) self.model = get_peft_model(self.model, lora_config) #trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) #total_params = sum(p.numel() for p in self.model.parameters()) #print(f"ViTLoRAEmbedder - Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}% of total)") self.projection = nn.Linear(self.model.embed_dim, embedding_dim) self.dropout = nn.Dropout(p=dropout_prob) def forward(self, x): """ Args: x: (B, 3, H, W), e.g. (B,3,224,224) image batch Returns: visual_feats: (B, Nv, D) Nv = number of visual tokens D = embedding_dim """ if len(x.shape) == 5: # shape: [1, 1, 3, 224, 224] x = x.squeeze(0) # get [1, 3, 224, 224] if len(x.shape) == 3: x = x.unsqueeze(0) # Use intermediate layers for feature extraction - same as original patches = self.model.get_intermediate_layers(x, n=1)[0] feats = self.projection(patches) feats = self.dropout(feats) return feats class Triad(nn.Module): def __init__( self, audio_model_name="facebook/hubert-base-ls960", text_model_name="distilbert/distilbert-base-uncased", temperature=2.0, patch_sparsity_threshold=0.3, patch_sparsity_weight=0.1, visual_dropout_prob=0.1, lora_rank=8, lora_alpha=16 ): super().__init__() self.audio_embedder = AudioEmbedder(embedding_dim=512, hubert_name=audio_model_name) self.text_embedder = TextEmbedder(embedding_dim=512, model_name=text_model_name) self.visual_embedder = ViTLoRAEmbedder( arch='dinov2_vitb14', embedding_dim=512, dropout_prob=visual_dropout_prob, lora_rank=lora_rank, lora_alpha=lora_alpha ) self.temperature = nn.Parameter(torch.tensor(temperature)) self.patch_sparsity_threshold = patch_sparsity_threshold self.patch_sparsity_weight = patch_sparsity_weight def compute_similarity_matrix(self, feats1, feats2): """ Generic token-level dot-product similarity between feats1 and feats2. feats1: (B, N1, D) feats2: (B, N2, D) Returns sim: (B, N1, N2) """ sim = torch.bmm(feats1, feats2.transpose(1, 2)) return sim / self.temperature def forward(self, image=None, audio=None, text_list=None): assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided" if image is not None: assert image is not str, "Frames should be a path to an image" if audio is not None: assert isinstance(audio, torch.Tensor) and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (B, T)" if text_list is not None: assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1" if image is not None: device = next(self.parameters()).device # Handle batch of file paths if isinstance(image, list): # Process a list of image paths processed_images = [] for img_path in image: img = Image.open(img_path).convert('RGB') transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) processed_img = transform(img).to(device) processed_images.append(processed_img) image = torch.stack(processed_images, dim=0) # [B, 3, 224, 224] # Handle single file path elif isinstance(image, str): img = Image.open(image).convert('RGB') transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(img).to(device).unsqueeze(0) # Add batch dimension [1, 3, 224, 224] # Handle tensor input (assume it's already processed but may need device transfer) elif isinstance(image, torch.Tensor): # If single image without batch dimension if image.dim() == 3: image = image.unsqueeze(0) # Add batch dimension image = image.to(device) embeddings = {} if image is not None: embeddings['visual_feats'] = self.visual_embedder(image) if audio is not None: embeddings['audio_feats'] = self.audio_embedder(audio) if text_list is not None: embeddings['text_feats'], _ = self.text_embedder(text_list) # if two or more modalities are present, we compute the similarity matrix if image is not None and text_list is not None: embeddings['vis_text_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['visual_feats']) if audio is not None and image is not None: embeddings['vis_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['audio_feats'], embeddings['visual_feats']) if text_list is not None and audio is not None: embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats']) return embeddings