Spaces:
Running
Running
| import os | |
| import torch | |
| from torch.utils.data import DataLoader, random_split | |
| from transformers import ( | |
| VisionEncoderDecoderModel, | |
| ViTImageProcessor, | |
| AutoTokenizer, | |
| GPT2Config, | |
| GPT2LMHeadModel, | |
| ViTModel | |
| ) | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from dataset_vit_gpt2 import COCODatasetViTGPT2 | |
| from tqdm import tqdm | |
| from pycocoevalcap.cider.cider import Cider | |
| from PIL import Image | |
| # ========================================== | |
| # GENERATE CAPTION | |
| # ========================================== | |
| def generate_caption(model, processor, tokenizer, image, device): | |
| pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| pixel_values=pixel_values, | |
| num_beams=5, | |
| max_length=20, | |
| length_penalty=1.0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # ========================================== | |
| # CIDEr EVALUATION | |
| # ========================================== | |
| def evaluate_cider(model, processor, tokenizer, val_dataset, device, max_samples=200): | |
| model.eval() | |
| cider_scorer = Cider() | |
| ground_truth = {} | |
| predictions = {} | |
| for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"): | |
| real_idx = val_dataset.indices[idx] | |
| ann = val_dataset.dataset.annotations[real_idx] | |
| image_path = os.path.join("train2017", ann["image"]) | |
| image = Image.open(image_path).convert("RGB") | |
| pred_caption = generate_caption(model, processor, tokenizer, image, device) | |
| ground_truth[idx] = ann["captions"] | |
| predictions[idx] = [pred_caption] | |
| score, _ = cider_scorer.compute_score(ground_truth, predictions) | |
| print(f"CIDEr Score: {score:.4f}") | |
| model.train() | |
| return score | |
| # ========================================== | |
| # MAIN | |
| # ========================================== | |
| def main(): | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print("Using device:", device) | |
| EPOCHS = 5 | |
| BATCH_SIZE = 6 | |
| LR = 3e-5 | |
| SAVE_DIR = "saved_vit_gpt2" | |
| os.makedirs(SAVE_DIR, exist_ok=True) | |
| # ------------------------------------------ | |
| # Build Encoder + Decoder | |
| # ------------------------------------------ | |
| encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") | |
| decoder_config = GPT2Config.from_pretrained("gpt2") | |
| decoder_config.is_decoder = True | |
| decoder_config.add_cross_attention = True | |
| decoder = GPT2LMHeadModel.from_pretrained("gpt2", config=decoder_config) | |
| model = VisionEncoderDecoderModel( | |
| encoder=encoder, | |
| decoder=decoder | |
| ) | |
| processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.config.pad_token_id = tokenizer.eos_token_id | |
| model.config.decoder_start_token_id = tokenizer.bos_token_id | |
| model.config.eos_token_id = tokenizer.eos_token_id | |
| model.config.vocab_size = model.config.decoder.vocab_size | |
| model.to(device) | |
| # ------------------------------------------ | |
| # DATASET | |
| # ------------------------------------------ | |
| dataset = COCODatasetViTGPT2( | |
| "annotations/subset_10k.jsonl", | |
| "train2017", | |
| processor, | |
| tokenizer, | |
| mode="short" | |
| ) | |
| train_size = int(0.9 * len(dataset)) | |
| val_size = len(dataset) - train_size | |
| train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) | |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) | |
| optimizer = AdamW(model.parameters(), lr=LR) | |
| scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS) | |
| best_cider = 0 | |
| # ========================================== | |
| # TRAIN LOOP | |
| # ========================================== | |
| for epoch in range(EPOCHS): | |
| model.train() | |
| total_loss = 0 | |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"): | |
| pixel_values = batch["pixel_values"].to(device) | |
| labels = batch["labels"].to(device) | |
| outputs = model(pixel_values=pixel_values, labels=labels) | |
| loss = outputs.loss | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| total_loss += loss.item() | |
| avg_loss = total_loss / len(train_loader) | |
| print(f"Train Loss: {avg_loss:.4f}") | |
| # ------------------------------------------ | |
| # CIDEr Evaluation | |
| # ------------------------------------------ | |
| cider_score = evaluate_cider( | |
| model, | |
| processor, | |
| tokenizer, | |
| val_dataset, | |
| device | |
| ) | |
| # Save best model | |
| if cider_score > best_cider: | |
| best_cider = cider_score | |
| model.save_pretrained(SAVE_DIR) | |
| tokenizer.save_pretrained(SAVE_DIR) | |
| processor.save_pretrained(SAVE_DIR) | |
| print("Best model saved.") | |
| scheduler.step() | |
| print("ViT-GPT2 Training complete.") | |
| if __name__ == "__main__": | |
| main() |