Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from transformers import GPT2Tokenizer | |
| from tqdm import tqdm | |
| from sklearn.model_selection import train_test_split | |
| from model import VQAModel | |
| from train import AugmentedVQADataset, Vocab, save_checkpoint, plot_losses | |
| def create_optimizer_with_differential_lr(model, clip_lr=5e-7, gpt_lr=5e-7, other_lr=3e-5): | |
| clip_params, gpt_params, other_params = [], [], [] | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| if 'clip_model' in name: | |
| clip_params.append(param) | |
| elif 'gpt2_model' in name: | |
| gpt_params.append(param) | |
| else: | |
| other_params.append(param) | |
| optimizer = torch.optim.AdamW([ | |
| {'params': clip_params, 'lr': clip_lr}, | |
| {'params': gpt_params, 'lr': gpt_lr}, | |
| {'params': other_params, 'lr': other_lr} | |
| ], weight_decay=1e-4) | |
| print(f"Optimizer: CLIP params: {len(clip_params)}, GPT-2 params: {len(gpt_params)}, Other params: {len(other_params)}") | |
| return optimizer | |
| def train_one_epoch(model, dataloader, optimizer, device, vocab, scaler): | |
| model.train() | |
| total_loss = 0.0 | |
| criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1) | |
| for batch in tqdm(dataloader): | |
| optimizer.zero_grad() | |
| images = batch['image'].to(device) | |
| questions = { | |
| 'input_ids': batch['question_ids'].to(device), | |
| 'attention_mask': batch['question_mask'].to(device) | |
| } | |
| answers = batch['answer_ids'].to(device) | |
| with torch.amp.autocast(device): | |
| logits = model(images, questions, answer_input_ids=answers) | |
| shifted_logits = logits[:, :-1, :].contiguous() | |
| shifted_answers = answers[:, 1:].contiguous() | |
| loss = criterion( | |
| shifted_logits.view(-1, shifted_logits.size(-1)), | |
| shifted_answers.view(-1) | |
| ) | |
| if torch.isnan(loss): | |
| print("NaN loss detected, skipping batch.") | |
| continue | |
| scaler.scale(loss).backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| total_loss += loss.item() | |
| return total_loss / len(dataloader) | |
| def validate_one_epoch(model, dataloader, device, vocab): | |
| model.eval() | |
| total_loss = 0.0 | |
| exact_matches = 0 | |
| total_samples = 0 | |
| criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id) | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader): | |
| images = batch['image'].to(device) | |
| questions = { | |
| 'input_ids': batch['question_ids'].to(device), | |
| 'attention_mask': batch['question_mask'].to(device) | |
| } | |
| answers = batch['answer_ids'].to(device) | |
| with torch.amp.autocast("cuda"): | |
| logits = model(images, questions, answer_input_ids=answers) | |
| shifted_logits = logits[:, :-1, :].contiguous() | |
| shifted_answers = answers[:, 1:].contiguous() | |
| loss = criterion( | |
| shifted_logits.view(-1, shifted_logits.size(-1)), | |
| shifted_answers.view(-1) | |
| ) | |
| total_loss += loss.item() | |
| generated = model(images, questions) | |
| for pred, true in zip(generated, answers): | |
| pred_text = vocab.decoder(pred.cpu().numpy()) | |
| true_text = vocab.decoder(true.cpu().numpy()) | |
| if pred_text.strip() == true_text.strip(): | |
| exact_matches += 1 | |
| total_samples += 1 | |
| avg_loss = total_loss / len(dataloader) | |
| exact_match_acc = exact_matches / total_samples | |
| return avg_loss, exact_match_acc | |
| def filter_spatial_directional_data(df): | |
| spatial_keywords = [ | |
| 'right', 'left', 'above', 'below', 'top', 'bottom', | |
| 'front', 'behind', 'next to', 'beside', 'near', | |
| 'looking', 'facing', 'pointing', 'direction', | |
| 'where is', 'which side', 'what side' | |
| ] | |
| directional_answers = [ | |
| 'up', 'down', 'left', 'right', 'forward', 'backward', | |
| 'north', 'south', 'east', 'west', 'straight', 'sideways' | |
| ] | |
| spatial_mask = df['question'].str.lower().str.contains('|'.join(spatial_keywords), na=False) | |
| directional_mask = df['answer'].str.lower().str.contains('|'.join(directional_answers), na=False) | |
| spatial_df = df[spatial_mask | directional_mask].copy() | |
| print(f"Found {len(spatial_df)} spatial/directional samples out of {len(df)} total") | |
| return spatial_df | |
| def main(): | |
| print("# VQA: Spatial-Enhanced Fine-Tuning") | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| random.seed(42) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(42) | |
| DATA_DIR = r"./gen_vqa_v2" | |
| CSV_PATH = os.path.join(DATA_DIR, "metadata.csv") | |
| PRETRAINED_CHECKPOINT = "./output2/feature_extraction/vqa_checkpoint.pt" | |
| OUTPUT_DIR = "./output2/spatial_finetuning" | |
| FINE_TUNED_CHECKPOINT = os.path.join(OUTPUT_DIR, "vqa_checkpoint.pt") | |
| LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv") | |
| LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| batch_size = 64 | |
| num_epochs = 50 | |
| patience = 8 | |
| clip_layers_to_unfreeze = 8 | |
| gpt_layers_to_unfreeze = 8 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| checkpoint = torch.load(PRETRAINED_CHECKPOINT, map_location=device) | |
| metadata = pd.read_csv(CSV_PATH) | |
| print(f"\nOriginal dataset size: {len(metadata)}") | |
| spatial_data = filter_spatial_directional_data(metadata) | |
| if len(spatial_data) < 1000: | |
| print(f"\nWARNING: Only {len(spatial_data)} spatial samples found!") | |
| print("Mixing 70% spatial data with 30% general data for balanced training") | |
| general_data = metadata[~metadata.index.isin(spatial_data.index)].sample(n=min(len(spatial_data)//2, len(metadata)//3), random_state=42) | |
| mixed_data = pd.concat([spatial_data, general_data]).sample(frac=1, random_state=42).reset_index(drop=True) | |
| else: | |
| print(f"Using {len(spatial_data)} spatial/directional samples") | |
| mixed_data = spatial_data | |
| vocab = Vocab() | |
| vocab.vocab = checkpoint['vocab'] | |
| vocab.vocab_size = len(checkpoint['vocab']) | |
| vocab.word2idx = checkpoint['word2idx'] | |
| vocab.idx2word = checkpoint['idx2word'] | |
| vocab.pad_token_id = checkpoint['pad_token_id'] | |
| vocab.bos_token_id = checkpoint['bos_token_id'] | |
| vocab.eos_token_id = checkpoint['eos_token_id'] | |
| vocab.unk_token_id = checkpoint['unk_token_id'] | |
| print(f"Answer vocabulary size: {len(vocab.vocab)}") | |
| model = VQAModel( | |
| vocab_size=len(checkpoint['vocab']), | |
| device=device, | |
| question_max_len=checkpoint.get('question_max_len', 20), | |
| answer_max_len=checkpoint.get('answer_max_len', 12), | |
| pad_token_id=checkpoint['pad_token_id'], | |
| bos_token_id=checkpoint['bos_token_id'], | |
| eos_token_id=checkpoint['eos_token_id'], | |
| unk_token_id=checkpoint['unk_token_id'], | |
| hidden_size=512, | |
| num_layers=2 | |
| ).to(device) | |
| question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") | |
| if question_tokenizer.pad_token is None: | |
| question_tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
| model.gpt2_model.resize_token_embeddings(len(question_tokenizer)) | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| print("Pretrained model loaded successfully!\n") | |
| print(f"UNFREEZING {clip_layers_to_unfreeze} CLIP LAYERS & {gpt_layers_to_unfreeze} GPT-2 LAYERS FOR SPATIAL UNDERSTANDING") | |
| model.unfreeze_clip_layers(num_layers=clip_layers_to_unfreeze) | |
| model.unfreeze_gpt2_layers(num_layers=gpt_layers_to_unfreeze) | |
| train_df, test_df = train_test_split(mixed_data, test_size=0.2, random_state=42) | |
| val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42) | |
| print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}\n") | |
| train_dataset = AugmentedVQADataset(train_df, DATA_DIR, question_tokenizer, vocab, | |
| clip_processor=model.clip_preprocess, augment=True, | |
| question_max_len=20, answer_max_len=12) | |
| val_dataset = AugmentedVQADataset(val_df, DATA_DIR, question_tokenizer, vocab, | |
| clip_processor=model.clip_preprocess, augment=False, | |
| question_max_len=20, answer_max_len=12) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
| optimizer = create_optimizer_with_differential_lr( | |
| model, | |
| clip_lr=3e-7, | |
| gpt_lr=3e-7, | |
| other_lr=2e-5 | |
| ) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True) | |
| scaler = torch.amp.GradScaler(device) | |
| print("\nSTARTING SPATIAL-ENHANCED FINE-TUNING") | |
| best_val_loss = np.inf | |
| best_exact_match = 0.0 | |
| logs = [] | |
| counter = 0 | |
| for epoch in range(num_epochs): | |
| print(f"\nSpatial Fine-tuning Epoch {epoch+1}/{num_epochs}") | |
| train_loss = train_one_epoch(model, train_loader, optimizer, device, vocab, scaler) | |
| val_loss, val_exact_match = validate_one_epoch(model, val_loader, device, vocab) | |
| print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Exact Match: {val_exact_match:.4f} | LR: {optimizer.param_groups[0]['lr']}") | |
| scheduler.step(val_exact_match) | |
| if val_exact_match > best_exact_match: | |
| best_exact_match = val_exact_match | |
| save_checkpoint(model, optimizer, epoch, vocab, FINE_TUNED_CHECKPOINT) | |
| print("Checkpoint saved!") | |
| counter = 0 | |
| else: | |
| counter += 1 | |
| print(f"No improvement for {counter} epochs.") | |
| if counter >= patience: | |
| print(f"\nEarly stopping after {patience} epochs without improvement") | |
| break | |
| logs.append([epoch + 1, train_loss, val_loss, val_exact_match, optimizer.param_groups[0]['lr']]) | |
| pd.DataFrame(logs, columns=["epoch", "train_loss", "val_loss", "val_exact_match", "lr"]).to_csv(LOG_CSV, index=False) | |
| plot_losses([x[1] for x in logs], [x[2] for x in logs], save_path=LOSS_GRAPH_PATH) | |
| print("\nFINE-TUNING COMPLETE") | |
| print(f"Best exact match: {best_exact_match:.4f}") | |
| print(f"Model saved to: {FINE_TUNED_CHECKPOINT}") | |
| if __name__ == "__main__": | |
| main() |