|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import CLIPModel, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
class CLIP2MT5_CrossAttention(nn.Module): |
|
|
def __init__(self, clip_name='openai/clip-vit-base-patch32', |
|
|
t5_name='mukayese/mt5-base-turkish-summarization'): |
|
|
super().__init__() |
|
|
|
|
|
self.clip = CLIPModel.from_pretrained(clip_name) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(t5_name) |
|
|
self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name) |
|
|
|
|
|
self.vis_proj = nn.Linear( |
|
|
self.clip.config.vision_config.hidden_size, |
|
|
self.t5.config.d_model |
|
|
) |
|
|
|
|
|
def forward(self, images, input_ids, attention_mask, labels=None): |
|
|
vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state |
|
|
vision_embeds = self.vis_proj(vision_outputs) |
|
|
|
|
|
text_embeds = self.t5.encoder.embed_tokens(input_ids) |
|
|
|
|
|
extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) |
|
|
|
|
|
extended_attention_mask = torch.cat([ |
|
|
torch.ones(vision_embeds.size(0), vision_embeds.size(1), |
|
|
dtype=attention_mask.dtype, device=attention_mask.device), |
|
|
attention_mask |
|
|
], dim=1) |
|
|
|
|
|
if labels is not None: |
|
|
labels = labels.clone() |
|
|
labels[labels == self.tokenizer.pad_token_id] = -100 |
|
|
|
|
|
return self.t5( |
|
|
inputs_embeds=extended_input_embeds, |
|
|
attention_mask=extended_attention_mask, |
|
|
labels=labels, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, images, input_ids, attention_mask, **gen_kwargs): |
|
|
vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state |
|
|
vision_embeds = self.vis_proj(vision_outputs) |
|
|
|
|
|
text_embeds = self.t5.encoder.embed_tokens(input_ids) |
|
|
|
|
|
extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) |
|
|
|
|
|
extended_attention_mask = torch.cat([ |
|
|
torch.ones(vision_embeds.size(0), vision_embeds.size(1), |
|
|
dtype=attention_mask.dtype, device=attention_mask.device), |
|
|
attention_mask |
|
|
], dim=1) |
|
|
|
|
|
return self.t5.generate( |
|
|
inputs_embeds=extended_input_embeds, |
|
|
attention_mask=extended_attention_mask, |
|
|
**gen_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model( |
|
|
repo_id: str, |
|
|
filename: str = "model.pt", |
|
|
clip_name="openai/clip-vit-base-patch32", |
|
|
t5_name="mukayese/mt5-base-turkish-summarization", |
|
|
device=None |
|
|
): |
|
|
if device is None: |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
|
|
|
model = CLIP2MT5_CrossAttention(clip_name=clip_name, t5_name=t5_name) |
|
|
|
|
|
state = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(state) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|