English
poonai commited on
Commit
6eff0e6
·
verified ·
1 Parent(s): e1dca07

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test_image_0.png filter=lfs diff=lfs merge=lfs -text
37
+ test_image_1.png filter=lfs diff=lfs merge=lfs -text
38
+ test_image_2.png filter=lfs diff=lfs merge=lfs -text
39
+ test_image_3.png filter=lfs diff=lfs merge=lfs -text
40
+ test_image_4.png filter=lfs diff=lfs merge=lfs -text
41
+ test_image_5.png filter=lfs diff=lfs merge=lfs -text
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ import model
checkpoint-epoch=01-loss=0.13.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2190cb9a3ba864e44fd8e28cb57595b043baf5b2ee32b4386c9b2d637164a24e
3
+ size 1774312061
inference.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import model
2
+ import datasets
3
+ from PIL import Image
4
+ vlm = model.ImageNetCaptionModel.load_from_checkpoint('checkpoint-epoch=01-loss=0.13.ckpt')
5
+
6
+ image = Image.open("test_image_5.png")
7
+ print(vlm.generate(image=image))
model.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import comet_ml
3
+ import datasets
4
+ import evaluate
5
+ import lightning as L
6
+ import torch
7
+ from timm import create_model, data
8
+ from tokenizers import Tokenizer
9
+ from torch import nn
10
+ from torch.utils.data import DataLoader
11
+ from transformers import (
12
+ GPT2LMHeadModel,
13
+ )
14
+ from lightning.pytorch.loggers import TensorBoardLogger
15
+ from lightning.pytorch.callbacks import ModelCheckpoint
16
+
17
+
18
+ eos_token_id = 50256 # obtained from gpt model
19
+
20
+
21
+ class Projection(nn.Module):
22
+ def __init__(self, in_features, out_features):
23
+ super().__init__()
24
+ self.network = nn.Sequential(
25
+ nn.Linear(in_features, in_features * 3),
26
+ nn.GELU(),
27
+ nn.Linear(in_features * 3, out_features),
28
+ )
29
+
30
+ def forward(self, input):
31
+ return self.network(input)
32
+
33
+
34
+ class ImageNetCaptionModel(L.LightningModule):
35
+ def __init__(self):
36
+ super().__init__()
37
+ # backbone model to extract image feature token
38
+ self.backbone = create_model(
39
+ "vit_mediumd_patch16_reg4_gap_384", pretrained=True
40
+ )
41
+ self.llm = GPT2LMHeadModel.from_pretrained("gpt2")
42
+
43
+ self.image_start_token = "<image_start>"
44
+ self.image_end_token = "<image_end>"
45
+ self.tokenizer = Tokenizer.from_pretrained("gpt2")
46
+ self.tokenizer.add_special_tokens(
47
+ [self.image_start_token, self.image_end_token]
48
+ )
49
+ self.image_start_token_id = self.tokenizer.token_to_id(self.image_start_token)
50
+ self.image_end_token_id = self.tokenizer.token_to_id(self.image_end_token)
51
+ self.eos_token = eos_token_id
52
+
53
+ self.llm.resize_token_embeddings(self.tokenizer.get_vocab_size())
54
+ self.embedding = self.llm.get_input_embeddings()
55
+
56
+ self.projection = Projection(
57
+ in_features=512, out_features=self.llm.config.hidden_size
58
+ )
59
+
60
+ self.bleu_metric = evaluate.load("bleu")
61
+ self.meteor_metric = evaluate.load("meteor")
62
+
63
+ ## freeze backbone and gpt models.
64
+ for param in self.backbone.parameters():
65
+ param.requires_grad = False
66
+
67
+ for param in self.llm.parameters():
68
+ param.requires_grad = True
69
+
70
+ def get_tokenizer(self):
71
+ return self.tokenizer
72
+
73
+ def forward(self, image=None, input_caption=None, **kwargs):
74
+ image_feature = self.backbone.forward_features(image)
75
+ projection = self.projection(image_feature)
76
+ input_caption_embedding = self.embedding(input=input_caption)
77
+
78
+ # concat start_image_token + projection + end_image_token + input_caption
79
+ image_start_token, image_end_token = self.get_image_seperation_token(
80
+ image=image
81
+ )
82
+ input_embedding = torch.cat(
83
+ [image_start_token, projection, image_end_token, input_caption_embedding],
84
+ dim=1,
85
+ )
86
+ attention_mask = torch.ones(
87
+ input_embedding.size()[:-1], dtype=torch.long, device=image.device
88
+ )
89
+
90
+ labels = torch.full(
91
+ (input_embedding.size(0), input_embedding.size(1)),
92
+ -100,
93
+ dtype=torch.long,
94
+ device=image.device,
95
+ )
96
+ labels[:, projection.size(1) + 2 :] = input_caption # align text labels
97
+
98
+ llm_output = self.llm(
99
+ inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels
100
+ )
101
+ return llm_output
102
+
103
+ def training_step(self, batch, batch_idx):
104
+ output = self.forward(**batch)
105
+ self.log("loss", output.loss.item())
106
+ return output.loss
107
+
108
+ def validation_step(self, batch, batch_idx):
109
+ if batch_idx < 5:
110
+ pred = self.predict_step(batch=batch, batch_idx=batch_idx)
111
+ print(
112
+ "evaluation ",
113
+ "pred",
114
+ pred,
115
+ "original caption",
116
+ batch["original_caption_enriched"],
117
+ )
118
+ bleu = self.bleu_metric.compute(
119
+ predictions=pred, references=batch["original_caption_enriched"]
120
+ )
121
+ self.log("bleu", bleu["bleu"])
122
+ self.log("precision", bleu["brevity_penalty"])
123
+ metor = self.meteor_metric.compute(
124
+ predictions=pred, references=batch["original_caption_enriched"]
125
+ )
126
+ print(metor)
127
+ self.log_dict(metor)
128
+
129
+ def get_image_seperation_token(self, image):
130
+ image_start_embedding = self.embedding(
131
+ torch.tensor([self.image_start_token_id], device=image.device)
132
+ )
133
+ image_end_embedding = self.embedding(
134
+ torch.tensor([self.image_end_token_id], device=image.device)
135
+ )
136
+ image_start_token = image_start_embedding.unsqueeze(0).repeat(len(image), 1, 1)
137
+ image_end_token = image_end_embedding.unsqueeze(0).repeat(len(image), 1, 1)
138
+
139
+ return image_start_token, image_end_token
140
+
141
+ def configure_optimizers(self):
142
+ proj_params = [p for p in self.projection.parameters() if p.requires_grad]
143
+ llm_params = [p for p in self.llm.parameters() if p.requires_grad]
144
+
145
+ optimizer = torch.optim.AdamW(
146
+ [
147
+ {"params": proj_params, "lr": 1e-4, "weight_decay": 0.01},
148
+ {"params": llm_params, "lr": 5e-6, "weight_decay": 0.01},
149
+ ]
150
+ )
151
+
152
+ return optimizer
153
+
154
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
155
+ image = batch["image"]
156
+ image_feature = self.backbone.forward_features(image)
157
+ projection = self.projection(image_feature)
158
+
159
+ image_start_embedding = self.embedding(
160
+ torch.tensor([self.image_start_token_id], device=image.device)
161
+ )
162
+ image_end_embedding = self.embedding(
163
+ torch.tensor([self.image_end_token_id], device=image.device)
164
+ )
165
+ input_start_image_embedding_batch = image_start_embedding.unsqueeze(0).repeat(
166
+ len(image), 1, 1
167
+ )
168
+ input_end_image_embedding_batch = image_end_embedding.unsqueeze(0).repeat(
169
+ len(image), 1, 1
170
+ )
171
+
172
+ input_embedding = torch.cat(
173
+ [
174
+ input_start_image_embedding_batch,
175
+ projection,
176
+ input_end_image_embedding_batch,
177
+ ],
178
+ dim=1,
179
+ )
180
+ attention_mask = torch.ones(
181
+ input_embedding.size()[:-1], dtype=torch.long, device=image.device
182
+ )
183
+
184
+ outputs = self.llm.generate(
185
+ inputs_embeds=input_embedding,
186
+ attention_mask=attention_mask,
187
+ eos_token_id=0,
188
+ max_new_tokens=30,
189
+ do_sample=True, # add randomness
190
+ top_p=0.9, # nucleus sampling
191
+ temperature=0.7,
192
+ )
193
+
194
+ # Convert tensor to list of lists for decode_batch
195
+ if outputs.dim() == 2:
196
+ # outputs is [batch_size, sequence_length], convert to list of lists
197
+ outputs_list = outputs.tolist()
198
+ else:
199
+ # outputs is already a list/sequence
200
+ outputs_list = outputs
201
+
202
+ return self.tokenizer.decode_batch(outputs_list, skip_special_tokens=True)
203
+
204
+ def generate(self, image):
205
+ data_config = data.resolve_model_data_config(
206
+ create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True)
207
+ )
208
+ transforms = data.create_transform(**data_config, is_training=False)
209
+ image = transforms(image)
210
+
211
+ return self.predict_step(batch={"image":image.unsqueeze(0)},batch_idx=0)[0]
212
+
213
+
214
+
215
+ def collate_fn(batch):
216
+ collected = {"image": [], "input_caption": [], "original_caption_enriched": []}
217
+
218
+ for data in batch:
219
+ collected["image"].append(torch.tensor(data["image"], dtype=torch.float))
220
+ collected["input_caption"].append(
221
+ torch.tensor(data["input_caption"], dtype=torch.long)
222
+ )
223
+ collected["original_caption_enriched"].append(data["original_caption_enriched"])
224
+
225
+ return {
226
+ "image": torch.stack(collected["image"], dim=0),
227
+ "input_caption": torch.stack(collected["input_caption"], dim=0),
228
+ "original_caption_enriched": collected["original_caption_enriched"],
229
+ }
230
+
231
+
232
+ def agument(tokenizer: Tokenizer):
233
+ data_config = data.resolve_model_data_config(
234
+ create_model("vit_mediumd_patch16_reg4_gap_384", pretrained=True)
235
+ )
236
+ transforms = data.create_transform(**data_config, is_training=False)
237
+
238
+ def transform(data):
239
+ ids = tokenizer.encode(data["caption_enriched"])
240
+
241
+ # Handle sequences based on length
242
+ if len(ids.ids) <= 59:
243
+ # For short sequences, just append EOS
244
+ ids.ids.append(eos_token_id)
245
+ else:
246
+ # For long sequences, truncate to 59 tokens and append EOS
247
+ ids.ids = ids.ids[:59]
248
+ ids.ids.append(eos_token_id)
249
+
250
+ # Pad to exactly 60 tokens
251
+ ids.ids = ids.ids[:60] # Ensure we don't exceed 60
252
+ ids.pad(60)
253
+
254
+ decoded = tokenizer.decode(ids.ids, skip_special_tokens=True)
255
+ print("original", data["caption_enriched"], "decoded", decoded)
256
+
257
+ data["input_caption"] = torch.tensor(ids.ids, dtype=torch.long)
258
+
259
+ data["original_caption_enriched"] = data["caption_enriched"]
260
+ data["image"] = transforms(data["image"])
261
+ return data
262
+
263
+ return transform
264
+
265
+
266
+ def is_valid_image(example):
267
+ try:
268
+ # Try opening the image
269
+ if example["image"].mode == "RGB":
270
+ return True
271
+
272
+ return False
273
+ except Exception as e:
274
+ # ValueError will catch the MAX_TEXT_CHUNK error
275
+ print("false", example["image"])
276
+ print("Exception:", e)
277
+ return False
278
+
279
+
280
+ def train(
281
+ root_path: Path,
282
+ dataset: datasets.Dataset,
283
+ num_loader_worker: int = 0,
284
+ batch_size=16,
285
+ logger=None,
286
+ ):
287
+ # dataset = datasets.load_dataset("visual-layer/imagenet-1k-vl-enriched", split="validation").shuffle(seed=42).select(range(20000)).train_test_split(test_size=0.1)
288
+ test_ds = dataset["test"]
289
+ train_ds = dataset["train"]
290
+
291
+ model = ImageNetCaptionModel()
292
+
293
+ tokenizer = model.get_tokenizer()
294
+
295
+ # Apply transformation to both datasets
296
+ train_ds = train_ds.filter(is_valid_image)
297
+ train_ds = train_ds.map(agument(tokenizer=tokenizer))
298
+
299
+ test_ds = test_ds.filter(is_valid_image)
300
+ test_ds = test_ds.map(agument(tokenizer=tokenizer))
301
+
302
+ train_data_loader = DataLoader(
303
+ dataset=train_ds,
304
+ drop_last=True,
305
+ batch_size=batch_size,
306
+ collate_fn=collate_fn,
307
+ num_workers=num_loader_worker,
308
+ )
309
+ evaluation_data_loader = DataLoader(
310
+ dataset=test_ds,
311
+ drop_last=True,
312
+ batch_size=batch_size,
313
+ collate_fn=collate_fn,
314
+ num_workers=num_loader_worker,
315
+ )
316
+
317
+ if logger is None:
318
+ logger = TensorBoardLogger(save_dir=str(root_path), version=1, name="logs")
319
+ checkpoint_callback = ModelCheckpoint(
320
+ dirpath=root_path / "checkpoint",
321
+ filename="checkpoint-{epoch:02d}-{loss:.2f}",
322
+ every_n_epochs=1,
323
+ save_top_k=-1,
324
+ )
325
+ print("path", root_path)
326
+ trainer = L.Trainer(
327
+ logger=logger,
328
+ max_epochs=2,
329
+ default_root_dir=root_path,
330
+ callbacks=[checkpoint_callback],
331
+ )
332
+ trainer.fit(
333
+ model=model,
334
+ train_dataloaders=train_data_loader,
335
+ val_dataloaders=evaluation_data_loader,
336
+ )
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "imagenet-caption"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "comet-ml>=3.52.1",
9
+ "datasets[vision]>=4.0.0",
10
+ "evaluate[bleu,meteor]>=0.4.6",
11
+ "lightning>=2.5.5",
12
+ "modal>=1.1.4",
13
+ "nltk>=3.9.1",
14
+ "numpy>=2.3.3",
15
+ "pillow>=11.3.0",
16
+ "tensorboard>=2.20.0",
17
+ "tensorboardx>=2.6.4",
18
+ "timm>=1.0.19",
19
+ "tokenizers>=0.22.0",
20
+ "torch>=2.8.0",
21
+ "transformers[torch]>=4.56.1",
22
+ ]
23
+
24
+ [dependency-groups]
25
+ dev = [
26
+ "ipykernel>=6.30.1",
27
+ ]
test_image_0.png ADDED

Git LFS Details

  • SHA256: 0113fbe513067377deb1d9563642b5376e9d8d9f6162b8f95e7870b26ec647bd
  • Pointer size: 131 Bytes
  • Size of remote file: 480 kB
test_image_1.png ADDED

Git LFS Details

  • SHA256: 46cb6a8c672682c8d2d900b06dd71aead680140b9b92e12e626db7017278880f
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
test_image_2.png ADDED

Git LFS Details

  • SHA256: 5ebd4498fef9d60f106b54f73d4bf6f4671cee4cea0d8b23042e9d8ead60417d
  • Pointer size: 131 Bytes
  • Size of remote file: 463 kB
test_image_3.png ADDED

Git LFS Details

  • SHA256: 116bfcf0dbca6d82595ba5fb065fa828286c019f181753b92d43c5d49203b6a0
  • Pointer size: 131 Bytes
  • Size of remote file: 371 kB
test_image_4.png ADDED

Git LFS Details

  • SHA256: 04fca2a4fbfa3b36bb77162566477851001aa93ba97c03a430a5657af91bf215
  • Pointer size: 131 Bytes
  • Size of remote file: 407 kB
test_image_5.png ADDED

Git LFS Details

  • SHA256: 0cbbe7ccb2063fca831a5287054f054f7ea77f45ee3981dfea91baada21a81b2
  • Pointer size: 131 Bytes
  • Size of remote file: 401 kB