Spaces:
Running
Running
| import os | |
| import torch | |
| from torch.utils.data import DataLoader, random_split | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from dataset_advanced import COCODataset | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from pycocoevalcap.cider.cider import Cider | |
| # ========================= | |
| # GENERATE CAPTION | |
| # ========================= | |
| def generate_caption(model, processor, image, device): | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_length=30, | |
| num_beams=5 | |
| ) | |
| caption = processor.decode( | |
| generated_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| return caption | |
| # ========================= | |
| # CIDEr EVALUATION | |
| # ========================= | |
| def evaluate_cider(model, processor, val_dataset, device, max_samples=200): | |
| model.eval() | |
| cider_scorer = Cider() | |
| ground_truth = {} | |
| predictions = {} | |
| for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"): | |
| real_idx = val_dataset.indices[idx] | |
| ann = val_dataset.dataset.annotations[real_idx] | |
| image_path = os.path.join("train2017", ann["image"]) | |
| image = Image.open(image_path).convert("RGB") | |
| pred_caption = generate_caption(model, processor, image, device) | |
| ground_truth[idx] = ann["captions"] | |
| predictions[idx] = [pred_caption] | |
| score, _ = cider_scorer.compute_score(ground_truth, predictions) | |
| print(f"CIDEr Score: {score:.4 f}") | |
| model.train() | |
| return score | |
| # ========================= | |
| # MAIN | |
| # ========================= | |
| def main(): | |
| if not torch.backends.mps.is_available(): | |
| raise RuntimeError("MPS not available.") | |
| device = torch.device("mps") | |
| print("Using device:", device) | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| EPOCHS = 5 | |
| BATCH_SIZE = 6 | |
| LR = 3e-5 # Lower LR for partial unfreezing | |
| NUM_WORKERS = 0 | |
| FINAL_MODEL_DIR = "saved_model_phase2" | |
| os.makedirs(FINAL_MODEL_DIR, exist_ok=True) | |
| # ========================= | |
| # LOAD MODEL | |
| # ========================= | |
| processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| # 🔥 Unfreeze LAST 2 vision layers only | |
| for name, param in model.vision_model.named_parameters(): | |
| if "encoder.layers.10" in name or "encoder.layers.11" in name: | |
| param.requires_grad = True | |
| else: | |
| param.requires_grad = False | |
| model.gradient_checkpointing_enable() | |
| model.config.use_cache = False | |
| model.to(device) | |
| # ========================= | |
| # DATASET SPLIT | |
| # ========================= | |
| full_dataset = COCODataset( | |
| "annotations/subset_10k.jsonl", | |
| "train2017", | |
| processor | |
| ) | |
| train_size = int(0.9 * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_dataset, val_dataset = random_split( | |
| full_dataset, | |
| [train_size, val_size] | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=NUM_WORKERS | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=BATCH_SIZE, | |
| shuffle=False, | |
| num_workers=NUM_WORKERS | |
| ) | |
| optimizer = AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=LR | |
| ) | |
| scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS) | |
| # ========================= | |
| # EARLY STOPPING | |
| # ========================= | |
| best_cider = 0 | |
| patience = 3 | |
| counter = 0 | |
| # ========================= | |
| # TRAIN LOOP | |
| # ========================= | |
| for epoch in range(EPOCHS): | |
| model.train() | |
| total_loss = 0 | |
| progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}") | |
| for batch in progress_bar: | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| with torch.autocast(device_type="mps", dtype=torch.float16): | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| total_loss += loss.item() | |
| progress_bar.set_postfix(loss=loss.item()) | |
| avg_train_loss = total_loss / len(train_loader) | |
| print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}") | |
| # ========================= | |
| # VALIDATION LOSS | |
| # ========================= | |
| model.eval() | |
| val_loss = 0 | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| outputs = model(**batch) | |
| val_loss += outputs.loss.item() | |
| val_loss /= len(val_loader) | |
| print(f"Epoch {epoch+1} Validation Loss: {val_loss:.4f}") | |
| # ========================= | |
| # CIDEr | |
| # ========================= | |
| cider_score = evaluate_cider(model, processor, val_dataset, device) | |
| # ========================= | |
| # SAVE BEST CIDEr MODEL | |
| # ========================= | |
| if cider_score > best_cider: | |
| best_cider = cider_score | |
| counter = 0 | |
| model.save_pretrained(FINAL_MODEL_DIR) | |
| processor.save_pretrained(FINAL_MODEL_DIR) | |
| print("Best CIDEr model saved.") | |
| else: | |
| counter += 1 | |
| if counter >= patience: | |
| print("Early stopping triggered.") | |
| break | |
| scheduler.step() | |
| print("Phase 2 training complete.") | |
| if __name__ == "__main__": | |
| main() |