|
|
from pathlib import Path |
|
|
import comet_ml |
|
|
import datasets |
|
|
import evaluate |
|
|
import lightning as L |
|
|
import torch |
|
|
from timm import create_model, data |
|
|
from tokenizers import Tokenizer |
|
|
from torch import nn |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import ( |
|
|
GPT2LMHeadModel, |
|
|
) |
|
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
|
from lightning.pytorch.callbacks import ModelCheckpoint |
|
|
|
|
|
|
|
|
eos_token_id = 50256 |
|
|
|
|
|
|
|
|
class Projection(nn.Module): |
|
|
def __init__(self, in_features, out_features): |
|
|
super().__init__() |
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(in_features, in_features * 3), |
|
|
nn.GELU(), |
|
|
nn.Linear(in_features * 3, out_features), |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.network(input) |
|
|
|
|
|
|
|
|
class ImageNetCaptionModel(L.LightningModule): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.backbone = create_model( |
|
|
"vit_mediumd_patch16_reg4_gap_384", pretrained=True |
|
|
) |
|
|
self.llm = GPT2LMHeadModel.from_pretrained("gpt2") |
|
|
|
|
|
self.image_start_token = "<image_start>" |
|
|
self.image_end_token = "<image_end>" |
|
|
self.tokenizer = Tokenizer.from_pretrained("gpt2") |
|
|
self.tokenizer.add_special_tokens( |
|
|
[self.image_start_token, self.image_end_token] |
|
|
) |
|
|
self.image_start_token_id = self.tokenizer.token_to_id(self.image_start_token) |
|
|
self.image_end_token_id = self.tokenizer.token_to_id(self.image_end_token) |
|
|
self.eos_token = eos_token_id |
|
|
|
|
|
self.llm.resize_token_embeddings(self.tokenizer.get_vocab_size()) |
|
|
self.embedding = self.llm.get_input_embeddings() |
|
|
|
|
|
self.projection = Projection( |
|
|
in_features=512, out_features=self.llm.config.hidden_size |
|
|
) |
|
|
|
|
|
self.bleu_metric = evaluate.load("bleu") |
|
|
self.meteor_metric = evaluate.load("meteor") |
|
|
|
|
|
|
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for param in self.llm.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def get_tokenizer(self): |
|
|
return self.tokenizer |
|
|
|
|
|
def forward(self, image=None, input_caption=None, **kwargs): |
|
|
image_feature = self.backbone.forward_features(image) |
|
|
projection = self.projection(image_feature) |
|
|
input_caption_embedding = self.embedding(input=input_caption) |
|
|
|
|
|
|
|
|
image_start_token, image_end_token = self.get_image_seperation_token( |
|
|
image=image |
|
|
) |
|
|
input_embedding = torch.cat( |
|
|
[image_start_token, projection, image_end_token, input_caption_embedding], |
|
|
dim=1, |
|
|
) |
|
|
attention_mask = torch.ones( |
|
|
input_embedding.size()[:-1], dtype=torch.long, device=image.device |
|
|
) |
|
|
|
|
|
labels = torch.full( |
|
|
(input_embedding.size(0), input_embedding.size(1)), |
|
|
-100, |
|
|
dtype=torch.long, |
|
|
device=image.device, |
|
|
) |
|
|
labels[:, projection.size(1) + 2 :] = input_caption |
|
|
|
|
|
llm_output = self.llm( |
|
|
inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels |
|
|
) |
|
|
return llm_output |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
output = self.forward(**batch) |
|
|
self.log("loss", output.loss.item()) |
|
|
return output.loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
if batch_idx < 5: |
|
|
pred = self.predict_step(batch=batch, batch_idx=batch_idx) |
|
|
print( |
|
|
"evaluation ", |
|
|
"pred", |
|
|
pred, |
|
|
"original caption", |
|
|
batch["original_caption_enriched"], |
|
|
) |
|
|
bleu = self.bleu_metric.compute( |
|
|
predictions=pred, references=batch["original_caption_enriched"] |
|
|
) |
|
|
self.log("bleu", bleu["bleu"]) |
|
|
self.log("precision", bleu["brevity_penalty"]) |
|
|
metor = self.meteor_metric.compute( |
|
|
predictions=pred, references=batch["original_caption_enriched"] |
|
|
) |
|
|
print(metor) |
|
|
self.log_dict(metor) |
|
|
|
|
|
def get_image_seperation_token(self, image): |
|
|
image_start_embedding = self.embedding( |
|
|
torch.tensor([self.image_start_token_id], device=image.device) |
|
|
) |
|
|
image_end_embedding = self.embedding( |
|
|
torch.tensor([self.image_end_token_id], device=image.device) |
|
|
) |
|
|
image_start_token = image_start_embedding.unsqueeze(0).repeat(len(image), 1, 1) |
|
|
image_end_token = image_end_embedding.unsqueeze(0).repeat(len(image), 1, 1) |
|
|
|
|
|
return image_start_token, image_end_token |
|
|
|
|
|
def configure_optimizers(self): |
|
|
proj_params = [p for p in self.projection.parameters() if p.requires_grad] |
|
|
llm_params = [p for p in self.llm.parameters() if p.requires_grad] |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
[ |
|
|
{"params": proj_params, "lr": 1e-4, "weight_decay": 0.01}, |
|
|
{"params": llm_params, "lr": 5e-6, "weight_decay": 0.01}, |
|
|
] |
|
|
) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
image = batch["image"] |
|
|
image_feature = self.backbone.forward_features(image) |
|
|
projection = self.projection(image_feature) |
|
|
|
|
|
image_start_embedding = self.embedding( |
|
|
torch.tensor([self.image_start_token_id], device=image.device) |
|
|
) |
|
|
image_end_embedding = self.embedding( |
|
|
torch.tensor([self.image_end_token_id], device=image.device) |
|
|
) |
|
|
input_start_image_embedding_batch = image_start_embedding.unsqueeze(0).repeat( |
|
|
len(image), 1, 1 |
|
|
) |
|
|
input_end_image_embedding_batch = image_end_embedding.unsqueeze(0).repeat( |
|
|
len(image), 1, 1 |
|
|
) |
|
|
|
|
|
input_embedding = torch.cat( |
|
|
[ |
|
|
input_start_image_embedding_batch, |
|
|
projection, |
|
|
input_end_image_embedding_batch, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
attention_mask = torch.ones( |
|
|
input_embedding.size()[:-1], dtype=torch.long, device=image.device |
|
|
) |
|
|
|
|
|
outputs = self.llm.generate( |
|
|
inputs_embeds=input_embedding, |
|
|
attention_mask=attention_mask, |
|
|
eos_token_id=0, |
|
|
max_new_tokens=30, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.7, |
|
|
) |
|
|
|
|
|
|
|
|
if outputs.dim() == 2: |
|
|
|
|
|
outputs_list = outputs.tolist() |
|
|
else: |
|
|
|
|
|
outputs_list = outputs |
|
|
|
|
|
return self.tokenizer.decode_batch(outputs_list, skip_special_tokens=True) |
|
|
|
|
|
def generate(self, image): |
|
|
data_config = data.resolve_model_data_config( |
|
|
create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True) |
|
|
) |
|
|
transforms = data.create_transform(**data_config, is_training=False) |
|
|
image = transforms(image) |
|
|
|
|
|
return self.predict_step(batch={"image":image.unsqueeze(0)},batch_idx=0)[0] |
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
collected = {"image": [], "input_caption": [], "original_caption_enriched": []} |
|
|
|
|
|
for data in batch: |
|
|
collected["image"].append(torch.tensor(data["image"], dtype=torch.float)) |
|
|
collected["input_caption"].append( |
|
|
torch.tensor(data["input_caption"], dtype=torch.long) |
|
|
) |
|
|
collected["original_caption_enriched"].append(data["original_caption_enriched"]) |
|
|
|
|
|
return { |
|
|
"image": torch.stack(collected["image"], dim=0), |
|
|
"input_caption": torch.stack(collected["input_caption"], dim=0), |
|
|
"original_caption_enriched": collected["original_caption_enriched"], |
|
|
} |
|
|
|
|
|
|
|
|
def agument(tokenizer: Tokenizer): |
|
|
data_config = data.resolve_model_data_config( |
|
|
create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True) |
|
|
) |
|
|
transforms = data.create_transform(**data_config, is_training=False) |
|
|
|
|
|
def transform(data): |
|
|
ids = tokenizer.encode(data["caption_enriched"]) |
|
|
|
|
|
|
|
|
if len(ids.ids) <= 59: |
|
|
|
|
|
ids.ids.append(eos_token_id) |
|
|
else: |
|
|
|
|
|
ids.ids = ids.ids[:59] |
|
|
ids.ids.append(eos_token_id) |
|
|
|
|
|
|
|
|
ids.ids = ids.ids[:60] |
|
|
ids.pad(60) |
|
|
|
|
|
decoded = tokenizer.decode(ids.ids, skip_special_tokens=True) |
|
|
print("original", data["caption_enriched"], "decoded", decoded) |
|
|
|
|
|
data["input_caption"] = torch.tensor(ids.ids, dtype=torch.long) |
|
|
|
|
|
data["original_caption_enriched"] = data["caption_enriched"] |
|
|
data["image"] = transforms(data["image"]) |
|
|
return data |
|
|
|
|
|
return transform |
|
|
|
|
|
|
|
|
def is_valid_image(example): |
|
|
try: |
|
|
|
|
|
if example["image"].mode == "RGB": |
|
|
return True |
|
|
|
|
|
return False |
|
|
except Exception as e: |
|
|
|
|
|
print("false", example["image"]) |
|
|
print("Exception:", e) |
|
|
return False |
|
|
|
|
|
|
|
|
def train( |
|
|
root_path: Path, |
|
|
dataset: datasets.Dataset, |
|
|
num_loader_worker: int = 0, |
|
|
batch_size=16, |
|
|
logger=None, |
|
|
): |
|
|
|
|
|
test_ds = dataset["test"] |
|
|
train_ds = dataset["train"] |
|
|
|
|
|
model = ImageNetCaptionModel() |
|
|
|
|
|
tokenizer = model.get_tokenizer() |
|
|
|
|
|
|
|
|
train_ds = train_ds.filter(is_valid_image) |
|
|
train_ds = train_ds.map(agument(tokenizer=tokenizer)) |
|
|
|
|
|
test_ds = test_ds.filter(is_valid_image) |
|
|
test_ds = test_ds.map(agument(tokenizer=tokenizer)) |
|
|
|
|
|
train_data_loader = DataLoader( |
|
|
dataset=train_ds, |
|
|
drop_last=True, |
|
|
batch_size=batch_size, |
|
|
collate_fn=collate_fn, |
|
|
num_workers=num_loader_worker, |
|
|
) |
|
|
evaluation_data_loader = DataLoader( |
|
|
dataset=test_ds, |
|
|
drop_last=True, |
|
|
batch_size=batch_size, |
|
|
collate_fn=collate_fn, |
|
|
num_workers=num_loader_worker, |
|
|
) |
|
|
|
|
|
if logger is None: |
|
|
logger = TensorBoardLogger(save_dir=str(root_path), version=1, name="logs") |
|
|
checkpoint_callback = ModelCheckpoint( |
|
|
dirpath=root_path / "checkpoint", |
|
|
filename="checkpoint-{epoch:02d}-{loss:.2f}", |
|
|
every_n_epochs=1, |
|
|
save_top_k=-1, |
|
|
) |
|
|
print("path", root_path) |
|
|
trainer = L.Trainer( |
|
|
logger=logger, |
|
|
max_epochs=2, |
|
|
default_root_dir=root_path, |
|
|
callbacks=[checkpoint_callback], |
|
|
) |
|
|
trainer.fit( |
|
|
model=model, |
|
|
train_dataloaders=train_data_loader, |
|
|
val_dataloaders=evaluation_data_loader, |
|
|
) |
|
|
|