ClipClap
Browse files- .gitattributes +1 -0
- 2111.09734v1.pdf +3 -0
- ClipClap.zip +3 -0
- captions_examples/conceptual.zip +3 -0
- notebooks/evaluation_mlp_baseline.ipynb +1 -0
- notebooks/evaluation_transformer_advance.ipynb +0 -0
- notebooks/evaluation_transformer_baseline.ipynb +0 -0
- notebooks/evaluation_transformer_gpt3.ipynb +0 -0
- notebooks/mlp_gpt2_inference.ipynb +0 -0
- notebooks/transformer_advance_inference.ipynb +1 -0
- notebooks/transformer_baseline_inference.ipynb +0 -0
- notebooks/transformer_gpt3_inference.ipynb +0 -0
- pretrained_models/mlp_gpt2_weights.pt +3 -0
- starter_code/CLIP_prefix_caption.zip +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
2111.09734v1.pdf filter=lfs diff=lfs merge=lfs -text
|
2111.09734v1.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba821a26ced327c767343d8cdd626a2857b190cf0760be65a041482d1771e12b
|
| 3 |
+
size 9493745
|
ClipClap.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ec7c1e890d239de7852586ae6be9582a5e70161ab014ddd435391487d126abe
|
| 3 |
+
size 212161994
|
captions_examples/conceptual.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68451099f873bb570b65b062a97abc076a05d7cb9585dec9e225899495ab8aff
|
| 3 |
+
size 387263
|
notebooks/evaluation_mlp_baseline.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":31929,"status":"ok","timestamp":1683689242625,"user":{"displayName":"Neha Jain","userId":"09908231183836366208"},"user_tz":300},"id":"9iQfOPEL4gYd","outputId":"d7d393be-2fb2-4a51-d660-89caa31d6b1c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting transformers\n"," Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m23.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n","Collecting huggingface-hub<1.0,>=0.11.0 (from transformers)\n"," Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m224.5/224.5 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n","Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)\n"," Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m39.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (2023.4.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n","Installing collected packages: tokenizers, huggingface-hub, transformers\n","Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting git+https://github.com/openai/CLIP.git\n"," Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-frymm3v5\n"," Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-frymm3v5\n"," Resolved https://github.com/openai/CLIP.git to commit a9b1bf5920416aaeaec965c25dd9e8f98c864f16\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting ftfy (from clip==1.0)\n"," Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (2022.10.31)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (4.65.0)\n","Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (2.0.0+cu118)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (0.15.1+cu118)\n","Requirement already satisfied: wcwidth>=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ftfy->clip==1.0) (0.2.6)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->clip==1.0) (3.12.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->clip==1.0) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->clip==1.0) (1.11.1)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->clip==1.0) (3.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->clip==1.0) (3.1.2)\n","Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->clip==1.0) (2.0.0)\n","Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->clip==1.0) (3.25.2)\n","Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->clip==1.0) (16.0.3)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision->clip==1.0) (1.22.4)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision->clip==1.0) (2.27.1)\n","Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->clip==1.0) (8.4.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->clip==1.0) (2.1.2)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->clip==1.0) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->clip==1.0) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->clip==1.0) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->clip==1.0) (3.4)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->clip==1.0) (1.3.0)\n","Building wheels for collected packages: clip\n"," Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369370 sha256=7844bf9b1a6d8a44f9356e8d3520c1177a11f9bdd3c4b1c1f783b86aa4aef401\n"," Stored in directory: /tmp/pip-ephem-wheel-cache-lparjd81/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4\n","Successfully built clip\n","Installing collected packages: ftfy, clip\n","Successfully installed clip-1.0 ftfy-6.1.1\n"]}],"source":["#@title Install\n","!pip install transformers\n","! pip install git+https://github.com/openai/CLIP.git"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":16899,"status":"ok","timestamp":1683689259520,"user":{"displayName":"Neha Jain","userId":"09908231183836366208"},"user_tz":300},"id":"PzIdzzVY62_u","outputId":"507e0c9d-1689-4885-90cc-5e3fb3f337da"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive\n"]}],"source":["from google.colab import drive\n","drive.mount(\"/content/gdrive\",force_remount=True)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"cZZU19Ah67hT"},"outputs":[],"source":["import os\n","os.chdir(\"/content/gdrive/MyDrive/CS444_Neha/CS444Project/CLIP_prefix_caption/\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fO-ycrHj4dLx"},"outputs":[],"source":["#@title Imports\n","\n","import clip\n","import os\n","from torch import nn\n","import numpy as np\n","import torch\n","import torch.nn.functional as nnf\n","import sys\n","from typing import Tuple, List, Union, Optional\n","from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup\n","from tqdm import tqdm, trange\n","from google.colab import files\n","import skimage.io as io\n","import PIL.Image\n","from IPython.display import Image\n","\n","N = type(None)\n","V = np.array\n","ARRAY = np.ndarray\n","ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]\n","VS = Union[Tuple[V, ...], List[V]]\n","VN = Union[V, N]\n","VNS = Union[VS, N]\n","T = torch.Tensor\n","TS = Union[Tuple[T, ...], List[T]]\n","TN = Optional[T]\n","TNS = Union[Tuple[TN, ...], List[TN]]\n","TSN = Optional[TS]\n","TA = Union[T, ARRAY]\n","\n","\n","D = torch.device\n","CPU = torch.device('cpu')\n","\n","\n","def get_device(device_id: int) -> D:\n"," if not torch.cuda.is_available():\n"," return CPU\n"," device_id = min(torch.cuda.device_count() - 1, device_id)\n"," return torch.device(f'cuda:{device_id}')\n","\n","\n","CUDA = get_device"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"n62Y4gFZ4XY8"},"outputs":[],"source":["#@title Model\n","\n","class MLP(nn.Module):\n","\n"," def forward(self, x: T) -> T:\n"," return self.model(x)\n","\n"," def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):\n"," super(MLP, self).__init__()\n"," layers = []\n"," for i in range(len(sizes) -1):\n"," layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))\n"," if i < len(sizes) - 2:\n"," layers.append(act())\n"," self.model = nn.Sequential(*layers)\n","\n","\n","class ClipCaptionModel(nn.Module):\n","\n"," #@functools.lru_cache #FIXME\n"," def get_dummy_token(self, batch_size: int, device: D) -> T:\n"," return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)\n","\n"," def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):\n"," embedding_text = self.gpt.transformer.wte(tokens)\n"," prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)\n"," #print(embedding_text.size()) #torch.Size([5, 67, 768])\n"," #print(prefix_projections.size()) #torch.Size([5, 1, 768])\n"," embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)\n"," if labels is not None:\n"," dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)\n"," labels = torch.cat((dummy_token, tokens), dim=1)\n"," out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)\n"," return out\n","\n"," def __init__(self, prefix_length: int, prefix_size: int = 512):\n"," super(ClipCaptionModel, self).__init__()\n"," self.prefix_length = prefix_length\n"," self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')\n"," self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]\n"," if prefix_length > 10: # not enough memory\n"," self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)\n"," else:\n"," self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))\n","\n","\n","class ClipCaptionPrefix(ClipCaptionModel):\n","\n"," def parameters(self, recurse: bool = True):\n"," return self.clip_project.parameters()\n","\n"," def train(self, mode: bool = True):\n"," super(ClipCaptionPrefix, self).train(mode)\n"," self.gpt.eval()\n"," return self"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"5S6Vccv387Os"},"outputs":[],"source":["#@title Caption prediction\n","\n","def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,\n"," entry_length=67, temperature=1., stop_token: str = '.'):\n","\n"," model.eval()\n"," stop_token_index = tokenizer.encode(stop_token)[0]\n"," tokens = None\n"," scores = None\n"," device = next(model.parameters()).device\n"," seq_lengths = torch.ones(beam_size, device=device)\n"," is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)\n"," with torch.no_grad():\n"," if embed is not None:\n"," generated = embed\n"," else:\n"," if tokens is None:\n"," tokens = torch.tensor(tokenizer.encode(prompt))\n"," tokens = tokens.unsqueeze(0).to(device)\n"," generated = model.gpt.transformer.wte(tokens)\n"," for i in range(entry_length):\n"," outputs = model.gpt(inputs_embeds=generated)\n"," logits = outputs.logits\n"," logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)\n"," logits = logits.softmax(-1).log()\n"," if scores is None:\n"," scores, next_tokens = logits.topk(beam_size, -1)\n"," generated = generated.expand(beam_size, *generated.shape[1:])\n"," next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)\n"," if tokens is None:\n"," tokens = next_tokens\n"," else:\n"," tokens = tokens.expand(beam_size, *tokens.shape[1:])\n"," tokens = torch.cat((tokens, next_tokens), dim=1)\n"," else:\n"," logits[is_stopped] = -float(np.inf)\n"," logits[is_stopped, 0] = 0\n"," scores_sum = scores[:, None] + logits\n"," seq_lengths[~is_stopped] += 1\n"," scores_sum_average = scores_sum / seq_lengths[:, None]\n"," scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)\n"," next_tokens_source = next_tokens // scores_sum.shape[1]\n"," seq_lengths = seq_lengths[next_tokens_source]\n"," next_tokens = next_tokens % scores_sum.shape[1]\n"," next_tokens = next_tokens.unsqueeze(1)\n"," tokens = tokens[next_tokens_source]\n"," tokens = torch.cat((tokens, next_tokens), dim=1)\n"," generated = generated[next_tokens_source]\n"," scores = scores_sum_average * seq_lengths\n"," is_stopped = is_stopped[next_tokens_source]\n"," next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)\n"," generated = torch.cat((generated, next_token_embed), dim=1)\n"," is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()\n"," if is_stopped.all():\n"," break\n"," scores = scores / seq_lengths\n"," output_list = tokens.cpu().numpy()\n"," output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]\n"," order = scores.argsort(descending=True)\n"," output_texts = [output_texts[i] for i in order]\n"," return output_texts\n","\n","\n","def generate2(\n"," model,\n"," tokenizer,\n"," tokens=None,\n"," prompt=None,\n"," embed=None,\n"," entry_count=1,\n"," entry_length=67, # maximum number of words\n"," top_p=0.8,\n"," temperature=1.,\n"," stop_token: str = '.',\n","):\n"," model.eval()\n"," generated_num = 0\n"," generated_list = []\n"," stop_token_index = tokenizer.encode(stop_token)[0]\n"," filter_value = -float(\"Inf\")\n"," device = next(model.parameters()).device\n","\n"," with torch.no_grad():\n","\n"," for entry_idx in range(entry_count):\n"," if embed is not None:\n"," generated = embed\n"," else:\n"," if tokens is None:\n"," tokens = torch.tensor(tokenizer.encode(prompt))\n"," tokens = tokens.unsqueeze(0).to(device)\n","\n"," generated = model.gpt.transformer.wte(tokens)\n","\n"," for i in range(entry_length):\n","\n"," outputs = model.gpt(inputs_embeds=generated)\n"," logits = outputs.logits\n"," logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)\n"," sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n"," cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)\n"," sorted_indices_to_remove = cumulative_probs > top_p\n"," sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[\n"," ..., :-1\n"," ].clone()\n"," sorted_indices_to_remove[..., 0] = 0\n","\n"," indices_to_remove = sorted_indices[sorted_indices_to_remove]\n"," logits[:, indices_to_remove] = filter_value\n"," next_token = torch.argmax(logits, -1).unsqueeze(0)\n"," next_token_embed = model.gpt.transformer.wte(next_token)\n"," if tokens is None:\n"," tokens = next_token\n"," else:\n"," tokens = torch.cat((tokens, next_token), dim=1)\n"," generated = torch.cat((generated, next_token_embed), dim=1)\n"," if stop_token_index == next_token.item():\n"," break\n","\n"," output_list = list(tokens.squeeze().cpu().numpy())\n"," output_text = tokenizer.decode(output_list)\n"," generated_list.append(output_text)\n","\n"," return generated_list[0]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6VOqclm8bUsn"},"outputs":[],"source":["# Validation Data \n","import zipfile\n","\n","val_data_zip = zipfile.ZipFile(\"/content/gdrive/MyDrive/CS444_Neha/CS444Project/CLIP_prefix_caption/data/coco/val2014.zip\",\"r\") #Opens the tar file in read mode\n","val_data_zip.extractall(\"/tmp\") #Extracts the files into the /tmp folder\n","val_data_zip.close()"]},{"cell_type":"code","source":["import json\n","from pycocotools.coco import COCO\n","from pycocoevalcap.eval import COCOEvalCap\n","#from clipcap_model import ClipCapModel # import your ClipCap model\n","\n","# Load the COCO annotations\n","annFile = '/content/gdrive/MyDrive/CS444_Neha/CS444Project/CLIP_prefix_caption/data/coco/annotations/captions_val2014.json' # change this to the path of the annotations file\n","coco = COCO(annFile)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6HK6Ok7vdJUj","executionInfo":{"status":"ok","timestamp":1683689353260,"user_tz":300,"elapsed":6144,"user":{"displayName":"Neha Jain","userId":"09908231183836366208"}},"outputId":"f34b1069-d808-4de8-c781-47e3d3ba43a0"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["loading annotations into memory...\n","Done (t=0.77s)\n","creating index...\n","index created!\n"]}]},{"cell_type":"code","source":["import random\n","# Select a random subset of image IDs\n","imgIds = [image_id for image_id in coco.getImgIds()]\n","imgIds = list(set(imgIds)) # remove duplicates\n","# Set the seed\n","random.seed(123)\n","\n","random.shuffle(imgIds,)\n","imgIds = imgIds[:5000]"],"metadata":{"id":"joQo6XJ2c9lx"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":99,"referenced_widgets":["87405f917af54c26813f8ab7700c2a66","75b045a978ea44e2bcffa096537553b6","3d9dc024cbf2489a95c7cfed1d58cd15","da2963f6d3aa4e6f96d7199b5b0a5cf2","df80e45ce13746d587e4df5be1102b54","48b98f2308fc4c25a8fbc4c38ef4a408","769aca47327b478b9d02d59b93f7c962","c13843c99727412593a036641d77a664","ccd1b835a14d4713a4991208e15e22bb","f7952d6e58a54f9594456633e600e59b","f08b3f212bbb4de9ae90dbf2bbec8a43","5a4811a7bfa845818fabdc7208508079","a858b08f50bc4a67b211f26e1347b31f","e0385c86cb824f38a3873da730dad850","eaccf5ffff8647f7be8b148a55845ef6","bf2e8bd15f554634a7876885ebbd4a98","2b74e3c5ac0143fc85d63307bf045024","1931c16ccd9949bbbb5ac8131fc0bc4d","5413529457c540d3988c17abe1687c98","dd4ad19645a845bd9022a154e4dd89c4","28c7d6ae0eaf4838a2cde87e332173cc","6c7fe187d16f47dea6d454180e384d75"]},"id":"3kXG5yBi38ST","outputId":"c5eb7c48-f270-4c4a-8678-7d83af4210db","executionInfo":{"status":"ok","timestamp":1683690364600,"user_tz":300,"elapsed":983233,"user":{"displayName":"Neha Jain","userId":"09908231183836366208"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["100%|ββββββββββββββββββββββββββββββββββββββββ| 338M/338M [00:03<00:00, 115MiB/s]\n"]},{"output_type":"display_data","data":{"text/plain":["Downloading (β¦)olve/main/vocab.json: 0%| | 0.00/1.04M [00:00<?, ?B/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"87405f917af54c26813f8ab7700c2a66"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Downloading (β¦)olve/main/merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5a4811a7bfa845818fabdc7208508079"}},"metadata":{}}],"source":["# Load the ClipCap model\n","current_directory = os.getcwd()\n","save_path = \"/content/gdrive/MyDrive/CS444_Neha/CS444Project/pretrained_models\"\n","model_path = os.path.join(save_path, 'mlp_gpt2_weights.pt')\n","\n","prefix_length = 10\n","\n","model = ClipCaptionModel(prefix_length)\n","\n","model.load_state_dict(torch.load(model_path, map_location=CPU)) \n","\n","model = model.eval() \n","is_gpu = True \n","device = CUDA(0) if is_gpu else \"cpu\"\n","model = model.to(device)# clipCap model \n","\n","#@title CLIP model + GPT2 tokenizer\n","\n","device = CUDA(0) if is_gpu else \"cpu\"\n","clip_model, preprocess = clip.load(\"ViT-B/32\", device=device, jit=False)\n","tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n","\n","use_beam_search = False \n","prefix_length = 10\n","\n","current_directory = os.getcwd()\n","\n","# Generate captions for the test set\n","results = []\n","images_path = \"/tmp/val2014/\"\n","for img_id in imgIds:\n"," img = coco.loadImgs(img_id)[0]\n"," name_ = f\"COCO_val2014_{int(img_id):012d}.jpg\"\n"," UPLOADED_FILE = os.path.join(images_path, name_)\n"," if not os.path.isfile(UPLOADED_FILE):\n"," continue\n","\n"," image = io.imread(UPLOADED_FILE)\n"," pil_image = PIL.Image.fromarray(image)\n"," image = preprocess(pil_image).unsqueeze(0).to(device)\n"," with torch.no_grad():\n"," prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)\n"," prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)\n"," if use_beam_search:\n"," generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]\n"," else:\n"," generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)\n","\n"," #caption = clipcap.generate_caption(file_name) # generate a caption using your ClipCap model\n"," results.append({\n"," 'image_id': img_id,\n"," 'caption': generated_text_prefix\n"," })\n","\n","# Evaluate the results using the COCO evaluation metrics\n","resFile = 'mlp_baseline_results.json' # save the results to a JSON file\n","with open(resFile, 'w') as f:\n"," json.dump(results, f)\n"]},{"cell_type":"code","source":["cocoRes = coco.loadRes(resFile) # load the generated captions\n","cocoEval = COCOEvalCap(coco, cocoRes)\n","cocoEval.evaluate()\n","\n","# Print the evaluation results\n","for metric, score in cocoEval.eval.items():\n"," print('%s: %.4f' % (metric, score))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VEKSI0ercSwX","executionInfo":{"status":"ok","timestamp":1683690400707,"user_tz":300,"elapsed":36129,"user":{"displayName":"Neha Jain","userId":"09908231183836366208"}},"outputId":"088d4267-d79e-48fb-fe0d-d9dcb822b5a4"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Loading and preparing results...\n","DONE (t=0.03s)\n","creating index...\n","index created!\n","tokenization...\n","setting up scorers...\n","computing Bleu score...\n","{'testlen': 49206, 'reflen': 48617, 'guess': [49206, 44206, 39206, 34206], 'correct': [38636, 22723, 11985, 6193]}\n","ratio: 1.0121151037702654\n","Bleu_1: 0.785\n","Bleu_2: 0.635\n","Bleu_3: 0.498\n","Bleu_4: 0.387\n","computing METEOR score...\n","METEOR: 0.299\n","computing Rouge score...\n","ROUGE_L: 0.593\n","computing CIDEr score...\n","CIDEr: 1.307\n","Bleu_1: 0.7852\n","Bleu_2: 0.6353\n","Bleu_3: 0.4978\n","Bleu_4: 0.3866\n","METEOR: 0.2987\n","ROUGE_L: 0.5927\n","CIDEr: 1.3068\n"]}]}],"metadata":{"colab":{"provenance":[],"gpuType":"T4"},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"87405f917af54c26813f8ab7700c2a66":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_75b045a978ea44e2bcffa096537553b6","IPY_MODEL_3d9dc024cbf2489a95c7cfed1d58cd15","IPY_MODEL_da2963f6d3aa4e6f96d7199b5b0a5cf2"],"layout":"IPY_MODEL_df80e45ce13746d587e4df5be1102b54"}},"75b045a978ea44e2bcffa096537553b6":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_48b98f2308fc4c25a8fbc4c38ef4a408","placeholder":"β","style":"IPY_MODEL_769aca47327b478b9d02d59b93f7c962","value":"Downloading (β¦)olve/main/vocab.json: 100%"}},"3d9dc024cbf2489a95c7cfed1d58cd15":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_c13843c99727412593a036641d77a664","max":1042301,"min":0,"orientation":"horizontal","style":"IPY_MODEL_ccd1b835a14d4713a4991208e15e22bb","value":1042301}},"da2963f6d3aa4e6f96d7199b5b0a5cf2":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_f7952d6e58a54f9594456633e600e59b","placeholder":"β","style":"IPY_MODEL_f08b3f212bbb4de9ae90dbf2bbec8a43","value":" 1.04M/1.04M [00:00<00:00, 34.3MB/s]"}},"df80e45ce13746d587e4df5be1102b54":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"48b98f2308fc4c25a8fbc4c38ef4a408":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"769aca47327b478b9d02d59b93f7c962":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"c13843c99727412593a036641d77a664":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ccd1b835a14d4713a4991208e15e22bb":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"f7952d6e58a54f9594456633e600e59b":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f08b3f212bbb4de9ae90dbf2bbec8a43":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"5a4811a7bfa845818fabdc7208508079":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_a858b08f50bc4a67b211f26e1347b31f","IPY_MODEL_e0385c86cb824f38a3873da730dad850","IPY_MODEL_eaccf5ffff8647f7be8b148a55845ef6"],"layout":"IPY_MODEL_bf2e8bd15f554634a7876885ebbd4a98"}},"a858b08f50bc4a67b211f26e1347b31f":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_2b74e3c5ac0143fc85d63307bf045024","placeholder":"β","style":"IPY_MODEL_1931c16ccd9949bbbb5ac8131fc0bc4d","value":"Downloading (β¦)olve/main/merges.txt: 100%"}},"e0385c86cb824f38a3873da730dad850":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_5413529457c540d3988c17abe1687c98","max":456318,"min":0,"orientation":"horizontal","style":"IPY_MODEL_dd4ad19645a845bd9022a154e4dd89c4","value":456318}},"eaccf5ffff8647f7be8b148a55845ef6":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_28c7d6ae0eaf4838a2cde87e332173cc","placeholder":"β","style":"IPY_MODEL_6c7fe187d16f47dea6d454180e384d75","value":" 456k/456k [00:00<00:00, 29.7MB/s]"}},"bf2e8bd15f554634a7876885ebbd4a98":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2b74e3c5ac0143fc85d63307bf045024":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1931c16ccd9949bbbb5ac8131fc0bc4d":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"5413529457c540d3988c17abe1687c98":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"dd4ad19645a845bd9022a154e4dd89c4":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"28c7d6ae0eaf4838a2cde87e332173cc":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"6c7fe187d16f47dea6d454180e384d75":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"nbformat":4,"nbformat_minor":0}
|
notebooks/evaluation_transformer_advance.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/evaluation_transformer_baseline.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/evaluation_transformer_gpt3.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/mlp_gpt2_inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/transformer_advance_inference.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":29911,"status":"ok","timestamp":1683692886248,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"NIZ4485gg2-Q","outputId":"420a96c0-c7a2-48bd-d3fe-9ae87727fc68"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/gdrive\n"]}],"source":["from google.colab import drive\n","drive.mount(\"/content/gdrive\")"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1485,"status":"ok","timestamp":1683692890041,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"OSsx04hzg5Kw","outputId":"bb46768c-fbc0-40b2-a36f-cd5d94097556"},"outputs":[{"name":"stdout","output_type":"stream","text":["/content/gdrive/MyDrive/CS444Project/CLIP_prefix_caption/Pallaw/notebooks\n"]}],"source":["import os\n","os.chdir(\"/content/gdrive/MyDrive/CS444Project/CLIP_prefix_caption/Pallaw/notebooks\")\n","!pwd"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":23328,"status":"ok","timestamp":1683692917201,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"GRfpGaz27IWs","outputId":"defde245-ab30-45e3-907b-fa3e13a2c11e"},"outputs":[{"name":"stdout","output_type":"stream","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting transformers\n"," Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m74.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n","Collecting huggingface-hub\u003c1.0,\u003e=0.11.0 (from transformers)\n"," Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m224.5/224.5 kB\u001b[0m \u001b[31m25.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: numpy\u003e=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: packaging\u003e=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: pyyaml\u003e=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n","Collecting tokenizers!=0.11.3,\u003c0.14,\u003e=0.11.1 (from transformers)\n"," Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m79.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: tqdm\u003e=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub\u003c1.0,\u003e=0.11.0-\u003etransformers) (2023.4.0)\n","Requirement already satisfied: typing-extensions\u003e=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub\u003c1.0,\u003e=0.11.0-\u003etransformers) (4.5.0)\n","Requirement already satisfied: urllib3\u003c1.27,\u003e=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etransformers) (1.26.15)\n","Requirement already satisfied: certifi\u003e=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etransformers) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etransformers) (2.0.12)\n","Requirement already satisfied: idna\u003c4,\u003e=2.5 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etransformers) (3.4)\n","Installing collected packages: tokenizers, huggingface-hub, transformers\n","Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting git+https://github.com/openai/CLIP.git\n"," Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-30z8umjw\n"," Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-30z8umjw\n"," Resolved https://github.com/openai/CLIP.git to commit a9b1bf5920416aaeaec965c25dd9e8f98c864f16\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting ftfy (from clip==1.0)\n"," Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (2022.10.31)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (4.65.0)\n","Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (2.0.0+cu118)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from clip==1.0) (0.15.1+cu118)\n","Requirement already satisfied: wcwidth\u003e=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ftfy-\u003eclip==1.0) (0.2.6)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch-\u003eclip==1.0) (3.12.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch-\u003eclip==1.0) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch-\u003eclip==1.0) (1.11.1)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch-\u003eclip==1.0) (3.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch-\u003eclip==1.0) (3.1.2)\n","Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch-\u003eclip==1.0) (2.0.0)\n","Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0-\u003etorch-\u003eclip==1.0) (3.25.2)\n","Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0-\u003etorch-\u003eclip==1.0) (16.0.3)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision-\u003eclip==1.0) (1.22.4)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision-\u003eclip==1.0) (2.27.1)\n","Requirement already satisfied: pillow!=8.3.*,\u003e=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision-\u003eclip==1.0) (8.4.0)\n","Requirement already satisfied: MarkupSafe\u003e=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2-\u003etorch-\u003eclip==1.0) (2.1.2)\n","Requirement already satisfied: urllib3\u003c1.27,\u003e=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etorchvision-\u003eclip==1.0) (1.26.15)\n","Requirement already satisfied: certifi\u003e=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etorchvision-\u003eclip==1.0) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etorchvision-\u003eclip==1.0) (2.0.12)\n","Requirement already satisfied: idna\u003c4,\u003e=2.5 in /usr/local/lib/python3.10/dist-packages (from requests-\u003etorchvision-\u003eclip==1.0) (3.4)\n","Requirement already satisfied: mpmath\u003e=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy-\u003etorch-\u003eclip==1.0) (1.3.0)\n","Building wheels for collected packages: clip\n"," Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369370 sha256=80a8719bafe67451346307f1b8c5f4ed5e0ef47fc461771b389c33b853f33252\n"," Stored in directory: /tmp/pip-ephem-wheel-cache-r9trctt0/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4\n","Successfully built clip\n","Installing collected packages: ftfy, clip\n","Successfully installed clip-1.0 ftfy-6.1.1\n"]}],"source":["#@title Install\n","!pip install transformers\n","! pip install git+https://github.com/openai/CLIP.git\n"]},{"cell_type":"code","execution_count":4,"metadata":{"executionInfo":{"elapsed":10826,"status":"ok","timestamp":1683692933448,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"OArDkm_24w4L"},"outputs":[],"source":["#@title Imports\n","\n","import clip\n","import os\n","from torch import nn\n","import numpy as np\n","import torch\n","import torch.nn.functional as nnf\n","import sys\n","from typing import Tuple, List, Union, Optional\n","from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup\n","from tqdm import tqdm, trange\n","from google.colab import files\n","import skimage.io as io\n","import PIL.Image\n","from IPython.display import Image \n","from enum import Enum\n","\n","\n","\n","N = type(None)\n","V = np.array\n","ARRAY = np.ndarray\n","ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]\n","VS = Union[Tuple[V, ...], List[V]]\n","VN = Union[V, N]\n","VNS = Union[VS, N]\n","T = torch.Tensor\n","TS = Union[Tuple[T, ...], List[T]]\n","TN = Optional[T]\n","TNS = Union[Tuple[TN, ...], List[TN]]\n","TSN = Optional[TS]\n","TA = Union[T, ARRAY]\n","\n","\n","D = torch.device\n","CPU = torch.device('cpu')\n","\n","\n","def get_device(device_id: int) -\u003e D:\n"," if not torch.cuda.is_available():\n"," return CPU\n"," device_id = min(torch.cuda.device_count() - 1, device_id)\n"," return torch.device(f'cuda:{device_id}')\n","\n","\n","CUDA = get_device"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":591,"status":"ok","timestamp":1683692937301,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"4ClW2ebek8DK"},"outputs":[],"source":["#@title Model\n","\n","\n","class MappingType(Enum):\n"," MLP = 'mlp'\n"," Transformer = 'transformer'\n","\n","\n","class MlpTransformer(nn.Module):\n"," def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):\n"," super().__init__()\n"," out_d = out_d if out_d is not None else in_dim\n"," self.fc1 = nn.Linear(in_dim, h_dim)\n"," self.act = act\n"," self.fc2 = nn.Linear(h_dim, out_d)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x):\n"," x = self.fc1(x)\n"," x = self.act(x)\n"," x = self.dropout(x)\n"," x = self.fc2(x)\n"," x = self.dropout(x)\n"," return x\n","\n","class MLP(nn.Module):\n","\n"," def forward(self, x: torch.Tensor) -\u003e torch.Tensor:\n"," return self.model(x)\n","\n"," def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):\n"," super(MLP, self).__init__()\n"," layers = []\n"," for i in range(len(sizes) - 1):\n"," layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))\n"," if i \u003c len(sizes) - 2:\n"," layers.append(act())\n"," self.model = nn.Sequential(*layers)\n","\n","\n","class MultiHeadAttention(nn.Module):\n","\n"," def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):\n"," super().__init__()\n"," self.num_heads = num_heads\n"," head_dim = dim_self // num_heads\n"," self.scale = head_dim ** -0.5\n"," self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)\n"," self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)\n"," self.project = nn.Linear(dim_self, dim_self)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x, y=None, mask=None):\n"," y = y if y is not None else x\n"," b, n, c = x.shape\n"," _, m, d = y.shape\n"," # b n h dh\n"," queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)\n"," # b m 2 h dh\n"," keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)\n"," keys, values = keys_values[:, :, 0], keys_values[:, :, 1]\n"," attention = torch.einsum('bnhd,bmhd-\u003ebnmh', queries, keys) * self.scale\n"," if mask is not None:\n"," if mask.dim() == 2:\n"," mask = mask.unsqueeze(1)\n"," attention = attention.masked_fill(mask.unsqueeze(3), float(\"-inf\"))\n"," attention = attention.softmax(dim=2)\n"," out = torch.einsum('bnmh,bmhd-\u003ebnhd', attention, values).reshape(b, n, c)\n"," out = self.project(out)\n"," return out, attention\n","\n","\n","class TransformerLayer(nn.Module):\n","\n"," def forward_with_attention(self, x, y=None, mask=None):\n"," x_, attention = self.attn(self.norm1(x), y, mask)\n"," x = x + x_\n"," x = x + self.mlp(self.norm2(x))\n"," return x, attention\n","\n"," def forward(self, x, y=None, mask=None):\n"," x = x + self.attn(self.norm1(x), y, mask)[0]\n"," x = x + self.mlp(self.norm2(x))\n"," return x\n","\n"," def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,\n"," norm_layer: nn.Module = nn.LayerNorm):\n"," super().__init__()\n"," self.norm1 = norm_layer(dim_self)\n"," self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)\n"," self.norm2 = norm_layer(dim_self)\n"," self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)\n","\n","\n","class Transformer(nn.Module):\n","\n"," def forward_with_attention(self, x, y=None, mask=None):\n"," attentions = []\n"," for layer in self.layers:\n"," x, att = layer.forward_with_attention(x, y, mask)\n"," attentions.append(att)\n"," return x, attentions\n","\n"," def forward(self, x, y=None, mask=None):\n"," for i, layer in enumerate(self.layers):\n"," if i % 2 == 0 and self.enc_dec: # cross\n"," x = layer(x, y)\n"," elif self.enc_dec: # self\n"," x = layer(x, x, mask)\n"," else: # self or cross\n"," x = layer(x, y, mask)\n"," return x\n","\n"," def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,\n"," mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):\n"," super(Transformer, self).__init__()\n"," dim_ref = dim_ref if dim_ref is not None else dim_self\n"," self.enc_dec = enc_dec\n"," if enc_dec:\n"," num_layers = num_layers * 2\n"," layers = []\n"," for i in range(num_layers):\n"," if i % 2 == 0 and enc_dec: # cross\n"," layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))\n"," elif enc_dec: # self\n"," layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))\n"," else: # self or cross\n"," layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))\n"," self.layers = nn.ModuleList(layers)\n","\n","\n","class TransformerMapper(nn.Module):\n","\n"," def forward(self, x):\n"," x = self.linear(x).view(x.shape[0], self.clip_length, -1)\n"," prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)\n"," prefix = torch.cat((x, prefix), dim=1)\n"," out = self.transformer(prefix)[:, self.clip_length:]\n"," return out\n","\n"," def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):\n"," super(TransformerMapper, self).__init__()\n"," self.clip_length = clip_length\n"," self.transformer = Transformer(dim_embedding, 8, num_layers)\n"," self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)\n"," self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)\n","\n","\n","class ClipCaptionModel(nn.Module):\n","\n"," def get_dummy_token(self, batch_size: int, device: torch.device) -\u003e torch.Tensor:\n"," return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)\n","\n"," def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,\n"," labels: Optional[torch.Tensor] = None):\n"," embedding_text = self.gpt.transformer.wte(tokens)\n"," prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)\n"," embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)\n"," if labels is not None:\n"," dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)\n"," labels = torch.cat((dummy_token, tokens), dim=1)\n"," out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)\n"," return out\n","\n"," def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,\n"," num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):\n"," super(ClipCaptionModel, self).__init__()\n"," self.prefix_length = prefix_length\n"," self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')\n"," self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]\n"," if mapping_type == MappingType.MLP:\n"," self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,\n"," self.gpt_embedding_size * prefix_length))\n"," else:\n"," self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,\n"," clip_length, num_layers)\n","\n","\n","class ClipCaptionPrefix(ClipCaptionModel):\n","\n"," def parameters(self, recurse: bool = True):\n"," return self.clip_project.parameters()\n","\n"," def train(self, mode: bool = True):\n"," super(ClipCaptionPrefix, self).train(mode)\n"," self.gpt.eval()\n"," return self"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":424,"status":"ok","timestamp":1683692945613,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"V7xocT3TUgey"},"outputs":[],"source":["#@title Caption prediction\n","\n","def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,\n"," entry_length=67, temperature=1., stop_token: str = '.'):\n","\n"," model.eval()\n"," stop_token_index = tokenizer.encode(stop_token)[0]\n"," tokens = None\n"," scores = None\n"," device = next(model.parameters()).device\n"," seq_lengths = torch.ones(beam_size, device=device)\n"," is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)\n"," with torch.no_grad():\n"," if embed is not None:\n"," generated = embed\n"," else:\n"," if tokens is None:\n"," tokens = torch.tensor(tokenizer.encode(prompt))\n"," tokens = tokens.unsqueeze(0).to(device)\n"," generated = model.gpt.transformer.wte(tokens)\n"," for i in range(entry_length):\n"," outputs = model.gpt(inputs_embeds=generated)\n"," logits = outputs.logits\n"," logits = logits[:, -1, :] / (temperature if temperature \u003e 0 else 1.0)\n"," logits = logits.softmax(-1).log()\n"," if scores is None:\n"," scores, next_tokens = logits.topk(beam_size, -1)\n"," generated = generated.expand(beam_size, *generated.shape[1:])\n"," next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)\n"," if tokens is None:\n"," tokens = next_tokens\n"," else:\n"," tokens = tokens.expand(beam_size, *tokens.shape[1:])\n"," tokens = torch.cat((tokens, next_tokens), dim=1)\n"," else:\n"," logits[is_stopped] = -float(np.inf)\n"," logits[is_stopped, 0] = 0\n"," scores_sum = scores[:, None] + logits\n"," seq_lengths[~is_stopped] += 1\n"," scores_sum_average = scores_sum / seq_lengths[:, None]\n"," scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)\n"," next_tokens_source = next_tokens // scores_sum.shape[1]\n"," seq_lengths = seq_lengths[next_tokens_source]\n"," next_tokens = next_tokens % scores_sum.shape[1]\n"," next_tokens = next_tokens.unsqueeze(1)\n"," tokens = tokens[next_tokens_source]\n"," tokens = torch.cat((tokens, next_tokens), dim=1)\n"," generated = generated[next_tokens_source]\n"," scores = scores_sum_average * seq_lengths\n"," is_stopped = is_stopped[next_tokens_source]\n"," next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)\n"," generated = torch.cat((generated, next_token_embed), dim=1)\n"," is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()\n"," if is_stopped.all():\n"," break\n"," scores = scores / seq_lengths\n"," output_list = tokens.cpu().numpy()\n"," output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]\n"," order = scores.argsort(descending=True)\n"," output_texts = [output_texts[i] for i in order]\n"," return output_texts\n","\n","\n","def generate2(\n"," model,\n"," tokenizer,\n"," tokens=None,\n"," prompt=None,\n"," embed=None,\n"," entry_count=1,\n"," entry_length=67, # maximum number of words\n"," top_p=0.8,\n"," temperature=1.,\n"," stop_token: str = '.',\n","):\n"," model.eval()\n"," generated_num = 0\n"," generated_list = []\n"," stop_token_index = tokenizer.encode(stop_token)[0]\n"," filter_value = -float(\"Inf\")\n"," device = next(model.parameters()).device\n","\n"," with torch.no_grad():\n","\n"," for entry_idx in trange(entry_count):\n"," if embed is not None:\n"," generated = embed\n"," else:\n"," if tokens is None:\n"," tokens = torch.tensor(tokenizer.encode(prompt))\n"," tokens = tokens.unsqueeze(0).to(device)\n","\n"," generated = model.gpt.transformer.wte(tokens)\n","\n"," for i in range(entry_length):\n","\n"," outputs = model.gpt(inputs_embeds=generated)\n"," logits = outputs.logits\n"," logits = logits[:, -1, :] / (temperature if temperature \u003e 0 else 1.0)\n"," sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n"," cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)\n"," sorted_indices_to_remove = cumulative_probs \u003e top_p\n"," sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[\n"," ..., :-1\n"," ].clone()\n"," sorted_indices_to_remove[..., 0] = 0\n","\n"," indices_to_remove = sorted_indices[sorted_indices_to_remove]\n"," logits[:, indices_to_remove] = filter_value\n"," next_token = torch.argmax(logits, -1).unsqueeze(0)\n"," next_token_embed = model.gpt.transformer.wte(next_token)\n"," if tokens is None:\n"," tokens = next_token\n"," else:\n"," tokens = torch.cat((tokens, next_token), dim=1)\n"," generated = torch.cat((generated, next_token_embed), dim=1)\n"," if stop_token_index == next_token.item():\n"," break\n","\n"," output_list = list(tokens.squeeze().cpu().numpy())\n"," output_text = tokenizer.decode(output_list)\n"," generated_list.append(output_text)\n","\n"," return generated_list[0]"]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":568,"status":"ok","timestamp":1683692950777,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"7lCgFHSgr_ny"},"outputs":[],"source":["#@title GPU/CPU\n","\n","\n","is_gpu = True #@param {type:\"boolean\"} \n"]},{"cell_type":"code","execution_count":8,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":130},"executionInfo":{"elapsed":19058,"status":"ok","timestamp":1683692972212,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"6bi_2zQ3QD57","outputId":"1e42798c-35d7-40fe-bf42-e414ea71bbe1"},"outputs":[{"name":"stderr","output_type":"stream","text":["100%|ββββββββββββββββββββββββββββββββββββββββ| 338M/338M [00:01\u003c00:00, 218MiB/s]\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"58c5e4f2f9734b2b828d489c36f42a63","version_major":2,"version_minor":0},"text/plain":["Downloading (β¦)olve/main/vocab.json: 0%| | 0.00/1.04M [00:00\u003c?, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"d73f3959aaae4e369ffc0c7eddb8a020","version_major":2,"version_minor":0},"text/plain":["Downloading (β¦)olve/main/merges.txt: 0%| | 0.00/456k [00:00\u003c?, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"7aef10017be94d0f84177ecb3d62d794","version_major":2,"version_minor":0},"text/plain":["Downloading (β¦)lve/main/config.json: 0%| | 0.00/665 [00:00\u003c?, ?B/s]"]},"metadata":{},"output_type":"display_data"}],"source":["#@title CLIP model + GPT2 tokenizer\n","\n","device = CUDA(0) if is_gpu else \"cpu\"\n","clip_model, preprocess = clip.load(\"ViT-B/32\", device=device, jit=False)\n","tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")"]},{"cell_type":"code","execution_count":12,"metadata":{"executionInfo":{"elapsed":21615,"status":"ok","timestamp":1683693503536,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"glBzYsgIwhwF"},"outputs":[],"source":["#@title Load model weights\n","\n","\n","prefix_length = 40\n","\n","model = ClipCaptionPrefix(prefix_length, clip_length=40, prefix_size=512,\n"," num_layers=16, mapping_type='transformer')\n","model.load_state_dict(torch.load(\"/content/gdrive/MyDrive/CS444Project/CLIP_prefix_caption/Pallaw/coco_train/coco_prefix-008.pt\", map_location=CPU)) \n","\n","model = model.eval() \n","device = CUDA(0) if is_gpu else \"cpu\"\n","model = model.to(device)\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/","height":38},"id":"m5jPDsEA5Kub"},"outputs":[{"data":{"text/html":["\n"," \u003cinput type=\"file\" id=\"files-48ba8b30-fd4b-4d30-b34d-8032f329c349\" name=\"files[]\" multiple disabled\n"," style=\"border:none\" /\u003e\n"," \u003coutput id=\"result-48ba8b30-fd4b-4d30-b34d-8032f329c349\"\u003e\n"," Upload widget is only available when the cell has been executed in the\n"," current browser session. Please rerun this cell to enable.\n"," \u003c/output\u003e\n"," \u003cscript\u003e// Copyright 2017 Google LLC\n","//\n","// Licensed under the Apache License, Version 2.0 (the \"License\");\n","// you may not use this file except in compliance with the License.\n","// You may obtain a copy of the License at\n","//\n","// http://www.apache.org/licenses/LICENSE-2.0\n","//\n","// Unless required by applicable law or agreed to in writing, software\n","// distributed under the License is distributed on an \"AS IS\" BASIS,\n","// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n","// See the License for the specific language governing permissions and\n","// limitations under the License.\n","\n","/**\n"," * @fileoverview Helpers for google.colab Python module.\n"," */\n","(function(scope) {\n","function span(text, styleAttributes = {}) {\n"," const element = document.createElement('span');\n"," element.textContent = text;\n"," for (const key of Object.keys(styleAttributes)) {\n"," element.style[key] = styleAttributes[key];\n"," }\n"," return element;\n","}\n","\n","// Max number of bytes which will be uploaded at a time.\n","const MAX_PAYLOAD_SIZE = 100 * 1024;\n","\n","function _uploadFiles(inputId, outputId) {\n"," const steps = uploadFilesStep(inputId, outputId);\n"," const outputElement = document.getElementById(outputId);\n"," // Cache steps on the outputElement to make it available for the next call\n"," // to uploadFilesContinue from Python.\n"," outputElement.steps = steps;\n","\n"," return _uploadFilesContinue(outputId);\n","}\n","\n","// This is roughly an async generator (not supported in the browser yet),\n","// where there are multiple asynchronous steps and the Python side is going\n","// to poll for completion of each step.\n","// This uses a Promise to block the python side on completion of each step,\n","// then passes the result of the previous step as the input to the next step.\n","function _uploadFilesContinue(outputId) {\n"," const outputElement = document.getElementById(outputId);\n"," const steps = outputElement.steps;\n","\n"," const next = steps.next(outputElement.lastPromiseValue);\n"," return Promise.resolve(next.value.promise).then((value) =\u003e {\n"," // Cache the last promise value to make it available to the next\n"," // step of the generator.\n"," outputElement.lastPromiseValue = value;\n"," return next.value.response;\n"," });\n","}\n","\n","/**\n"," * Generator function which is called between each async step of the upload\n"," * process.\n"," * @param {string} inputId Element ID of the input file picker element.\n"," * @param {string} outputId Element ID of the output display.\n"," * @return {!Iterable\u003c!Object\u003e} Iterable of next steps.\n"," */\n","function* uploadFilesStep(inputId, outputId) {\n"," const inputElement = document.getElementById(inputId);\n"," inputElement.disabled = false;\n","\n"," const outputElement = document.getElementById(outputId);\n"," outputElement.innerHTML = '';\n","\n"," const pickedPromise = new Promise((resolve) =\u003e {\n"," inputElement.addEventListener('change', (e) =\u003e {\n"," resolve(e.target.files);\n"," });\n"," });\n","\n"," const cancel = document.createElement('button');\n"," inputElement.parentElement.appendChild(cancel);\n"," cancel.textContent = 'Cancel upload';\n"," const cancelPromise = new Promise((resolve) =\u003e {\n"," cancel.onclick = () =\u003e {\n"," resolve(null);\n"," };\n"," });\n","\n"," // Wait for the user to pick the files.\n"," const files = yield {\n"," promise: Promise.race([pickedPromise, cancelPromise]),\n"," response: {\n"," action: 'starting',\n"," }\n"," };\n","\n"," cancel.remove();\n","\n"," // Disable the input element since further picks are not allowed.\n"," inputElement.disabled = true;\n","\n"," if (!files) {\n"," return {\n"," response: {\n"," action: 'complete',\n"," }\n"," };\n"," }\n","\n"," for (const file of files) {\n"," const li = document.createElement('li');\n"," li.append(span(file.name, {fontWeight: 'bold'}));\n"," li.append(span(\n"," `(${file.type || 'n/a'}) - ${file.size} bytes, ` +\n"," `last modified: ${\n"," file.lastModifiedDate ? file.lastModifiedDate.toLocaleDateString() :\n"," 'n/a'} - `));\n"," const percent = span('0% done');\n"," li.appendChild(percent);\n","\n"," outputElement.appendChild(li);\n","\n"," const fileDataPromise = new Promise((resolve) =\u003e {\n"," const reader = new FileReader();\n"," reader.onload = (e) =\u003e {\n"," resolve(e.target.result);\n"," };\n"," reader.readAsArrayBuffer(file);\n"," });\n"," // Wait for the data to be ready.\n"," let fileData = yield {\n"," promise: fileDataPromise,\n"," response: {\n"," action: 'continue',\n"," }\n"," };\n","\n"," // Use a chunked sending to avoid message size limits. See b/62115660.\n"," let position = 0;\n"," do {\n"," const length = Math.min(fileData.byteLength - position, MAX_PAYLOAD_SIZE);\n"," const chunk = new Uint8Array(fileData, position, length);\n"," position += length;\n","\n"," const base64 = btoa(String.fromCharCode.apply(null, chunk));\n"," yield {\n"," response: {\n"," action: 'append',\n"," file: file.name,\n"," data: base64,\n"," },\n"," };\n","\n"," let percentDone = fileData.byteLength === 0 ?\n"," 100 :\n"," Math.round((position / fileData.byteLength) * 100);\n"," percent.textContent = `${percentDone}% done`;\n","\n"," } while (position \u003c fileData.byteLength);\n"," }\n","\n"," // All done.\n"," yield {\n"," response: {\n"," action: 'complete',\n"," }\n"," };\n","}\n","\n","scope.google = scope.google || {};\n","scope.google.colab = scope.google.colab || {};\n","scope.google.colab._files = {\n"," _uploadFiles,\n"," _uploadFilesContinue,\n","};\n","})(self);\n","\u003c/script\u003e "],"text/plain":["\u003cIPython.core.display.HTML object\u003e"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Saving WhatsApp Image 2023-05-09 at 11.35.19 PM.jpeg to WhatsApp Image 2023-05-09 at 11.35.19 PM (1).jpeg\n","WhatsApp Image 2023-05-09 at 11.35.19 PM.jpeg\n"]}],"source":["#@title Upload Image\n","\n","\n","uploaded = files.upload()\n","\n","if not uploaded:\n"," UPLOADED_FILE = ''\n","elif len(uploaded) == 1:\n"," UPLOADED_FILE = list(uploaded.keys())[0]\n","else:\n"," raise AssertionError('Please upload one image at a time')\n","\n","print(UPLOADED_FILE)"]},{"cell_type":"code","execution_count":11,"metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1xsI6hhFVQlMn1njmmptvMjQGQd7_pg0d"},"executionInfo":{"elapsed":11536,"status":"ok","timestamp":1683693460232,"user":{"displayName":"Pallaw Kumar","userId":"00907428457681337218"},"user_tz":300},"id":"xQC2FQPtKLbz","outputId":"9b22fdce-adc8-49b9-9ace-85f79ac54fd3"},"outputs":[],"source":["#@title Inference\n","use_beam_search = True #@param {type:\"boolean\"} \n","\n","image = io.imread(UPLOADED_FILE)\n","pil_image = PIL.Image.fromarray(image)\n","#pil_img = Image(filename=UPLOADED_FILE)\n","display(pil_image)\n","\n","image = preprocess(pil_image).unsqueeze(0).to(device)\n","with torch.no_grad():\n"," prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)\n"," prefix = prefix / prefix.norm(2, -1).item()\n"," prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)\n","if use_beam_search:\n"," generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]\n","else:\n"," generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)\n","\n","\n","print('\\n')\n","print(generated_text_prefix)"]}],"metadata":{"accelerator":"GPU","colab":{"name":"","version":""},"gpuClass":"standard","kernelspec":{"display_name":"PyCharm (cvpr22)","language":"python","name":"pycharm-98db7c03"},"language_info":{"name":"python"},"stem_cell":{"cell_type":"raw","metadata":{"pycharm":{"metadata":false}},"source":""},"widgets":{"application/vnd.jupyter.widget-state+json":{"00313f0eea8442f581c88acf43747a00":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"0781496131554f349fb066b7f9e7e855":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"146aa1d835224d408e6dbd01d95c59af":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"ProgressStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"17925067685e4efd8b8abfcd9190fc51":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_9f4fce0c32234986b90ca8899f622efd","placeholder":"β","style":"IPY_MODEL_e3b439e95d534877b4a6feb155bf311e","value":"Downloading (β¦)olve/main/merges.txt: 100%"}},"297b592a2ccf4fce842de19ccf3c89ca":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_5f53fc687b9243e09c7c49c045e81297","placeholder":"β","style":"IPY_MODEL_8efc05dccbe54135a39e54a1f8b39164","value":" 665/665 [00:00\u0026lt;00:00, 26.1kB/s]"}},"2c8ecae59f1b40b6baec28e105c460aa":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2e7b57c2d9ab40b680a894d300f2278f":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"32049bc06ae14bf3b6f7da0297a2e076":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"3afcb996944e41d3a8939e7f19ec626a":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"3d1b37dbf5ee4492bbe205aabc06659d":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"4475cccabbc94a98a38858f6e6279ced":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"49ca9127f872417eb2193c57b55c2f4a":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"58c5e4f2f9734b2b828d489c36f42a63":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HBoxModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_ed65fc1b5325453aa6f0db70b6aa47ca","IPY_MODEL_f7624ba11614459ab02a349a26eba72a","IPY_MODEL_fc376267c8664dc8b4b93a52a2165a00"],"layout":"IPY_MODEL_3d1b37dbf5ee4492bbe205aabc06659d"}},"5f53fc687b9243e09c7c49c045e81297":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"633463fe1bb849f8b6d6d5426215c83e":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"6f905d07a9304138ba2fbc38f404d23f":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"ProgressStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"74e8c936102d40bca4065aaa88865228":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"78e4f4f39bdf48b28904982ca9e628cd":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"ProgressStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"799df42e67124c3887228857ab4ca4a4":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"7a6b544c34d74445a9c8ffa75a9a7ca1":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"FloatProgressModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_2c8ecae59f1b40b6baec28e105c460aa","max":456318,"min":0,"orientation":"horizontal","style":"IPY_MODEL_6f905d07a9304138ba2fbc38f404d23f","value":456318}},"7aef10017be94d0f84177ecb3d62d794":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HBoxModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_d201e967837a424baee8652c48717b55","IPY_MODEL_7fe24c5dfc4140fca56d988edef31e8e","IPY_MODEL_297b592a2ccf4fce842de19ccf3c89ca"],"layout":"IPY_MODEL_633463fe1bb849f8b6d6d5426215c83e"}},"7fe24c5dfc4140fca56d988edef31e8e":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"FloatProgressModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_49ca9127f872417eb2193c57b55c2f4a","max":665,"min":0,"orientation":"horizontal","style":"IPY_MODEL_78e4f4f39bdf48b28904982ca9e628cd","value":665}},"8efc05dccbe54135a39e54a1f8b39164":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"950078fcd2cc4d57a017d5433dc11c74":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c9aaec74f3fe46d99af920de071b36ba","placeholder":"β","style":"IPY_MODEL_799df42e67124c3887228857ab4ca4a4","value":" 456k/456k [00:00\u0026lt;00:00, 618kB/s]"}},"9f4fce0c32234986b90ca8899f622efd":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b688502a349147669f5ea93237457ae8":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"c9aaec74f3fe46d99af920de071b36ba":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d201e967837a424baee8652c48717b55":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_00313f0eea8442f581c88acf43747a00","placeholder":"β","style":"IPY_MODEL_b688502a349147669f5ea93237457ae8","value":"Downloading (β¦)lve/main/config.json: 100%"}},"d73f3959aaae4e369ffc0c7eddb8a020":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HBoxModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_17925067685e4efd8b8abfcd9190fc51","IPY_MODEL_7a6b544c34d74445a9c8ffa75a9a7ca1","IPY_MODEL_950078fcd2cc4d57a017d5433dc11c74"],"layout":"IPY_MODEL_74e8c936102d40bca4065aaa88865228"}},"e3b439e95d534877b4a6feb155bf311e":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"ed65fc1b5325453aa6f0db70b6aa47ca":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_3afcb996944e41d3a8939e7f19ec626a","placeholder":"β","style":"IPY_MODEL_32049bc06ae14bf3b6f7da0297a2e076","value":"Downloading (β¦)olve/main/vocab.json: 100%"}},"f7624ba11614459ab02a349a26eba72a":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"FloatProgressModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_2e7b57c2d9ab40b680a894d300f2278f","max":1042301,"min":0,"orientation":"horizontal","style":"IPY_MODEL_146aa1d835224d408e6dbd01d95c59af","value":1042301}},"fc376267c8664dc8b4b93a52a2165a00":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_4475cccabbc94a98a38858f6e6279ced","placeholder":"β","style":"IPY_MODEL_0781496131554f349fb066b7f9e7e855","value":" 1.04M/1.04M [00:00\u0026lt;00:00, 2.13MB/s]"}}}}},"nbformat":4,"nbformat_minor":0}
|
notebooks/transformer_baseline_inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/transformer_gpt3_inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pretrained_models/mlp_gpt2_weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f179e3da4662f132d181f5aef4989d72c7e3b61c2fe04691fa72c45047c6b2f
|
| 3 |
+
size 636286431
|
starter_code/CLIP_prefix_caption.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e3e5491ded7a74f37b4866d5f1d657d1f1cb7966a2536fc3713271b0202db555
|
| 3 |
+
size 4531268
|