| import torch |
| import torch.nn as nn |
| from transformers import CLIPModel, AutoTokenizer, AutoModelForSeq2SeqLM |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| class CLIP2MT5_CrossAttention(nn.Module): |
| def __init__(self, clip_name='openai/clip-vit-base-patch32', |
| t5_name='mukayese/mt5-base-turkish-summarization'): |
| super().__init__() |
|
|
| self.clip = CLIPModel.from_pretrained(clip_name) |
| self.tokenizer = AutoTokenizer.from_pretrained(t5_name) |
| self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name) |
|
|
| self.vis_proj = nn.Linear( |
| self.clip.config.vision_config.hidden_size, |
| self.t5.config.d_model |
| ) |
|
|
| def forward(self, images, input_ids, attention_mask, labels=None): |
| vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state |
| vision_embeds = self.vis_proj(vision_outputs) |
|
|
| text_embeds = self.t5.encoder.embed_tokens(input_ids) |
|
|
| extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) |
|
|
| extended_attention_mask = torch.cat([ |
| torch.ones(vision_embeds.size(0), vision_embeds.size(1), |
| dtype=attention_mask.dtype, device=attention_mask.device), |
| attention_mask |
| ], dim=1) |
|
|
| if labels is not None: |
| labels = labels.clone() |
| labels[labels == self.tokenizer.pad_token_id] = -100 |
|
|
| return self.t5( |
| inputs_embeds=extended_input_embeds, |
| attention_mask=extended_attention_mask, |
| labels=labels, |
| return_dict=True |
| ) |
|
|
| @torch.no_grad() |
| def generate(self, images, input_ids, attention_mask, **gen_kwargs): |
| vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state |
| vision_embeds = self.vis_proj(vision_outputs) |
|
|
| text_embeds = self.t5.encoder.embed_tokens(input_ids) |
|
|
| extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) |
|
|
| extended_attention_mask = torch.cat([ |
| torch.ones(vision_embeds.size(0), vision_embeds.size(1), |
| dtype=attention_mask.dtype, device=attention_mask.device), |
| attention_mask |
| ], dim=1) |
|
|
| return self.t5.generate( |
| inputs_embeds=extended_input_embeds, |
| attention_mask=extended_attention_mask, |
| **gen_kwargs |
| ) |
|
|
|
|
|
|
| |
|
|
|
|
| def load_model( |
| repo_id: str, |
| filename: str = "model.pt", |
| clip_name="openai/clip-vit-base-patch32", |
| t5_name="mukayese/mt5-base-turkish-summarization", |
| device=None |
| ): |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
| model = CLIP2MT5_CrossAttention(clip_name=clip_name, t5_name=t5_name) |
|
|
| state = torch.load(model_path, map_location=device) |
| model.load_state_dict(state) |
|
|
| model.to(device) |
| model.eval() |
| return model |
|
|