image_captioning / src /training /train_phase2.py
pchandragrid's picture
Deploy Streamlit app
a745a5e
import os
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, random_split
from transformers import BlipForConditionalGeneration, BlipProcessor
from tqdm import tqdm
from src.data.coco_advanced_dataset import COCODatasetAdvanced
from src.evaluation.cider_eval import evaluate_cider
def main() -> None:
if not torch.backends.mps.is_available():
raise RuntimeError("MPS not available.")
device = torch.device("mps")
print("Using device:", device)
# =========================
# CONFIG
# =========================
EPOCHS = 5
BATCH_SIZE = 6
LR = 3e-5 # Lower LR for partial unfreezing
NUM_WORKERS = 0
FINAL_MODEL_DIR = "saved_model_phase2"
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
# =========================
# LOAD MODEL
# =========================
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
# Unfreeze LAST 2 vision layers only
for name, param in model.vision_model.named_parameters():
if "encoder.layers.10" in name or "encoder.layers.11" in name:
param.requires_grad = True
else:
param.requires_grad = False
model.gradient_checkpointing_enable()
model.config.use_cache = False
model.to(device)
# =========================
# DATASET SPLIT
# =========================
MODE = "long" # change to "short" or "mixed" as needed
full_dataset = COCODatasetAdvanced(
"annotations/subset_10k.jsonl",
"train2017",
processor,
mode=MODE,
)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(
full_dataset,
[train_size, val_size],
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
)
optimizer = AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=LR,
)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
# =========================
# EARLY STOPPING
# =========================
best_cider = 0.0
patience = 3
counter = 0
# =========================
# TRAIN LOOP
# =========================
for epoch in range(EPOCHS):
model.train()
total_loss = 0.0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
for batch in progress_bar:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.autocast(device_type="mps", dtype=torch.float16):
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
avg_train_loss = total_loss / len(train_loader)
print(f"Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}")
# =========================
# VALIDATION LOSS
# =========================
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
val_loss += outputs.loss.item()
val_loss /= len(val_loader)
print(f"Epoch {epoch + 1} Validation Loss: {val_loss:.4f}")
# =========================
# CIDEr
# =========================
cider_score = evaluate_cider(model, processor, val_dataset, device)
# =========================
# SAVE BEST CIDEr MODEL
# =========================
if cider_score > best_cider:
best_cider = cider_score
counter = 0
model.save_pretrained(FINAL_MODEL_DIR)
processor.save_pretrained(FINAL_MODEL_DIR)
print("Best CIDEr model saved.")
else:
counter += 1
if counter >= patience:
print("Early stopping triggered.")
break
scheduler.step()
print("Phase 2 training complete.")
if __name__ == "__main__":
main()