File size: 9,935 Bytes
b6faa73 b0ea89d b6faa73 ce081c7 8948365 b6faa73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from transformers import (
HubertModel,
AutoProcessor,
AutoTokenizer,
AutoModel
)
warnings.filterwarnings("ignore")
import torchvision.transforms as transforms
from PIL import Image
#################################################################
# Audio Embedder
#################################################################
class AudioEmbedder(nn.Module):
"""
Pre-trained HuBERT (or similar) to extract audio features from raw audio (16kHz).
Projects them down to a desired embedding dimension.
"""
def __init__(self, embedding_dim=512, hubert_name="facebook/hubert-base-ls960"):
super().__init__()
self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
self.hubert = HubertModel.from_pretrained(hubert_name)
self.projection = nn.Linear(self.hubert.config.hidden_size, embedding_dim)
for param in self.hubert.parameters():
param.requires_grad = True
for param in self.projection.parameters():
param.requires_grad = True
def forward(self, audio_input: torch.Tensor) -> torch.Tensor:
"""
Args:
audio_input: (B, T) raw audio waveform at 16kHz
Returns:
audio_feats: (B, Na, D)
B = batch size
Na = number of audio tokens (T/320 for Hubert)
D = embedding_dim
"""
if len(audio_input.shape) == 3: # shape: [B, 1, T]
audio_input = audio_input.squeeze(0) # squeeze first dim to get [B, T]
inputs = self.processor(
audio_input,
return_tensors="pt",
sampling_rate=16000,
padding=True,
return_attention_mask=True
).input_values.squeeze(0)
device = next(self.parameters()).device
inputs = inputs.to(device)
hubert_output = self.hubert(inputs).last_hidden_state # (B, T', hidden_size)
audio_feats = self.projection(hubert_output) # (B, T', D)
return audio_feats
#################################################################
# Text Embedder
#################################################################
class TextEmbedder(nn.Module):
"""
Pre-trained BERT-like model (ModernBERT or similar) to extract text features.
Projects them down to a desired embedding dimension.
"""
def __init__(self, embedding_dim=512, model_name="answerdotai/ModernBERT-base"):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.encoder = AutoModel.from_pretrained(model_name)
self.projection = nn.Linear(self.encoder.config.hidden_size, embedding_dim)
print("Using text model: ", model_name)
for param in self.encoder.parameters():
param.requires_grad = True
for param in self.projection.parameters():
param.requires_grad = True
def forward(self, text_list):
"""
Args:
text_list: List[str], batch of text inputs
Returns:
text_feats: (B, Nt, D)
attention_mask: (B, Nt)
"""
inputs = self.tokenizer(
text_list,
padding=True,
truncation=True,
add_special_tokens=False,
max_length=128,
return_tensors="pt"
)
device = next(self.parameters()).device
for k in inputs:
inputs[k] = inputs[k].to(device)
outputs = self.encoder(**inputs) # (B, Nt, hidden_size)
hidden_states = outputs.last_hidden_state
text_feats = self.projection(hidden_states) # (B, Nt, D)
return text_feats, inputs["attention_mask"]
#################################################################
# Visual Embedder
#################################################################
class ViTEmbedder(nn.Module):
"""
DINOv2 to extract patch embeddings from an image.
Then projects to a common dimension with a linear layer.
"""
def __init__(self, model_name='facebookresearch/dinov2', arch='dinov2_vitb14',
embedding_dim=512, dropout_prob=0.1):
super().__init__()
self.model = torch.hub.load(model_name, arch)
print("Using DINOv2 model: ", arch)
self.projection = nn.Linear(self.model.embed_dim, embedding_dim)
self.dropout = nn.Dropout(p=dropout_prob)
for param in self.model.parameters():
param.requires_grad = True
def forward(self, x):
"""
Args:
x: (B, 3, H, W), e.g. (B,3,224,224) image batch
Returns:
visual_feats: (B, Nv, D)
Nv = number of visual tokens
D = embedding_dim
"""
if len(x.shape) == 5: # shape: [1, 1, 3, 224, 224]
x = x.squeeze(0) # get [1, 3, 224, 224]
if len(x.shape) == 3:
x = x.unsqueeze(0)
patches = self.model.get_intermediate_layers(x, n=1)[0]
feats = self.projection(patches)
feats = self.dropout(feats)
return feats
class Triad(nn.Module):
def __init__(
self,
audio_model_name="facebook/hubert-base-ls960",
text_model_name="distilbert/distilbert-base-uncased",
temperature=2.0,
patch_sparsity_threshold=0.3,
patch_sparsity_weight=0.1,
visual_dropout_prob=0.1
):
super().__init__()
self.audio_embedder = AudioEmbedder(embedding_dim=512, hubert_name=audio_model_name)
self.text_embedder = TextEmbedder(embedding_dim=512, model_name=text_model_name)
self.visual_embedder = ViTEmbedder(arch='dinov2_vitb14',
embedding_dim=512,
dropout_prob=visual_dropout_prob)
self.temperature = nn.Parameter(torch.tensor(temperature))
self.patch_sparsity_threshold = patch_sparsity_threshold
self.patch_sparsity_weight = patch_sparsity_weight
def compute_similarity_matrix(self, feats1, feats2):
"""
Generic token-level dot-product similarity between feats1 and feats2.
feats1: (B, N1, D)
feats2: (B, N2, D)
Returns sim: (B, N1, N2)
"""
sim = torch.bmm(feats1, feats2.transpose(1, 2))
return sim / self.temperature
def forward(self, image=None, audio=None, text_list=None):
assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided"
if image is not None: assert image is not str, "Frames should be a path to an image"
if audio is not None:
assert isinstance(audio, torch.Tensor) and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (B, T)"
if text_list is not None:
assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
if image is not None:
device = next(self.parameters()).device
# Handle batch of file paths
if isinstance(image, list):
# Process a list of image paths
processed_images = []
for img_path in image:
img = Image.open(img_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
processed_img = transform(img).to(device)
processed_images.append(processed_img)
image = torch.stack(processed_images, dim=0) # [B, 3, 224, 224]
# Handle single file path
elif isinstance(image, str):
img = Image.open(image).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(img).to(device).unsqueeze(0) # Add batch dimension [1, 3, 224, 224]
# Handle tensor input (assume it's already processed but may need device transfer)
elif isinstance(image, torch.Tensor):
# If single image without batch dimension
if image.dim() == 3:
image = image.unsqueeze(0) # Add batch dimension
image = image.to(device)
embeddings = {}
if image is not None:
embeddings['visual_feats'] = self.visual_embedder(image)
if audio is not None:
embeddings['audio_feats'] = self.audio_embedder(audio)
if text_list is not None:
embeddings['text_feats'], _ = self.text_embedder(text_list)
# if two or more modalities are present, we compute the similarity matrix
if image is not None and text_list is not None:
embeddings['vis_text_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['visual_feats'])
if audio is not None and image is not None:
embeddings['vis_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['audio_feats'], embeddings['visual_feats'])
if text_list is not None and audio is not None:
embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats'])
return embeddings
|