from datasets import load_dataset import pandas as pd import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader, random_split from PIL import Image from collections import Counter import pickle import re from tqdm import tqdm import os # ======================== # CONFIG # ======================== DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EPOCHS = 50 BATCH_SIZE = 32 LR = 5e-4 MAX_LEN = 20 # ======================== # LOAD DATASET # ======================== dataset = load_dataset("flaviagiammarino/vqa-rad") df = pd.DataFrame(dataset["train"]) df = df[["image", "question", "answer"]] # ======================== # CLEAN TEXT # ======================== def clean_text(text): text = text.lower() return re.sub(r"[^a-z0-9 ]", "", text) df["question"] = df["question"].apply(clean_text) df["answer"] = df["answer"].apply(clean_text) # ======================== # FILTER TOP ANSWERS # ======================== top_answers = df["answer"].value_counts().nlargest(50).index df = df[df["answer"].isin(top_answers)] answer_to_idx = {a:i for i,a in enumerate(top_answers)} idx_to_answer = {i:a for a,i in answer_to_idx.items()} df["answer_encoded"] = df["answer"].apply(lambda x: answer_to_idx[x]) # ======================== # VOCAB # ======================== vocab = {"":0, "":1} counter = Counter() for q in df["question"]: for w in q.split(): counter[w] += 1 idx = 2 for word, count in counter.items(): if count > 2: vocab[word] = idx idx += 1 def encode_question(q): tokens = q.split() enc = [vocab.get(w, vocab[""]) for w in tokens] enc = enc[:MAX_LEN] + [vocab[""]] * (MAX_LEN - len(enc)) return enc df["question_encoded"] = df["question"].apply(encode_question) # ======================== # DATASET CLASS # ======================== transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor() ]) class VQADataset(Dataset): def __init__(self, df): self.df = df def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] image = row["image"].convert("RGB") image = transform(image) question = torch.tensor(row["question_encoded"]) answer = torch.tensor(row["answer_encoded"]) return image, question, answer # ======================== # SPLIT DATA # ======================== dataset_full = VQADataset(df) train_size = int(0.8 * len(dataset_full)) val_size = len(dataset_full) - train_size train_dataset, val_dataset = random_split(dataset_full, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) # ======================== # MODEL # ======================== import torchvision.models as models class VQAModel(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers): super().__init__() self.cnn = models.resnet18(weights="DEFAULT") self.cnn.fc = nn.Identity() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True) self.fc1 = nn.Linear(512 + hidden_dim, 256) self.relu = nn.ReLU() self.fc2 = nn.Linear(256, num_answers) def forward(self, image, question): img_feat = self.cnn(image) q_embed = self.embedding(question) _, (h, _) = self.lstm(q_embed) q_feat = h.squeeze(0) x = self.relu(self.fc1(torch.cat((img_feat, q_feat), dim=1))) return self.fc2(x) model = VQAModel(len(vocab), 300, 256, len(answer_to_idx)).to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=LR) # ======================== # TRAIN LOOP # ======================== for epoch in range(EPOCHS): model.train() total_loss = 0 for images, questions, answers in tqdm(train_loader): images, questions, answers = images.to(DEVICE), questions.to(DEVICE), answers.to(DEVICE) outputs = model(images, questions) loss = criterion(outputs, answers) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() # VALIDATION model.eval() val_loss = 0 with torch.no_grad(): for images, questions, answers in val_loader: images, questions, answers = images.to(DEVICE), questions.to(DEVICE), answers.to(DEVICE) outputs = model(images, questions) loss = criterion(outputs, answers) val_loss += loss.item() print(f"\nEpoch {epoch+1}") print(f"Train Loss: {total_loss/len(train_loader):.4f}") print(f"Val Loss: {val_loss/len(val_loader):.4f}") # ======================== # SAVE MODEL # ======================== os.makedirs("weights", exist_ok=True) torch.save(model.state_dict(), "weights/vqa_model.pth") with open("weights/vocab.pkl", "wb") as f: pickle.dump(vocab, f) with open("weights/answers.pkl", "wb") as f: pickle.dump(idx_to_answer, f) print("\n✅ Training Complete & Model Saved!")