|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
import cv2 |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
print(torch.cuda.is_available()) |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-360M") |
|
|
llm_model = AutoModelForCausalLM.from_pretrained("alibidaran/SMOLL_image_captioner").to('cuda') |
|
|
|
|
|
class SmoLLM_processor(): |
|
|
def __init__(image_model=clip_model,image_processor=clip_processor) |
|
|
self.image_model=image_model |
|
|
self.image_processor |
|
|
def get_features(image_path): |
|
|
image = cv2.imread(image_url) |
|
|
image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB) |
|
|
image_features=processor.get_features(image_url,device='cuda') |
|
|
tokenizer.pad_token=tokenizer.eos_token |
|
|
prompt=f""" |
|
|
##User <image> Write a caption |
|
|
##Assitant:""" |
|
|
tokenized=tokenizer(prompt,return_tensors='pt') |
|
|
label=tokenized['input_ids'].to('cuda') |
|
|
att=tokenized['attention_mask'].to('cuda') |
|
|
data={} |
|
|
data['image_features']=image_features |
|
|
data['label']=label |
|
|
data['attention_mask']=att |
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
class SMOLLm_VISION_ImageCaptioning(torch.nn.Module): |
|
|
def __init__(self, llm_model, hidden_dim): |
|
|
super(ImageCaptioningModel, self).__init__() |
|
|
self.llm_model = llm_model |
|
|
self.fc = torch.nn.Linear(768, 960) |
|
|
self.relu=torch.nn.GELU() |
|
|
def forward(self, images, input_ids,att): |
|
|
|
|
|
image_features = self.relu(self.fc(images)) |
|
|
|
|
|
|
|
|
|
|
|
llama_inputs = self.llm_model.prepare_inputs_for_generation(input_ids) |
|
|
with torch.no_grad(): |
|
|
llama_embeds=self.llm_model.get_input_embeddings()(llama_inputs['input_ids']) |
|
|
|
|
|
|
|
|
combined_inputs = torch.cat([image_features.unsqueeze(1).float(),llama_embeds], dim=1) |
|
|
|
|
|
outputs = self.llm_model(inputs_embeds=combined_inputs,attention_mask=att) |
|
|
|
|
|
return outputs.logits[:,1:,:],combined_inputs |
|
|
|