SajayR commited on
Commit
b6faa73
·
verified ·
1 Parent(s): 8637fe4

Create hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +212 -0
hf_model.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ from transformers import (
6
+ HubertModel,
7
+ AutoProcessor,
8
+ AutoTokenizer,
9
+ AutoModel
10
+ )
11
+ warnings.filterwarnings("ignore")
12
+ import torchvision.transforms as transforms
13
+ from PIL import Image
14
+ #################################################################
15
+ # Audio Embedder
16
+ #################################################################
17
+ class AudioEmbedder(nn.Module):
18
+ """
19
+ Pre-trained HuBERT (or similar) to extract audio features from raw audio (16kHz).
20
+ Projects them down to a desired embedding dimension.
21
+ """
22
+ def __init__(self, embedding_dim=512, hubert_name="facebook/hubert-base-ls960"):
23
+ super().__init__()
24
+ self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
25
+ self.hubert = HubertModel.from_pretrained(hubert_name)
26
+ self.projection = nn.Linear(self.hubert.config.hidden_size, embedding_dim)
27
+
28
+ for param in self.hubert.parameters():
29
+ param.requires_grad = True
30
+ for param in self.projection.parameters():
31
+ param.requires_grad = True
32
+
33
+ def forward(self, audio_input: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Args:
36
+ audio_input: (B, T) raw audio waveform at 16kHz
37
+
38
+ Returns:
39
+ audio_feats: (B, Na, D)
40
+ B = batch size
41
+ Na = number of audio tokens (T/320 for Hubert)
42
+ D = embedding_dim
43
+ """
44
+ if len(audio_input.shape) == 3: # shape: [B, 1, T]
45
+ audio_input = audio_input.squeeze(0) # squeeze first dim to get [B, T]
46
+ inputs = self.processor(
47
+ audio_input,
48
+ return_tensors="pt",
49
+ sampling_rate=16000,
50
+ padding=True,
51
+ return_attention_mask=True
52
+ ).input_values.squeeze(0)
53
+ device = next(self.parameters()).device
54
+ inputs = inputs.to(device)
55
+
56
+ hubert_output = self.hubert(inputs).last_hidden_state # (B, T', hidden_size)
57
+
58
+ audio_feats = self.projection(hubert_output) # (B, T', D)
59
+
60
+ return audio_feats
61
+
62
+
63
+ #################################################################
64
+ # Text Embedder
65
+ #################################################################
66
+ class TextEmbedder(nn.Module):
67
+ """
68
+ Pre-trained BERT-like model (ModernBERT or similar) to extract text features.
69
+ Projects them down to a desired embedding dimension.
70
+ """
71
+ def __init__(self, embedding_dim=512, model_name="answerdotai/ModernBERT-base"):
72
+ super().__init__()
73
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
74
+ self.encoder = AutoModel.from_pretrained(model_name)
75
+ self.projection = nn.Linear(self.encoder.config.hidden_size, embedding_dim)
76
+ print("Using text model: ", model_name)
77
+
78
+ for param in self.encoder.parameters():
79
+ param.requires_grad = True
80
+ for param in self.projection.parameters():
81
+ param.requires_grad = True
82
+
83
+ def forward(self, text_list):
84
+ """
85
+ Args:
86
+ text_list: List[str], batch of text inputs
87
+
88
+ Returns:
89
+ text_feats: (B, Nt, D)
90
+ attention_mask: (B, Nt)
91
+ """
92
+ inputs = self.tokenizer(
93
+ text_list,
94
+ padding=True,
95
+ truncation=True,
96
+ add_special_tokens=False,
97
+ max_length=128,
98
+ return_tensors="pt"
99
+ )
100
+ device = next(self.parameters()).device
101
+ for k in inputs:
102
+ inputs[k] = inputs[k].to(device)
103
+
104
+ outputs = self.encoder(**inputs) # (B, Nt, hidden_size)
105
+ hidden_states = outputs.last_hidden_state
106
+ text_feats = self.projection(hidden_states) # (B, Nt, D)
107
+
108
+ return text_feats, inputs["attention_mask"]
109
+
110
+
111
+ #################################################################
112
+ # Visual Embedder
113
+ #################################################################
114
+ class ViTEmbedder(nn.Module):
115
+ """
116
+ DINOv2 to extract patch embeddings from an image.
117
+ Then projects to a common dimension with a linear layer.
118
+ """
119
+ def __init__(self, model_name='facebookresearch/dinov2', arch='dinov2_vitb14',
120
+ embedding_dim=512, dropout_prob=0.1):
121
+ super().__init__()
122
+ self.model = torch.hub.load(model_name, arch)
123
+ print("Using DINOv2 model: ", arch)
124
+ self.projection = nn.Linear(self.model.embed_dim, embedding_dim)
125
+ self.dropout = nn.Dropout(p=dropout_prob)
126
+
127
+ for param in self.model.parameters():
128
+ param.requires_grad = True
129
+
130
+ def forward(self, x):
131
+ """
132
+ Args:
133
+ x: (B, 3, H, W), e.g. (B,3,224,224) image batch
134
+ Returns:
135
+ visual_feats: (B, Nv, D)
136
+ Nv = number of visual tokens
137
+ D = embedding_dim
138
+ """
139
+ if len(x.shape) == 5: # shape: [1, 1, 3, 224, 224]
140
+ x = x.squeeze(0) # get [1, 3, 224, 224]
141
+ if len(x.shape) == 3:
142
+ x = x.unsqueeze(0)
143
+ patches = self.model.get_intermediate_layers(x, n=1)[0]
144
+ feats = self.projection(patches)
145
+ feats = self.dropout(feats)
146
+
147
+ return feats
148
+
149
+ class Triad(nn.Module):
150
+ def __init__(
151
+ self,
152
+ audio_model_name="facebook/hubert-base-ls960",
153
+ text_model_name="distilbert/distilbert-base-uncased",
154
+ temperature=2.0,
155
+ patch_sparsity_threshold=0.3,
156
+ patch_sparsity_weight=0.1,
157
+ visual_dropout_prob=0.1
158
+ ):
159
+ super().__init__()
160
+
161
+ self.audio_embedder = AudioEmbedder(embedding_dim=512, hubert_name=audio_model_name)
162
+ self.text_embedder = TextEmbedder(embedding_dim=512, model_name=text_model_name)
163
+ self.visual_embedder = ViTEmbedder(arch='dinov2_vitb14',
164
+ embedding_dim=512,
165
+ dropout_prob=visual_dropout_prob)
166
+
167
+ self.temperature = nn.Parameter(torch.tensor(temperature))
168
+ self.patch_sparsity_threshold = patch_sparsity_threshold
169
+ self.patch_sparsity_weight = patch_sparsity_weight
170
+
171
+ def compute_similarity_matrix(self, feats1, feats2):
172
+ """
173
+ Generic token-level dot-product similarity between feats1 and feats2.
174
+ feats1: (B, N1, D)
175
+ feats2: (B, N2, D)
176
+ Returns sim: (B, N1, N2)
177
+ """
178
+ sim = torch.bmm(feats1, feats2.transpose(1, 2))
179
+ return sim / self.temperature
180
+
181
+ def forward(self, image=None, audio=None, text_list=None):
182
+ assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided"
183
+ if image is not None: assert image is not str, "Frames should be a path to an image"
184
+ if audio is not None:
185
+ assert isinstance(audio, torch.Tensor) and audio.shape[0] == 1 and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (1, T)"
186
+ if text_list is not None:
187
+ assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
188
+ if image is not None:
189
+ image = Image.open(image).convert('RGB')
190
+ transform = transforms.Compose([
191
+ transforms.Resize((224, 224)),
192
+ transforms.ToTensor(),
193
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
194
+ std=[0.229, 0.224, 0.225])
195
+ ])
196
+ image = transform(image)
197
+ embeddings = {}
198
+ if image is not None:
199
+ embeddings['visual_feats'] = self.visual_embedder(image)
200
+ if audio is not None:
201
+ embeddings['audio_feats'] = self.audio_embedder(audio)
202
+ if text_list is not None:
203
+ embeddings['text_feats'], _ = self.text_embedder(text_list)
204
+ # if two or more modalities are present, we compute the similarity matrix
205
+ if image is not None and text_list is not None:
206
+ embeddings['vis_text_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['visual_feats'])
207
+ if audio is not None and image is not None:
208
+ embeddings['vis_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['audio_feats'], embeddings['visual_feats'])
209
+ if text_list is not None and audio is not None:
210
+ embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats'])
211
+ return embeddings
212
+