cactuarix commited on
Commit
7a895c1
Β·
1 Parent(s): f0359ed

first commit

Browse files
Files changed (2) hide show
  1. app.py +212 -0
  2. requirements.txt +121 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import PreTrainedModel
3
+ import torch
4
+ import cv2
5
+ import os
6
+ from torchvision import transforms as tr
7
+ import numpy as np
8
+
9
+ from transformers import PretrainedConfig, PreTrainedModel
10
+ from torch import nn
11
+ from torchvision import models
12
+
13
+ from transformers import PreTrainedModel
14
+ from transformers import AutoTokenizer
15
+ import torch
16
+ from huggingface_hub import hf_hub_download
17
+ import json
18
+
19
+
20
+ class img_fe_class_vit(nn.Module):
21
+
22
+ def __init__(self, base_model, emb_size):
23
+ super(img_fe_class_vit, self).__init__()
24
+ self.patch = base_model.conv_proj
25
+ self.encoder = base_model.encoder
26
+ self.pos_embedding = base_model.encoder.pos_embedding.requires_grad_(False)
27
+ self.class_token = base_model.class_token.requires_grad_(False)
28
+ for param in self.encoder.parameters():
29
+ param.requires_grad_(False)
30
+ for param in self.patch.parameters():
31
+ param.requires_grad_(False)
32
+ self.fc = nn.Linear(base_model.heads.head.in_features, emb_size)
33
+
34
+ def forward(self, imgs):
35
+ imgs = self.patch(imgs)
36
+ imgs = imgs.flatten(2).transpose(1, 2)
37
+ imgs = torch.cat([self.class_token.expand(imgs.shape[0], -1, -1), imgs], dim=1)
38
+ imgs = imgs + self.pos_embedding
39
+ embeddings = self.encoder(imgs)
40
+ embeddings = self.fc(embeddings)
41
+ return embeddings
42
+
43
+
44
+ class text_fe_class_transformer(nn.Module):
45
+
46
+ def __init__(self, num_heads, num_layers):
47
+ super(text_fe_class_transformer, self).__init__()
48
+
49
+ self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=300, padding_idx=tok_to_ind['<PAD>'])
50
+ # self.embed.weight = nn.Parameter(
51
+ # torch.from_numpy(glove_weights).to(dtype=self.embed.weight.dtype),
52
+ # requires_grad=True,
53
+ # )
54
+ self.transformer_layer = nn.TransformerDecoderLayer(d_model=300, nhead=num_heads, dim_feedforward=2048, batch_first=True, activation='gelu', dropout=0.1)
55
+ self.transformer = nn.TransformerDecoder(self.transformer_layer, num_layers=num_layers)
56
+
57
+ def forward(self, texts, img_features):
58
+ emb = self.embed(texts)
59
+ casual_mask = nn.Transformer.generate_square_subsequent_mask(texts.shape[-1])
60
+ padding_mask = torch.where(texts == 3, -torch.inf, 0)
61
+ out = self.transformer(emb, img_features, tgt_mask=casual_mask.to(device), tgt_key_padding_mask=padding_mask.to(device), tgt_is_causal=True)
62
+ return out
63
+
64
+
65
+ class image_captioning_model_transformer(nn.Module):
66
+
67
+ def __init__(self, num_heads, num_layers):
68
+ super(image_captioning_model_transformer, self).__init__()
69
+ self.feature_extractor = img_fe_class_vit(models.vit_b_16(weights='IMAGENET1K_V1'), 300)
70
+ self.caption_generator = text_fe_class_transformer(num_heads, num_layers)
71
+ self.fc = nn.Linear(300, vocab_size, bias=False)
72
+
73
+ def forward(self, img_batch, texts_batch):
74
+ img_batch_features = self.feature_extractor(img_batch)
75
+ out = self.caption_generator(texts_batch, img_batch_features)
76
+ out = self.fc(out)
77
+ return out
78
+
79
+ from typing import Optional
80
+
81
+ def generate(
82
+ model,
83
+ image,
84
+ max_seq_len: Optional[int] = 20,
85
+ top_p: Optional[float] = None,
86
+ top_k: Optional[int] = None,
87
+ ):
88
+ """
89
+ По ΠΊΠ°Ρ€Ρ‚ΠΈΠ½ΠΊΠ΅ image Π³Π΅Π½Π΅Ρ€ΠΈΡ€ΡƒΠ΅Ρ‚Π΅ тСкст модСлью model Π»ΠΈΠ±ΠΎ ΠΏΠΎΠΊΠ° Π½Π΅ сгСнСрируСтС '<EOS>' Ρ‚ΠΎΠΊΠ΅Π½, Π»ΠΈΠ±ΠΎ ΠΏΠΎΠΊΠ° Π½Π΅ сгСнСрируСтС max_seq_len Ρ‚ΠΎΠΊΠ΅Π½ΠΎΠ²
90
+ top_k -> послС получСния прСдсказания оставляСтС ΠΏΠ΅Ρ€Π²Ρ‹Π΅ top_k слов ΠΈ сэмплируСтС случайно с ΠΏΠ΅Ρ€Π΅Π½ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½Π½Ρ‹ΠΌΠΈ вСроятностями ΠΈΠ· ΠΎΡΡ‚Π°Π²ΡˆΠΈΡ…ΡΡ слов
91
+ top_p -> послС получСния прСдсказания оставляСтС ΠΏΠ΅Ρ€Π²Ρ‹Π΅ сколько-Ρ‚ΠΎ слов, Ρ‚Π°ΠΊ, Ρ‡Ρ‚ΠΎΠ±Ρ‹ суммарная Π²Π΅Ρ€ΠΎΡΡ‚Π½ΠΎΡΡ‚ΡŒ ΠΎΡΡ‚Π°Π²ΡˆΠΈΡ…ΡΡ слов Π±Ρ‹Π»Π° Π½Π΅ большС top_p,
92
+ послС Ρ‡Π΅Π³ΠΎ сэмплируСтС с ΠΏΠ΅Ρ€Π΅Π½ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½Π½Ρ‹ΠΌΠΈ вСроятностями ΠΈΠ· ΠΎΡΡ‚Π°Π²ΡˆΠΈΡ…ΡΡ слов
93
+ ΠΈΠ½Π°Ρ‡Π΅ -> сэмплируСтС случайноС слово с прСдсказанными вСроятностями
94
+ """
95
+ assert top_p is None or top_k is None, "Don't use top_p and top_k at the same time"
96
+
97
+ model.eval()
98
+ result_tokens = []
99
+ result_text = []
100
+ image = image_prepare_val(image).to(device)
101
+ with torch.no_grad():
102
+ if top_k is not None:
103
+ # logits, hid = model(image.unsqueeze(0), torch.IntTensor([tok_to_ind['<BOS>']]).unsqueeze(0).to(device), None)
104
+ logits = model(image.unsqueeze(0), torch.IntTensor([tok_to_ind['<BOS>']]).unsqueeze(0).to(device))[:, -1 , :]
105
+ prev_tokens = torch.IntTensor([tok_to_ind['<BOS>']]).unsqueeze(0).to(device)
106
+ for _ in range(max_seq_len - 1):
107
+ top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)
108
+ probs = nn.functional.softmax(top_k_logits, dim=-1)
109
+ sampled_index = torch.multinomial(probs[0], 1)
110
+ next_token = torch.squeeze(top_k_indices, dim=-2)[torch.squeeze(sampled_index).item()]
111
+ if next_token.item() == tok_to_ind['<EOS>']:
112
+ break
113
+ result_tokens.append(next_token.item())
114
+ result_text.append(ind_to_tok[next_token.item()])
115
+ # logits, hid = model(image.unsqueeze(0), next_token.unsqueeze(0).unsqueeze(0), hid)
116
+ prev_tokens = torch.concat((prev_tokens, next_token.unsqueeze(0).unsqueeze(0)), dim=-1)
117
+ logits = model(image.unsqueeze(0), prev_tokens)[:, -1 , :]
118
+ return result_tokens, ' '.join(result_text)
119
+
120
+ class ImageCaptioningConfig(PretrainedConfig):
121
+ model_type = "image_captioning_transformer"
122
+
123
+ def __init__(
124
+ self,
125
+ num_heads=6,
126
+ num_layers=3,
127
+ vocab_size=3478,
128
+ emb_size=300,
129
+ **kwargs
130
+ ):
131
+ super().__init__(**kwargs)
132
+ self.num_heads = num_heads
133
+ self.num_layers = num_layers
134
+ self.vocab_size = vocab_size
135
+ self.emb_size = emb_size
136
+
137
+ class ImageCaptioningModel(PreTrainedModel):
138
+ config_class = ImageCaptioningConfig
139
+
140
+ def __init__(self, config, original_model=None):
141
+ super().__init__(config)
142
+ if original_model is None:
143
+ # Если Π·Π°Π³Ρ€ΡƒΠΆΠ°Π΅ΠΌ с Hub, Π½ΡƒΠΆΠ½ΠΎ ΡΠΎΠ·Π΄Π°Ρ‚ΡŒ модСль ΠΈΠ· ΠΊΠΎΠ½Ρ„ΠΈΠ³Π°
144
+ self.model = image_captioning_model_transformer(
145
+ num_heads=config.num_heads,
146
+ num_layers=config.num_layers
147
+ )
148
+ else:
149
+ # Если сохраняСм, ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅ΠΌ ΡΡƒΡ‰Π΅ΡΡ‚Π²ΡƒΡŽΡ‰ΡƒΡŽ модСль
150
+ self.model = original_model
151
+
152
+ def forward(self, image, input_ids, **kwargs):
153
+ return self.model(image, input_ids)
154
+
155
+ def generate(self, image, max_seq_len=20, top_p=None, top_k=None):
156
+ """Π˜Π½Ρ‚Π΅Ρ€Ρ„Π΅ΠΉΡ для Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΠΈ тСкста"""
157
+ result_tokens, result_text = generate(
158
+ self.model,
159
+ image,
160
+ max_seq_len=max_seq_len,
161
+ top_p=top_p,
162
+ top_k=top_k
163
+ )
164
+ return {"tokens": result_tokens, "text": result_text}
165
+
166
+ channel_mean = np.array([0.4579829, 0.44630096, 0.40314582])
167
+ channel_std = np.array([0.24192157, 0.23313912, 0.23692572])
168
+
169
+ image_prepare_val = tr.Compose([
170
+ tr.Resize((224, 224)),
171
+ tr.ToTensor(),
172
+ tr.Normalize(mean=channel_mean, std=channel_std),
173
+ ])
174
+
175
+ vocab_size = 3478
176
+ config_path = hf_hub_download(
177
+ repo_id="cactuarix/image-captioning-vit-transformer",
178
+ filename="tokenizer_config.json"
179
+ )
180
+
181
+ with open(config_path, "r") as f:
182
+ tokenizer_config = json.load(f)
183
+
184
+
185
+ tok_to_ind = tokenizer_config["tok_to_ind"]
186
+ ind_to_tok = tokenizer_config["ind_to_tok"]
187
+
188
+ config = ImageCaptioningConfig.from_pretrained("cactuarix/image-captioning-vit-transformer")
189
+ model = ImageCaptioningModel.from_pretrained("cactuarix/image-captioning-vit-transformer")
190
+
191
+ old_keys = list(ind_to_tok.keys())
192
+ for key in old_keys:
193
+ ind_to_tok[int(key)] = ind_to_tok[key]
194
+
195
+ for key in old_keys:
196
+ del ind_to_tok[key]
197
+
198
+ device = torch.device('cpu')
199
+
200
+ def predict(image):
201
+ output = model.generate(image, top_k=3)
202
+ return output["text"]
203
+
204
+ iface = gr.Interface(
205
+ fn=predict,
206
+ inputs=gr.Image(type="pil"),
207
+ outputs="text",
208
+ title="Image Captioning",
209
+ description="Upload an image to generate a description"
210
+ )
211
+
212
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ asttokens==3.0.0
5
+ certifi==2025.1.31
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ comm==0.2.2
9
+ contourpy==1.3.1
10
+ cycler==0.12.1
11
+ debugpy==1.8.13
12
+ decorator==5.2.1
13
+ executing==2.2.0
14
+ fastapi==0.115.12
15
+ ffmpy==0.5.0
16
+ filelock==3.18.0
17
+ fonttools==4.56.0
18
+ fsspec==2025.3.0
19
+ gradio==5.31.0
20
+ gradio_client==1.10.1
21
+ groovy==0.1.2
22
+ h11==0.16.0
23
+ hf-xet==1.1.2
24
+ httpcore==1.0.9
25
+ httpx==0.28.1
26
+ huggingface-hub==0.32.2
27
+ idna==3.10
28
+ ipykernel==6.29.5
29
+ ipython==9.0.2
30
+ ipython_pygments_lexers==1.1.1
31
+ ipywidgets==8.1.7
32
+ jedi==0.19.2
33
+ Jinja2==3.1.6
34
+ joblib==1.4.2
35
+ jupyter_client==8.6.3
36
+ jupyter_core==5.7.2
37
+ jupyterlab_widgets==3.0.15
38
+ kiwisolver==1.4.8
39
+ markdown-it-py==3.0.0
40
+ MarkupSafe==3.0.2
41
+ matplotlib==3.10.1
42
+ matplotlib-inline==0.1.7
43
+ mdurl==0.1.2
44
+ mpmath==1.3.0
45
+ nest-asyncio==1.6.0
46
+ networkx==3.4.2
47
+ nltk==3.9.1
48
+ numpy==2.2.4
49
+ nvidia-cublas-cu12==12.4.5.8
50
+ nvidia-cuda-cupti-cu12==12.4.127
51
+ nvidia-cuda-nvrtc-cu12==12.4.127
52
+ nvidia-cuda-runtime-cu12==12.4.127
53
+ nvidia-cudnn-cu12==9.1.0.70
54
+ nvidia-cufft-cu12==11.2.1.3
55
+ nvidia-curand-cu12==10.3.5.147
56
+ nvidia-cusolver-cu12==11.6.1.9
57
+ nvidia-cusparse-cu12==12.3.1.170
58
+ nvidia-cusparselt-cu12==0.6.2
59
+ nvidia-nccl-cu12==2.21.5
60
+ nvidia-nvjitlink-cu12==12.4.127
61
+ nvidia-nvtx-cu12==12.4.127
62
+ opencv-python==4.11.0.86
63
+ orjson==3.10.18
64
+ packaging==24.2
65
+ pandas==2.2.3
66
+ parso==0.8.4
67
+ pexpect==4.9.0
68
+ pillow==11.1.0
69
+ platformdirs==4.3.7
70
+ prompt_toolkit==3.0.50
71
+ psutil==7.0.0
72
+ ptyprocess==0.7.0
73
+ pure_eval==0.2.3
74
+ pydantic==2.11.5
75
+ pydantic_core==2.33.2
76
+ pydub==0.25.1
77
+ Pygments==2.19.1
78
+ pyparsing==3.2.1
79
+ python-dateutil==2.9.0.post0
80
+ python-multipart==0.0.20
81
+ pytz==2025.1
82
+ PyYAML==6.0.2
83
+ pyzmq==26.3.0
84
+ regex==2024.11.6
85
+ requests==2.32.3
86
+ rich==14.0.0
87
+ ruff==0.11.11
88
+ safehttpx==0.1.6
89
+ safetensors==0.5.3
90
+ scikit-learn==1.6.1
91
+ scipy==1.15.2
92
+ semantic-version==2.10.0
93
+ setuptools==77.0.3
94
+ shellingham==1.5.4
95
+ six==1.17.0
96
+ sniffio==1.3.1
97
+ stack-data==0.6.3
98
+ starlette==0.46.2
99
+ sympy==1.13.1
100
+ termcolor==2.5.0
101
+ threadpoolctl==3.6.0
102
+ tokenizers==0.21.1
103
+ tomlkit==0.13.2
104
+ torch==2.6.0
105
+ torchaudio==2.6.0
106
+ torchdata==0.7.1
107
+ torchvision==0.21.0
108
+ tornado==6.4.2
109
+ tqdm==4.67.1
110
+ traitlets==5.14.3
111
+ transformers==4.52.3
112
+ triton==3.2.0
113
+ typer==0.16.0
114
+ typing-inspection==0.4.1
115
+ typing_extensions==4.12.2
116
+ tzdata==2025.1
117
+ urllib3==2.3.0
118
+ uvicorn==0.34.2
119
+ wcwidth==0.2.13
120
+ websockets==15.0.1
121
+ widgetsnbextension==4.0.14