|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import warnings |
|
|
from transformers import ( |
|
|
HubertModel, |
|
|
AutoProcessor, |
|
|
AutoTokenizer, |
|
|
AutoModel |
|
|
) |
|
|
warnings.filterwarnings("ignore") |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
class AudioEmbedder(nn.Module): |
|
|
""" |
|
|
Pre-trained HuBERT (or similar) to extract audio features from raw audio (16kHz). |
|
|
Projects them down to a desired embedding dimension. |
|
|
""" |
|
|
def __init__(self, embedding_dim=512, hubert_name="facebook/hubert-base-ls960"): |
|
|
super().__init__() |
|
|
self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") |
|
|
self.hubert = HubertModel.from_pretrained(hubert_name) |
|
|
self.projection = nn.Linear(self.hubert.config.hidden_size, embedding_dim) |
|
|
|
|
|
for param in self.hubert.parameters(): |
|
|
param.requires_grad = True |
|
|
for param in self.projection.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def forward(self, audio_input: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
audio_input: (B, T) raw audio waveform at 16kHz |
|
|
|
|
|
Returns: |
|
|
audio_feats: (B, Na, D) |
|
|
B = batch size |
|
|
Na = number of audio tokens (T/320 for Hubert) |
|
|
D = embedding_dim |
|
|
""" |
|
|
if len(audio_input.shape) == 3: |
|
|
audio_input = audio_input.squeeze(0) |
|
|
inputs = self.processor( |
|
|
audio_input, |
|
|
return_tensors="pt", |
|
|
sampling_rate=16000, |
|
|
padding=True, |
|
|
return_attention_mask=True |
|
|
).input_values.squeeze(0) |
|
|
device = next(self.parameters()).device |
|
|
inputs = inputs.to(device) |
|
|
|
|
|
hubert_output = self.hubert(inputs).last_hidden_state |
|
|
|
|
|
audio_feats = self.projection(hubert_output) |
|
|
|
|
|
return audio_feats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextEmbedder(nn.Module): |
|
|
""" |
|
|
Pre-trained BERT-like model (ModernBERT or similar) to extract text features. |
|
|
Projects them down to a desired embedding dimension. |
|
|
""" |
|
|
def __init__(self, embedding_dim=512, model_name="answerdotai/ModernBERT-base"): |
|
|
super().__init__() |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.encoder = AutoModel.from_pretrained(model_name) |
|
|
self.projection = nn.Linear(self.encoder.config.hidden_size, embedding_dim) |
|
|
print("Using text model: ", model_name) |
|
|
|
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = True |
|
|
for param in self.projection.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def forward(self, text_list): |
|
|
""" |
|
|
Args: |
|
|
text_list: List[str], batch of text inputs |
|
|
|
|
|
Returns: |
|
|
text_feats: (B, Nt, D) |
|
|
attention_mask: (B, Nt) |
|
|
""" |
|
|
inputs = self.tokenizer( |
|
|
text_list, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
add_special_tokens=False, |
|
|
max_length=128, |
|
|
return_tensors="pt" |
|
|
) |
|
|
device = next(self.parameters()).device |
|
|
for k in inputs: |
|
|
inputs[k] = inputs[k].to(device) |
|
|
|
|
|
outputs = self.encoder(**inputs) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
text_feats = self.projection(hidden_states) |
|
|
|
|
|
return text_feats, inputs["attention_mask"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ViTEmbedder(nn.Module): |
|
|
""" |
|
|
DINOv2 to extract patch embeddings from an image. |
|
|
Then projects to a common dimension with a linear layer. |
|
|
""" |
|
|
def __init__(self, model_name='facebookresearch/dinov2', arch='dinov2_vitb14', |
|
|
embedding_dim=512, dropout_prob=0.1): |
|
|
super().__init__() |
|
|
self.model = torch.hub.load(model_name, arch) |
|
|
print("Using DINOv2 model: ", arch) |
|
|
self.projection = nn.Linear(self.model.embed_dim, embedding_dim) |
|
|
self.dropout = nn.Dropout(p=dropout_prob) |
|
|
|
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: (B, 3, H, W), e.g. (B,3,224,224) image batch |
|
|
Returns: |
|
|
visual_feats: (B, Nv, D) |
|
|
Nv = number of visual tokens |
|
|
D = embedding_dim |
|
|
""" |
|
|
if len(x.shape) == 5: |
|
|
x = x.squeeze(0) |
|
|
if len(x.shape) == 3: |
|
|
x = x.unsqueeze(0) |
|
|
patches = self.model.get_intermediate_layers(x, n=1)[0] |
|
|
feats = self.projection(patches) |
|
|
feats = self.dropout(feats) |
|
|
|
|
|
return feats |
|
|
|
|
|
class Triad(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
audio_model_name="facebook/hubert-base-ls960", |
|
|
text_model_name="distilbert/distilbert-base-uncased", |
|
|
temperature=2.0, |
|
|
patch_sparsity_threshold=0.3, |
|
|
patch_sparsity_weight=0.1, |
|
|
visual_dropout_prob=0.1 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.audio_embedder = AudioEmbedder(embedding_dim=512, hubert_name=audio_model_name) |
|
|
self.text_embedder = TextEmbedder(embedding_dim=512, model_name=text_model_name) |
|
|
self.visual_embedder = ViTEmbedder(arch='dinov2_vitb14', |
|
|
embedding_dim=512, |
|
|
dropout_prob=visual_dropout_prob) |
|
|
|
|
|
self.temperature = nn.Parameter(torch.tensor(temperature)) |
|
|
self.patch_sparsity_threshold = patch_sparsity_threshold |
|
|
self.patch_sparsity_weight = patch_sparsity_weight |
|
|
|
|
|
def compute_similarity_matrix(self, feats1, feats2): |
|
|
""" |
|
|
Generic token-level dot-product similarity between feats1 and feats2. |
|
|
feats1: (B, N1, D) |
|
|
feats2: (B, N2, D) |
|
|
Returns sim: (B, N1, N2) |
|
|
""" |
|
|
sim = torch.bmm(feats1, feats2.transpose(1, 2)) |
|
|
return sim / self.temperature |
|
|
|
|
|
def forward(self, image=None, audio=None, text_list=None): |
|
|
assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided" |
|
|
if image is not None: assert image is not str, "Frames should be a path to an image" |
|
|
if audio is not None: |
|
|
assert isinstance(audio, torch.Tensor) and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (B, T)" |
|
|
if text_list is not None: |
|
|
assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1" |
|
|
if image is not None: |
|
|
device = next(self.parameters()).device |
|
|
|
|
|
|
|
|
if isinstance(image, list): |
|
|
|
|
|
processed_images = [] |
|
|
for img_path in image: |
|
|
img = Image.open(img_path).convert('RGB') |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
processed_img = transform(img).to(device) |
|
|
processed_images.append(processed_img) |
|
|
image = torch.stack(processed_images, dim=0) |
|
|
|
|
|
|
|
|
elif isinstance(image, str): |
|
|
img = Image.open(image).convert('RGB') |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
image = transform(img).to(device).unsqueeze(0) |
|
|
|
|
|
|
|
|
elif isinstance(image, torch.Tensor): |
|
|
|
|
|
if image.dim() == 3: |
|
|
image = image.unsqueeze(0) |
|
|
image = image.to(device) |
|
|
|
|
|
embeddings = {} |
|
|
if image is not None: |
|
|
embeddings['visual_feats'] = self.visual_embedder(image) |
|
|
if audio is not None: |
|
|
embeddings['audio_feats'] = self.audio_embedder(audio) |
|
|
if text_list is not None: |
|
|
embeddings['text_feats'], _ = self.text_embedder(text_list) |
|
|
|
|
|
if image is not None and text_list is not None: |
|
|
embeddings['vis_text_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['visual_feats']) |
|
|
if audio is not None and image is not None: |
|
|
embeddings['vis_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['audio_feats'], embeddings['visual_feats']) |
|
|
if text_list is not None and audio is not None: |
|
|
embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats']) |
|
|
return embeddings |
|
|
|
|
|
|