{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "463ed6b4", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "id": "c62b8cc4", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import random\n", "import requests\n", "from tqdm import tqdm\n", "from pycocotools.coco import COCO\n", "from pathlib import Path\n", "\n", "\n", "SUBSET_SIZE = 30_000\n", "RANDOM_SEED = 42\n", "\n", "OUTPUT_DIR = Path(\"data/processed\")\n", "IMAGES_DIR = OUTPUT_DIR / \"images\"\n", "\n", "CAPTIONS_OUT = OUTPUT_DIR / \"captions.json\"\n", "SPLITS_OUT = OUTPUT_DIR / \"splits.json\"\n", "\n", "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n", "IMAGES_DIR.mkdir(parents=True, exist_ok=True)\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "4f821cb9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Captions file ready: annotations/captions_train2017.json\n" ] } ], "source": [ "ANNOT_DIR = Path(\"annotations\")\n", "ANNOT_DIR.mkdir(exist_ok=True)\n", "\n", "train_caps_path = ANNOT_DIR / \"captions_train2017.json\"\n", "\n", "if not train_caps_path.exists():\n", " print(\"Downloading COCO captions_train2017.json (~45 MB)...\")\n", " url = \"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\"\n", " zip_path = ANNOT_DIR / \"annotations_trainval2017.zip\"\n", "\n", " # download zip\n", " r = requests.get(url, stream=True)\n", " with open(zip_path, \"wb\") as f:\n", " for chunk in r.iter_content(chunk_size=8192):\n", " if chunk:\n", " f.write(chunk)\n", "\n", " # unzip only the captions file\n", " import zipfile\n", " with zipfile.ZipFile(zip_path, \"r\") as z:\n", " for name in z.namelist():\n", " if \"captions_train2017.json\" in name:\n", " z.extract(name, ANNOT_DIR)\n", " # Normalize to expected name\n", " (ANNOT_DIR / name).rename(train_caps_path)\n", "\n", "print(\"Captions file ready:\", train_caps_path)\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "864b367e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading captions_train2017.json ...\n", "loading annotations into memory...\n", "Done (t=0.49s)\n", "creating index...\n", "index created!\n", "Total train images: 118287\n" ] } ], "source": [ "print(\"Loading captions_train2017.json ...\")\n", "coco = COCO(str(train_caps_path))\n", "\n", "image_ids = list(coco.imgs.keys())\n", "print(f\"Total train images: {len(image_ids)}\")\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "dcd38f25", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sampled: 30000\n" ] } ], "source": [ "random.seed(RANDOM_SEED)\n", "sampled_ids = random.sample(image_ids, SUBSET_SIZE)\n", "\n", "print(\"Sampled:\", len(sampled_ids))\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "72d5e7b1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "30000" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Collect metadata for faster processing\n", "metadata = {}\n", "\n", "for img_id in sampled_ids:\n", " info = coco.imgs[img_id]\n", " anns = coco.imgToAnns[img_id]\n", "\n", " captions = [ann[\"caption\"] for ann in anns]\n", "\n", " metadata[str(img_id)] = {\n", " \"file_name\": info[\"file_name\"],\n", " \"url\": info[\"coco_url\"],\n", " \"captions\": captions,\n", " }\n", "\n", "len(metadata)\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "1ded3635", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Previewing 3 random metadata samples:\n", "\n", "Image ID: 68130\n", "File name: 000000068130.jpg\n", "URL: http://images.cocodataset.org/train2017/000000068130.jpg\n", "Captions:\n", " - a group of people crossing a city street \n", " - A group of people cross a cross walk in a big city\n", " - a lot of people crossing a crosswalk \n", " - Many people are crossing the street in front of some large buildings.\n", " - a photo taken on the corner facing old looking buildings\n", "------------------------------------------------------------\n", "Image ID: 222195\n", "File name: 000000222195.jpg\n", "URL: http://images.cocodataset.org/train2017/000000222195.jpg\n", "Captions:\n", " - A cay laying on top of a blue couch arm next to a wall.\n", " - A cat lying down with a packaged toothbrush on its head.\n", " - The cat is irritated that there is a packaged toothbrush resting on its head.\n", " - A close up view of cat carrying a toothbrush on its head.\n", " - A grey and black cat with a toothbrush on its head.\n", "------------------------------------------------------------\n", "Image ID: 133386\n", "File name: 000000133386.jpg\n", "URL: http://images.cocodataset.org/train2017/000000133386.jpg\n", "Captions:\n", " - some people are on skateboards at a skate park\n", " - A skateboard park with many skateboarders doing different stunts \n", " - The young man is practicing his tricks on his skateboard.\n", " - An edited photo showing a single boy performing various skateboard tricks in a single picture.\n", " - A kid performs many tricks while at the skate park.\n", "------------------------------------------------------------\n" ] } ], "source": [ "print(\"Previewing 3 random metadata samples:\\n\")\n", "\n", "for img_id in list(metadata.keys())[:3]:\n", " item = metadata[img_id]\n", " print(\"Image ID:\", img_id)\n", " print(\"File name:\", item[\"file_name\"])\n", " print(\"URL:\", item[\"url\"])\n", " print(\"Captions:\")\n", " for c in item[\"captions\"]:\n", " print(\" -\", c)\n", " print(\"-\" * 60)\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "4962a0c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading images...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 30000/30000 [6:22:54<00:00, 1.31it/s] \n" ] } ], "source": [ "def download_image(url, out_path):\n", " try:\n", " r = requests.get(url, timeout=10, stream=True)\n", " if r.status_code != 200:\n", " return False\n", " with open(out_path, \"wb\") as f:\n", " for chunk in r.iter_content(8192):\n", " if chunk:\n", " f.write(chunk)\n", " return True\n", " except:\n", " return False\n", "\n", "print(\"Downloading images...\")\n", "\n", "for img_id, item in tqdm(metadata.items()):\n", " url = item[\"url\"]\n", " fname = item[\"file_name\"]\n", "\n", " out_path = IMAGES_DIR / fname\n", " if out_path.exists():\n", " continue\n", "\n", " ok = download_image(url, out_path)\n", " if not ok:\n", " print(\"FAILED DOWNLOAD:\", url)\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "87d7a066", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Writing captions.json ...\n" ] } ], "source": [ "print(\"Writing captions.json ...\")\n", "with open(CAPTIONS_OUT, \"w\") as f:\n", " json.dump(metadata, f, indent=2)\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "e4356477", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating splits...\n", "Splits saved:\n", "{'train': 25500, 'val': 3000, 'test': 1500}\n", "Done!\n" ] } ], "source": [ "print(\"Creating splits...\")\n", "\n", "ids = list(metadata.keys())\n", "random.shuffle(ids)\n", "\n", "n = len(ids)\n", "n_train = int(0.85 * n)\n", "n_val = int(0.10 * n)\n", "n_test = n - n_train - n_val\n", "\n", "splits = {\n", " \"train\": ids[:n_train],\n", " \"val\": ids[n_train:n_train + n_val],\n", " \"test\": ids[n_train + n_val:]\n", "}\n", "\n", "with open(SPLITS_OUT, \"w\") as f:\n", " json.dump(splits, f, indent=2)\n", "\n", "print(\"Splits saved:\")\n", "print({k: len(v) for k, v in splits.items()})\n", "print(\"Done!\")\n" ] }, { "cell_type": "markdown", "id": "f1451b3b", "metadata": {}, "source": [ "## Unit Tests" ] }, { "cell_type": "code", "execution_count": 5, "id": "a3aac8a3", "metadata": {}, "outputs": [], "source": [ "import yaml\n", "from transformers import T5TokenizerFast\n", "\n", "with open(\"configs/default.yaml\") as f:\n", " cfg = yaml.safe_load(f)\n", "\n", "tokenizer = T5TokenizerFast.from_pretrained(cfg[\"model\"][\"t5_name\"])\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "3cfac312", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch shape: torch.Size([4, 3, 224, 224])\n", "Input IDs shape: torch.Size([4, 64])\n", "Attention mask shape: torch.Size([4, 64])\n", "Caption example: A man in brown sitting on a motorcycle\n" ] } ], "source": [ "from data.loaders import get_coco_dataloaders\n", "import torch\n", "\n", "train_loader, val_loader, test_loader = get_coco_dataloaders(\n", " batch_size=4, data_dir=\"data/processed\"\n", ")\n", "\n", "batch = next(iter(train_loader))\n", "\n", "print(\"Image batch shape:\", batch[\"pixel_values\"].shape) # Expect [B, 3, H, W]\n", "print(\"Input IDs shape:\", batch[\"input_ids\"].shape)\n", "print(\"Attention mask shape:\", batch[\"attention_mask\"].shape)\n", "\n", "# Show one caption\n", "tokenizer_pad = batch[\"input_ids\"][0]\n", "decoded = tokenizer.decode(\n", " tokenizer_pad[tokenizer_pad != tokenizer.pad_token_id],\n", " skip_special_tokens=True\n", ")\n", "print(\"Caption example:\", decoded)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "389d7636", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model loaded with encoder: ResnetCNNEncoder\n", "T5 model: t5-small\n" ] } ], "source": [ "import yaml\n", "import torch\n", "from models.vision_t5 import VisionT5\n", "from models.encoder_projection_t5 import ImageProjection\n", "import models.encoders as encoders\n", "from train import build_model\n", "\n", "with open(\"configs/default.yaml\") as f:\n", " config = yaml.safe_load(f)\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "model, tokenizer = build_model(config)\n", "model.to(device)\n", "\n", "print(\"Model loaded with encoder:\", config[\"model\"][\"encoder\"])\n", "print(\"T5 model:\", config[\"model\"][\"t5_name\"])\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "10c47a24", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_87086/2683574345.py:9: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", " with torch.cuda.amp.autocast():\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "DEBUG — encoder hidden shape: torch.Size([4, 1, 512])\n", "DEBUG — encoder hidden mean: 0.001422882080078125\n", "Forward pass OK. Loss: 6.275390625\n" ] } ], "source": [ "batch = next(iter(train_loader))\n", "pixel_values = batch[\"pixel_values\"].to(device)\n", "input_ids = batch[\"input_ids\"].to(device)\n", "attention_mask = batch[\"attention_mask\"].to(device)\n", "\n", "labels = input_ids.clone()\n", "labels[labels == tokenizer.pad_token_id] = -100\n", "\n", "with torch.cuda.amp.autocast():\n", " outputs = model(\n", " pixel_values=pixel_values,\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", "print(\"Forward pass OK. Loss:\", outputs.loss.item())\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7ceca299", "metadata": {}, "outputs": [], "source": [ "from torch.optim import AdamW\n", "from torch.cuda.amp import GradScaler\n", "\n", "optimizer = AdamW(model.parameters(), lr=1e-4)\n", "scaler = GradScaler()\n", "\n", "model.train()\n", "\n", "pixel_values = batch[\"pixel_values\"].to(device)\n", "input_ids = batch[\"input_ids\"].to(device)\n", "attention_mask = batch[\"attention_mask\"].to(device)\n", "\n", "labels = input_ids.clone()\n", "labels[labels == tokenizer.pad_token_id] = -100\n", "\n", "optimizer.zero_grad()\n", "\n", "with torch.cuda.amp.autocast():\n", " outputs = model(\n", " pixel_values=pixel_values,\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", " loss = outputs.loss\n", "\n", "print(\"Loss before backward:\", loss.item())\n", "\n", "scaler.scale(loss).backward()\n", "scaler.step(optimizer)\n", "scaler.update()\n", "\n", "print(\" Training step passed (no errors)\")\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "75a98267", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0, 3, 0, 0, 0, 0, 3, 0, 0, 3]], device='cuda:0')\n", " \n" ] } ], "source": [ "ids = model.t5.generate(\n", " input_ids=torch.tensor([[tokenizer.pad_token_id]]).to(device),\n", " max_length=10,\n", ")\n", "print(ids)\n", "print(tokenizer.decode(ids[0], skip_special_tokens=True))" ] }, { "cell_type": "code", "execution_count": 14, "id": "0b2af979", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Decoder start token: 0\n" ] } ], "source": [ "print(\"Decoder start token:\", model.t5.config.decoder_start_token_id)\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "1a75eb6c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unconditioned generation: \n" ] } ], "source": [ "test_ids = model.t5.generate(\n", " input_ids=torch.tensor([[tokenizer.pad_token_id]]).to(device),\n", " max_length=10\n", ")\n", "\n", "print(\"Unconditioned generation:\", tokenizer.decode(test_ids[0], skip_special_tokens=True))\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "54d23198", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Vision out mean: 0.11751262843608856\n", "Projected mean: 0.14342853426933289\n", "Projected shape: torch.Size([1, 512])\n", "Generated caption: e e e e e e e e e e e e e e e e\n" ] } ], "source": [ "from inference import generate_caption\n", "\n", "model.eval()\n", "\n", "sample_img = batch[\"pixel_values\"][0:1].to(device) # one image batch\n", "vision_out = model.vision_encoder(sample_img)\n", "print(\"Vision out mean:\", vision_out[\"image_embeds\"].abs().mean().item())\n", "\n", "proj = model.projector(vision_out[\"image_embeds\"])\n", "print(\"Projected mean:\", proj.abs().mean().item())\n", "print(\"Projected shape:\", proj.shape)\n", "\n", "\n", "caption = generate_caption(model, tokenizer, sample_img, device=device)\n", "print(\"Generated caption:\", caption)\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "f6402df8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DEBUG — encoder hidden shape: torch.Size([1, 1, 512])\n", "DEBUG — encoder hidden mean: 0.01798422262072563\n", "Val-style loss: 8.727174758911133\n", "Prediction: e e e e e e e e e e e e e e e e\n", "Ground Truth: A woman dancing around with a umbrella in her hand.\n" ] } ], "source": [ "model.eval()\n", "\n", "with torch.no_grad():\n", " pixel_values = batch[\"pixel_values\"][0:1].to(device)\n", " input_ids = batch[\"input_ids\"][0:1].to(device)\n", " attention_mask = batch[\"attention_mask\"][0:1].to(device)\n", "\n", " labels = input_ids.clone()\n", " labels[labels == tokenizer.pad_token_id] = -100\n", "\n", " outputs = model(\n", " pixel_values=pixel_values,\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", "print(\"Val-style loss:\", outputs.loss.item())\n", "\n", "# Preview caption\n", "pred = generate_caption(model, tokenizer, pixel_values, device=device)\n", "gt_ids = input_ids[0][input_ids[0] != tokenizer.pad_token_id]\n", "gt_caption = tokenizer.decode(gt_ids, skip_special_tokens=True)\n", "\n", "print(\"Prediction:\", pred)\n", "print(\"Ground Truth:\", gt_caption)\n" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }