from pathlib import Path import comet_ml import datasets import evaluate import lightning as L import torch from timm import create_model, data from tokenizers import Tokenizer from torch import nn from torch.utils.data import DataLoader from transformers import ( GPT2LMHeadModel, ) from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.callbacks import ModelCheckpoint eos_token_id = 50256 # obtained from gpt model class Projection(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.network = nn.Sequential( nn.Linear(in_features, in_features * 3), nn.GELU(), nn.Linear(in_features * 3, out_features), ) def forward(self, input): return self.network(input) class ImageNetCaptionModel(L.LightningModule): def __init__(self): super().__init__() # backbone model to extract image feature token self.backbone = create_model( "vit_mediumd_patch16_reg4_gap_384", pretrained=True ) self.llm = GPT2LMHeadModel.from_pretrained("gpt2") self.image_start_token = "" self.image_end_token = "" self.tokenizer = Tokenizer.from_pretrained("gpt2") self.tokenizer.add_special_tokens( [self.image_start_token, self.image_end_token] ) self.image_start_token_id = self.tokenizer.token_to_id(self.image_start_token) self.image_end_token_id = self.tokenizer.token_to_id(self.image_end_token) self.eos_token = eos_token_id self.llm.resize_token_embeddings(self.tokenizer.get_vocab_size()) self.embedding = self.llm.get_input_embeddings() self.projection = Projection( in_features=512, out_features=self.llm.config.hidden_size ) self.bleu_metric = evaluate.load("bleu") self.meteor_metric = evaluate.load("meteor") ## freeze backbone and gpt models. for param in self.backbone.parameters(): param.requires_grad = False for param in self.llm.parameters(): param.requires_grad = True def get_tokenizer(self): return self.tokenizer def forward(self, image=None, input_caption=None, **kwargs): image_feature = self.backbone.forward_features(image) projection = self.projection(image_feature) input_caption_embedding = self.embedding(input=input_caption) # concat start_image_token + projection + end_image_token + input_caption image_start_token, image_end_token = self.get_image_seperation_token( image=image ) input_embedding = torch.cat( [image_start_token, projection, image_end_token, input_caption_embedding], dim=1, ) attention_mask = torch.ones( input_embedding.size()[:-1], dtype=torch.long, device=image.device ) labels = torch.full( (input_embedding.size(0), input_embedding.size(1)), -100, dtype=torch.long, device=image.device, ) labels[:, projection.size(1) + 2 :] = input_caption # align text labels llm_output = self.llm( inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels ) return llm_output def training_step(self, batch, batch_idx): output = self.forward(**batch) self.log("loss", output.loss.item()) return output.loss def validation_step(self, batch, batch_idx): if batch_idx < 5: pred = self.predict_step(batch=batch, batch_idx=batch_idx) print( "evaluation ", "pred", pred, "original caption", batch["original_caption_enriched"], ) bleu = self.bleu_metric.compute( predictions=pred, references=batch["original_caption_enriched"] ) self.log("bleu", bleu["bleu"]) self.log("precision", bleu["brevity_penalty"]) metor = self.meteor_metric.compute( predictions=pred, references=batch["original_caption_enriched"] ) print(metor) self.log_dict(metor) def get_image_seperation_token(self, image): image_start_embedding = self.embedding( torch.tensor([self.image_start_token_id], device=image.device) ) image_end_embedding = self.embedding( torch.tensor([self.image_end_token_id], device=image.device) ) image_start_token = image_start_embedding.unsqueeze(0).repeat(len(image), 1, 1) image_end_token = image_end_embedding.unsqueeze(0).repeat(len(image), 1, 1) return image_start_token, image_end_token def configure_optimizers(self): proj_params = [p for p in self.projection.parameters() if p.requires_grad] llm_params = [p for p in self.llm.parameters() if p.requires_grad] optimizer = torch.optim.AdamW( [ {"params": proj_params, "lr": 1e-4, "weight_decay": 0.01}, {"params": llm_params, "lr": 5e-6, "weight_decay": 0.01}, ] ) return optimizer def predict_step(self, batch, batch_idx, dataloader_idx=0): image = batch["image"] image_feature = self.backbone.forward_features(image) projection = self.projection(image_feature) image_start_embedding = self.embedding( torch.tensor([self.image_start_token_id], device=image.device) ) image_end_embedding = self.embedding( torch.tensor([self.image_end_token_id], device=image.device) ) input_start_image_embedding_batch = image_start_embedding.unsqueeze(0).repeat( len(image), 1, 1 ) input_end_image_embedding_batch = image_end_embedding.unsqueeze(0).repeat( len(image), 1, 1 ) input_embedding = torch.cat( [ input_start_image_embedding_batch, projection, input_end_image_embedding_batch, ], dim=1, ) attention_mask = torch.ones( input_embedding.size()[:-1], dtype=torch.long, device=image.device ) outputs = self.llm.generate( inputs_embeds=input_embedding, attention_mask=attention_mask, eos_token_id=0, max_new_tokens=30, do_sample=True, # add randomness top_p=0.9, # nucleus sampling temperature=0.7, ) # Convert tensor to list of lists for decode_batch if outputs.dim() == 2: # outputs is [batch_size, sequence_length], convert to list of lists outputs_list = outputs.tolist() else: # outputs is already a list/sequence outputs_list = outputs return self.tokenizer.decode_batch(outputs_list, skip_special_tokens=True) def generate(self, image): data_config = data.resolve_model_data_config( create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True) ) transforms = data.create_transform(**data_config, is_training=False) image = transforms(image) return self.predict_step(batch={"image":image.unsqueeze(0)},batch_idx=0)[0] def collate_fn(batch): collected = {"image": [], "input_caption": [], "original_caption_enriched": []} for data in batch: collected["image"].append(torch.tensor(data["image"], dtype=torch.float)) collected["input_caption"].append( torch.tensor(data["input_caption"], dtype=torch.long) ) collected["original_caption_enriched"].append(data["original_caption_enriched"]) return { "image": torch.stack(collected["image"], dim=0), "input_caption": torch.stack(collected["input_caption"], dim=0), "original_caption_enriched": collected["original_caption_enriched"], } def agument(tokenizer: Tokenizer): data_config = data.resolve_model_data_config( create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True) ) transforms = data.create_transform(**data_config, is_training=False) def transform(data): ids = tokenizer.encode(data["caption_enriched"]) # Handle sequences based on length if len(ids.ids) <= 59: # For short sequences, just append EOS ids.ids.append(eos_token_id) else: # For long sequences, truncate to 59 tokens and append EOS ids.ids = ids.ids[:59] ids.ids.append(eos_token_id) # Pad to exactly 60 tokens ids.ids = ids.ids[:60] # Ensure we don't exceed 60 ids.pad(60) decoded = tokenizer.decode(ids.ids, skip_special_tokens=True) print("original", data["caption_enriched"], "decoded", decoded) data["input_caption"] = torch.tensor(ids.ids, dtype=torch.long) data["original_caption_enriched"] = data["caption_enriched"] data["image"] = transforms(data["image"]) return data return transform def is_valid_image(example): try: # Try opening the image if example["image"].mode == "RGB": return True return False except Exception as e: # ValueError will catch the MAX_TEXT_CHUNK error print("false", example["image"]) print("Exception:", e) return False def train( root_path: Path, dataset: datasets.Dataset, num_loader_worker: int = 0, batch_size=16, logger=None, ): # dataset = datasets.load_dataset("visual-layer/imagenet-1k-vl-enriched", split="validation").shuffle(seed=42).select(range(20000)).train_test_split(test_size=0.1) test_ds = dataset["test"] train_ds = dataset["train"] model = ImageNetCaptionModel() tokenizer = model.get_tokenizer() # Apply transformation to both datasets train_ds = train_ds.filter(is_valid_image) train_ds = train_ds.map(agument(tokenizer=tokenizer)) test_ds = test_ds.filter(is_valid_image) test_ds = test_ds.map(agument(tokenizer=tokenizer)) train_data_loader = DataLoader( dataset=train_ds, drop_last=True, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_loader_worker, ) evaluation_data_loader = DataLoader( dataset=test_ds, drop_last=True, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_loader_worker, ) if logger is None: logger = TensorBoardLogger(save_dir=str(root_path), version=1, name="logs") checkpoint_callback = ModelCheckpoint( dirpath=root_path / "checkpoint", filename="checkpoint-{epoch:02d}-{loss:.2f}", every_n_epochs=1, save_top_k=-1, ) print("path", root_path) trainer = L.Trainer( logger=logger, max_epochs=2, default_root_dir=root_path, callbacks=[checkpoint_callback], ) trainer.fit( model=model, train_dataloaders=train_data_loader, val_dataloaders=evaluation_data_loader, )