English
imagenet-caption / model.py
poonai's picture
Upload 12 files
6eff0e6 verified
from pathlib import Path
import comet_ml
import datasets
import evaluate
import lightning as L
import torch
from timm import create_model, data
from tokenizers import Tokenizer
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
GPT2LMHeadModel,
)
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
eos_token_id = 50256 # obtained from gpt model
class Projection(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.network = nn.Sequential(
nn.Linear(in_features, in_features * 3),
nn.GELU(),
nn.Linear(in_features * 3, out_features),
)
def forward(self, input):
return self.network(input)
class ImageNetCaptionModel(L.LightningModule):
def __init__(self):
super().__init__()
# backbone model to extract image feature token
self.backbone = create_model(
"vit_mediumd_patch16_reg4_gap_384", pretrained=True
)
self.llm = GPT2LMHeadModel.from_pretrained("gpt2")
self.image_start_token = "<image_start>"
self.image_end_token = "<image_end>"
self.tokenizer = Tokenizer.from_pretrained("gpt2")
self.tokenizer.add_special_tokens(
[self.image_start_token, self.image_end_token]
)
self.image_start_token_id = self.tokenizer.token_to_id(self.image_start_token)
self.image_end_token_id = self.tokenizer.token_to_id(self.image_end_token)
self.eos_token = eos_token_id
self.llm.resize_token_embeddings(self.tokenizer.get_vocab_size())
self.embedding = self.llm.get_input_embeddings()
self.projection = Projection(
in_features=512, out_features=self.llm.config.hidden_size
)
self.bleu_metric = evaluate.load("bleu")
self.meteor_metric = evaluate.load("meteor")
## freeze backbone and gpt models.
for param in self.backbone.parameters():
param.requires_grad = False
for param in self.llm.parameters():
param.requires_grad = True
def get_tokenizer(self):
return self.tokenizer
def forward(self, image=None, input_caption=None, **kwargs):
image_feature = self.backbone.forward_features(image)
projection = self.projection(image_feature)
input_caption_embedding = self.embedding(input=input_caption)
# concat start_image_token + projection + end_image_token + input_caption
image_start_token, image_end_token = self.get_image_seperation_token(
image=image
)
input_embedding = torch.cat(
[image_start_token, projection, image_end_token, input_caption_embedding],
dim=1,
)
attention_mask = torch.ones(
input_embedding.size()[:-1], dtype=torch.long, device=image.device
)
labels = torch.full(
(input_embedding.size(0), input_embedding.size(1)),
-100,
dtype=torch.long,
device=image.device,
)
labels[:, projection.size(1) + 2 :] = input_caption # align text labels
llm_output = self.llm(
inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels
)
return llm_output
def training_step(self, batch, batch_idx):
output = self.forward(**batch)
self.log("loss", output.loss.item())
return output.loss
def validation_step(self, batch, batch_idx):
if batch_idx < 5:
pred = self.predict_step(batch=batch, batch_idx=batch_idx)
print(
"evaluation ",
"pred",
pred,
"original caption",
batch["original_caption_enriched"],
)
bleu = self.bleu_metric.compute(
predictions=pred, references=batch["original_caption_enriched"]
)
self.log("bleu", bleu["bleu"])
self.log("precision", bleu["brevity_penalty"])
metor = self.meteor_metric.compute(
predictions=pred, references=batch["original_caption_enriched"]
)
print(metor)
self.log_dict(metor)
def get_image_seperation_token(self, image):
image_start_embedding = self.embedding(
torch.tensor([self.image_start_token_id], device=image.device)
)
image_end_embedding = self.embedding(
torch.tensor([self.image_end_token_id], device=image.device)
)
image_start_token = image_start_embedding.unsqueeze(0).repeat(len(image), 1, 1)
image_end_token = image_end_embedding.unsqueeze(0).repeat(len(image), 1, 1)
return image_start_token, image_end_token
def configure_optimizers(self):
proj_params = [p for p in self.projection.parameters() if p.requires_grad]
llm_params = [p for p in self.llm.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(
[
{"params": proj_params, "lr": 1e-4, "weight_decay": 0.01},
{"params": llm_params, "lr": 5e-6, "weight_decay": 0.01},
]
)
return optimizer
def predict_step(self, batch, batch_idx, dataloader_idx=0):
image = batch["image"]
image_feature = self.backbone.forward_features(image)
projection = self.projection(image_feature)
image_start_embedding = self.embedding(
torch.tensor([self.image_start_token_id], device=image.device)
)
image_end_embedding = self.embedding(
torch.tensor([self.image_end_token_id], device=image.device)
)
input_start_image_embedding_batch = image_start_embedding.unsqueeze(0).repeat(
len(image), 1, 1
)
input_end_image_embedding_batch = image_end_embedding.unsqueeze(0).repeat(
len(image), 1, 1
)
input_embedding = torch.cat(
[
input_start_image_embedding_batch,
projection,
input_end_image_embedding_batch,
],
dim=1,
)
attention_mask = torch.ones(
input_embedding.size()[:-1], dtype=torch.long, device=image.device
)
outputs = self.llm.generate(
inputs_embeds=input_embedding,
attention_mask=attention_mask,
eos_token_id=0,
max_new_tokens=30,
do_sample=True, # add randomness
top_p=0.9, # nucleus sampling
temperature=0.7,
)
# Convert tensor to list of lists for decode_batch
if outputs.dim() == 2:
# outputs is [batch_size, sequence_length], convert to list of lists
outputs_list = outputs.tolist()
else:
# outputs is already a list/sequence
outputs_list = outputs
return self.tokenizer.decode_batch(outputs_list, skip_special_tokens=True)
def generate(self, image):
data_config = data.resolve_model_data_config(
create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True)
)
transforms = data.create_transform(**data_config, is_training=False)
image = transforms(image)
return self.predict_step(batch={"image":image.unsqueeze(0)},batch_idx=0)[0]
def collate_fn(batch):
collected = {"image": [], "input_caption": [], "original_caption_enriched": []}
for data in batch:
collected["image"].append(torch.tensor(data["image"], dtype=torch.float))
collected["input_caption"].append(
torch.tensor(data["input_caption"], dtype=torch.long)
)
collected["original_caption_enriched"].append(data["original_caption_enriched"])
return {
"image": torch.stack(collected["image"], dim=0),
"input_caption": torch.stack(collected["input_caption"], dim=0),
"original_caption_enriched": collected["original_caption_enriched"],
}
def agument(tokenizer: Tokenizer):
data_config = data.resolve_model_data_config(
create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True)
)
transforms = data.create_transform(**data_config, is_training=False)
def transform(data):
ids = tokenizer.encode(data["caption_enriched"])
# Handle sequences based on length
if len(ids.ids) <= 59:
# For short sequences, just append EOS
ids.ids.append(eos_token_id)
else:
# For long sequences, truncate to 59 tokens and append EOS
ids.ids = ids.ids[:59]
ids.ids.append(eos_token_id)
# Pad to exactly 60 tokens
ids.ids = ids.ids[:60] # Ensure we don't exceed 60
ids.pad(60)
decoded = tokenizer.decode(ids.ids, skip_special_tokens=True)
print("original", data["caption_enriched"], "decoded", decoded)
data["input_caption"] = torch.tensor(ids.ids, dtype=torch.long)
data["original_caption_enriched"] = data["caption_enriched"]
data["image"] = transforms(data["image"])
return data
return transform
def is_valid_image(example):
try:
# Try opening the image
if example["image"].mode == "RGB":
return True
return False
except Exception as e:
# ValueError will catch the MAX_TEXT_CHUNK error
print("false", example["image"])
print("Exception:", e)
return False
def train(
root_path: Path,
dataset: datasets.Dataset,
num_loader_worker: int = 0,
batch_size=16,
logger=None,
):
# dataset = datasets.load_dataset("visual-layer/imagenet-1k-vl-enriched", split="validation").shuffle(seed=42).select(range(20000)).train_test_split(test_size=0.1)
test_ds = dataset["test"]
train_ds = dataset["train"]
model = ImageNetCaptionModel()
tokenizer = model.get_tokenizer()
# Apply transformation to both datasets
train_ds = train_ds.filter(is_valid_image)
train_ds = train_ds.map(agument(tokenizer=tokenizer))
test_ds = test_ds.filter(is_valid_image)
test_ds = test_ds.map(agument(tokenizer=tokenizer))
train_data_loader = DataLoader(
dataset=train_ds,
drop_last=True,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=num_loader_worker,
)
evaluation_data_loader = DataLoader(
dataset=test_ds,
drop_last=True,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=num_loader_worker,
)
if logger is None:
logger = TensorBoardLogger(save_dir=str(root_path), version=1, name="logs")
checkpoint_callback = ModelCheckpoint(
dirpath=root_path / "checkpoint",
filename="checkpoint-{epoch:02d}-{loss:.2f}",
every_n_epochs=1,
save_top_k=-1,
)
print("path", root_path)
trainer = L.Trainer(
logger=logger,
max_epochs=2,
default_root_dir=root_path,
callbacks=[checkpoint_callback],
)
trainer.fit(
model=model,
train_dataloaders=train_data_loader,
val_dataloaders=evaluation_data_loader,
)