image_captioning / train_phase2.py
pchandragrid's picture
Deploy Streamlit app
a745a5e
import os
import torch
from torch.utils.data import DataLoader, random_split
from transformers import BlipProcessor, BlipForConditionalGeneration
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from dataset_advanced import COCODataset
from tqdm import tqdm
from PIL import Image
from pycocoevalcap.cider.cider import Cider
# =========================
# GENERATE CAPTION
# =========================
def generate_caption(model, processor, image, device):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_length=30,
num_beams=5
)
caption = processor.decode(
generated_ids[0],
skip_special_tokens=True
)
return caption
# =========================
# CIDEr EVALUATION
# =========================
def evaluate_cider(model, processor, val_dataset, device, max_samples=200):
model.eval()
cider_scorer = Cider()
ground_truth = {}
predictions = {}
for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
real_idx = val_dataset.indices[idx]
ann = val_dataset.dataset.annotations[real_idx]
image_path = os.path.join("train2017", ann["image"])
image = Image.open(image_path).convert("RGB")
pred_caption = generate_caption(model, processor, image, device)
ground_truth[idx] = ann["captions"]
predictions[idx] = [pred_caption]
score, _ = cider_scorer.compute_score(ground_truth, predictions)
print(f"CIDEr Score: {score:.4 f}")
model.train()
return score
# =========================
# MAIN
# =========================
def main():
if not torch.backends.mps.is_available():
raise RuntimeError("MPS not available.")
device = torch.device("mps")
print("Using device:", device)
# =========================
# CONFIG
# =========================
EPOCHS = 5
BATCH_SIZE = 6
LR = 3e-5 # Lower LR for partial unfreezing
NUM_WORKERS = 0
FINAL_MODEL_DIR = "saved_model_phase2"
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
# =========================
# LOAD MODEL
# =========================
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
# 🔥 Unfreeze LAST 2 vision layers only
for name, param in model.vision_model.named_parameters():
if "encoder.layers.10" in name or "encoder.layers.11" in name:
param.requires_grad = True
else:
param.requires_grad = False
model.gradient_checkpointing_enable()
model.config.use_cache = False
model.to(device)
# =========================
# DATASET SPLIT
# =========================
full_dataset = COCODataset(
"annotations/subset_10k.jsonl",
"train2017",
processor
)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(
full_dataset,
[train_size, val_size]
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
optimizer = AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=LR
)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
# =========================
# EARLY STOPPING
# =========================
best_cider = 0
patience = 3
counter = 0
# =========================
# TRAIN LOOP
# =========================
for epoch in range(EPOCHS):
model.train()
total_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
for batch in progress_bar:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.autocast(device_type="mps", dtype=torch.float16):
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
avg_train_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")
# =========================
# VALIDATION LOSS
# =========================
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
val_loss += outputs.loss.item()
val_loss /= len(val_loader)
print(f"Epoch {epoch+1} Validation Loss: {val_loss:.4f}")
# =========================
# CIDEr
# =========================
cider_score = evaluate_cider(model, processor, val_dataset, device)
# =========================
# SAVE BEST CIDEr MODEL
# =========================
if cider_score > best_cider:
best_cider = cider_score
counter = 0
model.save_pretrained(FINAL_MODEL_DIR)
processor.save_pretrained(FINAL_MODEL_DIR)
print("Best CIDEr model saved.")
else:
counter += 1
if counter >= patience:
print("Early stopping triggered.")
break
scheduler.step()
print("Phase 2 training complete.")
if __name__ == "__main__":
main()