pusulcuk Mueris commited on
Commit
6d6a355
·
0 Parent(s):

Duplicate from Mueris/TurkishVLMTAMGA

Browse files

Co-authored-by: Mustafa Eren Işıktaşlı <Mueris@users.noreply.huggingface.co>

Files changed (9) hide show
  1. .gitattributes +35 -0
  2. README.txt +20 -0
  3. __init__.py +2 -0
  4. app.py +19 -0
  5. config.json +8 -0
  6. inference.py +46 -0
  7. model.pt +3 -0
  8. model.py +91 -0
  9. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLIP2MT5 CrossAttention VQA Model
2
+
3
+ This is a Vision-Language model combining **CLIP-ViT** and **mT5** using a custom cross-attention bridge.
4
+ It supports Visual Question Answering (VQA) in Turkish.
5
+
6
+ ## Usage
7
+
8
+ ```python
9
+ from PIL import Image
10
+ from hf_clip2mt5 import load_for_inference, predict
11
+
12
+ repo_id = "MUERIS/TurkishVLMTAMGA"
13
+
14
+ model, tokenizer, device = load_for_inference(repo_id)
15
+
16
+ image = Image.open("example.jpg")
17
+ question = "Görselde kaç kişi var?"
18
+
19
+ answer = predict(model, tokenizer, device, image, question)
20
+ print(answer)
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import CLIP2MT5_CrossAttention, load_model
2
+ from .inference import load_for_inference, predict
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import load_for_inference, predict
3
+
4
+ REPO_ID = "MUERIS/TurkishVLMTAMGA"
5
+
6
+ model, tokenizer, device = load_for_inference(REPO_ID)
7
+
8
+ def answer(image, question):
9
+ return predict(model, tokenizer, device, image, question)
10
+
11
+ gr.Interface(
12
+ fn=answer,
13
+ inputs=[
14
+ gr.Image(type="pil"),
15
+ gr.Textbox(label="Question")
16
+ ],
17
+ outputs="text",
18
+ title="CLIP2MT5 Visual Question Answering"
19
+ ).launch()
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "clip2mt5-crossattention",
3
+ "library": "pytorch",
4
+ "architectures": ["CLIP2MT5_CrossAttention"],
5
+ "pipeline_tag": "image-text-to-text",
6
+ "description": "CLIP + mT5 VQA Model using cross-attention.",
7
+ "author": "MUERIS"
8
+ }
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+
5
+ from model import load_model
6
+
7
+
8
+ # Preprocessing
9
+ _transform = transforms.Compose([
10
+ transforms.Resize((224, 224)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(
13
+ mean=[0.4815, 0.4578, 0.4082],
14
+ std=[0.2686, 0.2613, 0.2758]
15
+ )
16
+ ])
17
+
18
+
19
+ def load_for_inference(repo_id, filename="model.pt"):
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model = load_model(repo_id=repo_id, filename=filename, device=device)
22
+ tokenizer = model.tokenizer
23
+ return model, tokenizer, device
24
+
25
+
26
+ def predict(model, tokenizer, device, image: Image.Image, question: str):
27
+ image_tensor = _transform(image).unsqueeze(0).to(device)
28
+
29
+ q = tokenizer(
30
+ question,
31
+ return_tensors='pt',
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=64
35
+ ).to(device)
36
+
37
+ with torch.no_grad():
38
+ output_ids = model.generate(
39
+ images=image_tensor,
40
+ input_ids=q.input_ids,
41
+ attention_mask=q.attention_mask,
42
+ max_length=64,
43
+ num_beams=4
44
+ )
45
+
46
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7006cd53a99bc96da69f28839e99492d83fc73f55a6d429adb8cf7f8ccfa6d45
3
+ size 2937292702
model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPModel, AutoTokenizer, AutoModelForSeq2SeqLM
4
+ from huggingface_hub import hf_hub_download
5
+
6
+
7
+ class CLIP2MT5_CrossAttention(nn.Module):
8
+ def __init__(self, clip_name='openai/clip-vit-base-patch32',
9
+ t5_name='mukayese/mt5-base-turkish-summarization'):
10
+ super().__init__()
11
+
12
+ self.clip = CLIPModel.from_pretrained(clip_name)
13
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_name)
14
+ self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name)
15
+
16
+ self.vis_proj = nn.Linear(
17
+ self.clip.config.vision_config.hidden_size,
18
+ self.t5.config.d_model
19
+ )
20
+
21
+ def forward(self, images, input_ids, attention_mask, labels=None):
22
+ vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state
23
+ vision_embeds = self.vis_proj(vision_outputs)
24
+
25
+ text_embeds = self.t5.encoder.embed_tokens(input_ids)
26
+
27
+ extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
28
+
29
+ extended_attention_mask = torch.cat([
30
+ torch.ones(vision_embeds.size(0), vision_embeds.size(1),
31
+ dtype=attention_mask.dtype, device=attention_mask.device),
32
+ attention_mask
33
+ ], dim=1)
34
+
35
+ if labels is not None:
36
+ labels = labels.clone()
37
+ labels[labels == self.tokenizer.pad_token_id] = -100
38
+
39
+ return self.t5(
40
+ inputs_embeds=extended_input_embeds,
41
+ attention_mask=extended_attention_mask,
42
+ labels=labels,
43
+ return_dict=True
44
+ )
45
+
46
+ @torch.no_grad()
47
+ def generate(self, images, input_ids, attention_mask, **gen_kwargs):
48
+ vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state
49
+ vision_embeds = self.vis_proj(vision_outputs)
50
+
51
+ text_embeds = self.t5.encoder.embed_tokens(input_ids)
52
+
53
+ extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
54
+
55
+ extended_attention_mask = torch.cat([
56
+ torch.ones(vision_embeds.size(0), vision_embeds.size(1),
57
+ dtype=attention_mask.dtype, device=attention_mask.device),
58
+ attention_mask
59
+ ], dim=1)
60
+
61
+ return self.t5.generate(
62
+ inputs_embeds=extended_input_embeds,
63
+ attention_mask=extended_attention_mask,
64
+ **gen_kwargs
65
+ )
66
+
67
+
68
+
69
+ # HF Loader for STATE_DICT
70
+
71
+
72
+ def load_model(
73
+ repo_id: str,
74
+ filename: str = "model.pt",
75
+ clip_name="openai/clip-vit-base-patch32",
76
+ t5_name="mukayese/mt5-base-turkish-summarization",
77
+ device=None
78
+ ):
79
+ if device is None:
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
83
+
84
+ model = CLIP2MT5_CrossAttention(clip_name=clip_name, t5_name=t5_name)
85
+
86
+ state = torch.load(model_path, map_location=device)
87
+ model.load_state_dict(state)
88
+
89
+ model.to(device)
90
+ model.eval()
91
+ return model
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ huggingface_hub
5
+ gradio
6
+ Pillow