| | import torch |
| | import torch.nn as nn |
| | import wandb |
| | import streamlit as st |
| | import os |
| |
|
| | import clip |
| | from transformers import GPT2Tokenizer, GPT2LMHeadModel |
| |
|
| |
|
| | class ImageEncoder(nn.Module): |
| |
|
| | def __init__(self, base_network): |
| | super(ImageEncoder, self).__init__() |
| | self.base_network = base_network |
| | self.embedding_size = self.base_network.token_embedding.weight.shape[1] |
| |
|
| | def forward(self, images): |
| | with torch.no_grad(): |
| | x = self.base_network.encode_image(images) |
| | x = x / x.norm(dim=1, keepdim=True) |
| | x = x.float() |
| |
|
| | return x |
| |
|
| | class Mapping(nn.Module): |
| | |
| | def __init__(self, clip_embedding_size, gpt_embedding_size, length=30): |
| | super(Mapping, self).__init__() |
| |
|
| | self.clip_embedding_size = clip_embedding_size |
| | self.gpt_embedding_size = gpt_embedding_size |
| | self.length = length |
| |
|
| | self.fc1 = nn.Linear(clip_embedding_size, gpt_embedding_size * length) |
| | |
| | def forward(self, x): |
| | x = self.fc1(x) |
| |
|
| | return x.view(-1, self.length, self.gpt_embedding_size) |
| | |
| |
|
| | class TextDecoder(nn.Module): |
| | def __init__(self, base_network): |
| | super(TextDecoder, self).__init__() |
| | self.base_network = base_network |
| | self.embedding_size = self.base_network.transformer.wte.weight.shape[1] |
| | self.vocab_size = self.base_network.transformer.wte.weight.shape[0] |
| | |
| | def forward(self, concat_embedding, mask=None): |
| | return self.base_network(inputs_embeds=concat_embedding, attention_mask=mask) |
| | |
| |
|
| | def get_embedding(self, texts): |
| | return self.base_network.transformer.wte(texts) |
| |
|
| |
|
| | import pytorch_lightning as pl |
| |
|
| |
|
| | class ImageCaptioner(pl.LightningModule): |
| | def __init__(self, clip_model, gpt_model, tokenizer, total_steps, max_length=20): |
| | super(ImageCaptioner, self).__init__() |
| |
|
| | self.padding_token_id = tokenizer.pad_token_id |
| | |
| |
|
| | |
| | self.clip = ImageEncoder(clip_model) |
| | self.gpt = TextDecoder(gpt_model) |
| | self.mapping_network = Mapping(self.clip.embedding_size, self.gpt.embedding_size, max_length) |
| |
|
| | |
| | self.total_steps = total_steps |
| | self.max_length = max_length |
| | self.clip_embedding_size = self.clip.embedding_size |
| | self.gpt_embedding_size = self.gpt.embedding_size |
| | self.gpt_vocab_size = self.gpt.vocab_size |
| |
|
| | |
| | def forward(self, images, texts, masks): |
| | texts_embedding = self.gpt.get_embedding(texts) |
| | images_embedding = self.clip(images) |
| |
|
| | images_projection = self.mapping_network(images_embedding).view(-1, self.max_length, self.gpt_embedding_size) |
| | embedding_concat = torch.cat((images_projection, texts_embedding), dim=1) |
| |
|
| | out = self.gpt(embedding_concat, masks) |
| |
|
| | return out |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | @st.cache_resource |
| | def load_clip_model(): |
| |
|
| | clip_model, image_transform = clip.load("ViT-L/14", device="cpu") |
| |
|
| | return clip_model, image_transform |
| |
|
| | @st.cache_resource |
| | def load_gpt_model(): |
| | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
| | gpt_model = GPT2LMHeadModel.from_pretrained('gpt2') |
| |
|
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | return gpt_model, tokenizer |
| |
|
| | @st.cache_resource |
| | def load_model(): |
| |
|
| | |
| | artifact_dir = "./artifacts/model-ql03493w:v3" |
| | PATH = f"{os.getcwd()}/{artifact_dir[2:]}/model.ckpt" |
| |
|
| | |
| | clip_model, image_transfrom = load_clip_model() |
| | gpt_model, tokenizer = load_gpt_model() |
| | |
| | |
| | |
| | print(PATH) |
| | model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0) |
| | checkpoint = torch.load(PATH, map_location=torch.device('cpu')) |
| | model.load_state_dict(checkpoint["state_dict"]) |
| |
|
| | return model, image_transfrom, tokenizer |