| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, random_split |
| from text_encoder import * |
| from vision_encoder import * |
| import os |
| import json |
| import numpy as np |
| import random |
| from tqdm import tqdm |
| import datetime |
|
|
| |
| class VisionCaptionDataset(torch.utils.data.Dataset): |
| def __init__(self, captions_path, embeddings_dir, normalize=True): |
| with open(captions_path, 'r') as f: |
| self.captions_dict = json.load(f) |
|
|
| self.embeddings_dir = embeddings_dir |
| self.image_ids = list(self.captions_dict.keys()) |
| self.normalize = normalize |
|
|
| def __len__(self): |
| return len(self.image_ids) |
|
|
| def __getitem__(self, idx): |
| image_id = self.image_ids[idx] |
| |
| caption_entry = random.choice(self.captions_dict[image_id]) |
| tokenized_caption = caption_entry["tokenized"] |
| attention_mask = caption_entry["attention_mask"] |
|
|
| embedding_path = os.path.join(self.embeddings_dir, f"{image_id}.npy") |
| embedding = np.load(embedding_path) |
|
|
| embedding = torch.tensor(embedding, dtype=torch.float32) |
| tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long) |
| attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
|
|
| return embedding, tokenized_caption, attention_mask |
|
|
|
|
| class JointNetwork(nn.Module): |
| def __init__(self): |
| super(JointNetwork, self).__init__() |
| |
| self.text_encoder = modernBERT("answerdotai/ModernBERT-base") |
| |
| for param in self.text_encoder.parameters(): |
| param.requires_grad = True |
|
|
| self.vision_projector = nn.Linear(1152, 512) |
| self.text_projector = nn.Linear(768, 512) |
|
|
| def forward(self, tokenized_text, image_encoding): |
| vision_patch_pooled = image_encoding.mean(dim=1) |
| text_output = self.text_encoder(tokenized_text) |
| text_pooled = text_output.mean(dim=1) |
|
|
| vision_embedded = self.vision_projector(vision_patch_pooled) |
| text_embedded = self.text_projector(text_pooled) |
|
|
| vision_embedded = F.normalize(vision_embedded, dim=1) |
| text_embedded = F.normalize(text_embedded, dim=1) |
|
|
| return text_embedded, vision_embedded |
|
|
|
|
| def infoNCE_loss(text_features, vision_features, temperature=0.07): |
| text_features = F.normalize(text_features, p=2, dim=-1) |
| vision_features = F.normalize(vision_features, p=2, dim=-1) |
|
|
| similarity_matrix = torch.matmul(text_features, vision_features.T) / temperature |
| batch_size = vision_features.size(0) |
| labels = torch.arange(batch_size, device=vision_features.device) |
|
|
| loss_text_to_image = F.cross_entropy(similarity_matrix, labels) |
| loss_image_to_text = F.cross_entropy(similarity_matrix.T, labels) |
|
|
| return (loss_text_to_image + loss_image_to_text) / 2 |
|
|
|
|
| def train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=5, freeze_text_encoder=True, checkpoint_path=None): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| best_val_loss = float('inf') |
|
|
| |
| if freeze_text_encoder: |
| for param in model.text_encoder.parameters(): |
| param.requires_grad = False |
|
|
| |
| for param in model.vision_projector.parameters(): |
| param.requires_grad = True |
| for param in model.text_projector.parameters(): |
| param.requires_grad = True |
|
|
| model.to(device) |
| |
| for epoch in range(num_epochs): |
|
|
| |
| model.train() |
| total_loss = 0.0 |
|
|
| print(f"\nEpoch {epoch + 1}/{num_epochs} - Training:") |
| train_progress = tqdm(train_loader, desc="Training", leave=True) |
|
|
| for image_embeddings, tokenized_captions, attention_masks in train_progress: |
| text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)} |
| image_embeddings = image_embeddings.to(device) |
|
|
| optimizer.zero_grad() |
| text_features, vision_features = model(text_inputs, image_embeddings) |
| loss = infoNCE_loss(text_features, vision_features) |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() |
| train_progress.set_postfix(loss=loss.item()) |
|
|
| scheduler.step() |
|
|
| |
| model.eval() |
| val_loss = 0.0 |
|
|
| print(f"\nEpoch {epoch + 1}/{num_epochs} - Validation:") |
| val_progress = tqdm(val_loader, desc="Validation", leave=True) |
|
|
| with torch.no_grad(): |
| for image_embeddings, tokenized_captions, attention_masks in val_progress: |
| text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)} |
| image_embeddings = image_embeddings.to(device) |
|
|
| text_features, vision_features = model(text_inputs, image_embeddings) |
| loss = infoNCE_loss(text_features, vision_features) |
| val_loss += loss.item() |
| val_progress.set_postfix(loss=loss.item()) |
|
|
| avg_train_loss = total_loss / len(train_loader) |
| avg_val_loss = val_loss / len(val_loader) |
| print(f"\nEpoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") |
|
|
| |
| if checkpoint_path is not None: |
| if avg_val_loss < best_val_loss: |
| best_val_loss = avg_val_loss |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'val_loss': best_val_loss |
| }, checkpoint_path) |
| print(f"New Best Model Saved at: {checkpoint_path} (Val Loss: {best_val_loss:.4f})") |
|
|
| print("Training completed!") |
|
|
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json' |
| |
| embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/vision_embeddings_reduced2' |
|
|
| |
| full_dataset = VisionCaptionDataset(captions_path, embeddings_dir) |
| train_size = int(0.85 * len(full_dataset)) |
| val_size = len(full_dataset) - train_size |
| train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8, pin_memory=True) |
| val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True) |
|
|
| |
| model = JointNetwork().to(device) |
|
|
| checkpoint_path = f"./checkpoints/model_checkpoint_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pth" |
|
|
| |
| initial_lr = 1e-4 |
| min_lr = 1e-6 |
| num_epochs = 16 |
| optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr) |
|
|
| |
| print("\n### Phase 1: Training new layers only (Text Encoder Frozen) ###") |
| train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs, freeze_text_encoder=True, checkpoint_path=checkpoint_path) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |