Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from transformers import ( | |
| GPT2LMHeadModel, | |
| VisionEncoderDecoderModel, | |
| ViTImageProcessor, | |
| AutoTokenizer, | |
| BlipProcessor, | |
| BlipForConditionalGeneration | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # 1. Custom ResNet + GPT-2 (Training from Scratch) | |
| # ----------------------------------------------------------------------------- | |
| class ResNetEncoder(nn.Module): | |
| def __init__(self, embed_dim=768): | |
| super(ResNetEncoder, self).__init__() | |
| resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| modules = list(resnet.children())[:-1] | |
| self.resnet = nn.Sequential(*modules) | |
| for param in self.resnet.parameters(): | |
| param.requires_grad = False | |
| self.projection = nn.Linear(2048, embed_dim) | |
| self.bn = nn.BatchNorm1d(embed_dim, momentum=0.01) | |
| def forward(self, images): | |
| features = self.resnet(images) | |
| features = features.view(features.size(0), -1) | |
| features = self.projection(features) | |
| features = self.bn(features) | |
| return features | |
| class ResNetGPT2(nn.Module): | |
| def __init__(self, max_seq_len=40): | |
| super(ResNetGPT2, self).__init__() | |
| self.encoder = ResNetEncoder(embed_dim=768) | |
| self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') | |
| self.max_seq_len = max_seq_len | |
| def forward(self, images, input_ids, attention_mask): | |
| image_embeds = self.encoder(images) | |
| token_embeds = self.gpt2.transformer.wte(input_ids) | |
| inputs_embeds = torch.cat((image_embeds.unsqueeze(1), token_embeds), dim=1) | |
| batch_size = images.shape[0] | |
| ones = torch.ones(batch_size, 1).to(images.device) | |
| attention_mask = torch.cat((ones, attention_mask), dim=1) | |
| outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask) | |
| return outputs.logits | |
| def generate_caption(self, image, tokenizer, max_length=20, temperature=1.0): | |
| self.eval() | |
| with torch.no_grad(): | |
| image_embed = self.encoder(image.unsqueeze(0)) | |
| inputs_embeds = image_embed.unsqueeze(1) | |
| generated_tokens = [] | |
| for _ in range(max_length): | |
| outputs = self.gpt2(inputs_embeds=inputs_embeds) | |
| logits = outputs.logits[:, -1, :] / temperature | |
| next_token = torch.argmax(logits, dim=-1).unsqueeze(0) | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| generated_tokens.append(next_token.item()) | |
| next_token_embed = self.gpt2.transformer.wte(next_token) | |
| inputs_embeds = torch.cat((inputs_embeds, next_token_embed), dim=1) | |
| return tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| # ----------------------------------------------------------------------------- | |
| # 2. ViT + GPT-2 (Pre-trained SOTA 1) | |
| # ----------------------------------------------------------------------------- | |
| class ViTGPT2Captioner(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| print("Loading ViT-GPT2 model...") | |
| self.model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| self.feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| self.tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| def generate_caption(self, image, **kwargs): | |
| self.eval() | |
| with torch.no_grad(): | |
| pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(self.model.device) | |
| output_ids = self.model.generate(pixel_values, max_length=20, num_beams=4) | |
| preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| return preds[0].strip() | |
| # ----------------------------------------------------------------------------- | |
| # 3. BLIP (Pre-trained SOTA 2 - Best) | |
| # ----------------------------------------------------------------------------- | |
| class BLIPCaptioner(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| print("Loading BLIP model (Salesforce/blip-image-captioning-large)...") | |
| self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
| self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") | |
| def generate_caption(self, image, **kwargs): | |
| self.eval() | |
| with torch.no_grad(): | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.model.device) | |
| output_ids = self.model.generate(**inputs, max_length=50, num_beams=5, repetition_penalty=1.2, min_length=5) | |
| caption = self.processor.decode(output_ids[0], skip_special_tokens=True) | |
| return caption | |
| # ----------------------------------------------------------------------------- | |
| # Factory | |
| # ----------------------------------------------------------------------------- | |
| def get_model(config): | |
| if config.MODEL_TYPE == "resnet_gpt2": | |
| return ResNetGPT2() | |
| elif config.MODEL_TYPE == "vit_gpt2": | |
| return ViTGPT2Captioner() | |
| elif config.MODEL_TYPE == "blip": | |
| return BLIPCaptioner() | |
| else: | |
| raise ValueError(f"Unknown model type: {config.MODEL_TYPE}") | |