evanec commited on
Commit
1809762
·
verified ·
1 Parent(s): c890bd5

Upload 12 files

Browse files
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Makes this directory a python package
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (140 Bytes). View file
 
src/__pycache__/inference.cpython-311.pyc ADDED
Binary file (6.25 kB). View file
 
src/__pycache__/interpretability.cpython-311.pyc ADDED
Binary file (23.2 kB). View file
 
src/__pycache__/train.cpython-311.pyc ADDED
Binary file (6.86 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
src/evaluate.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # eval.py
2
+ import os
3
+ import json
4
+ from tqdm import tqdm
5
+ import torch
6
+ from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
7
+ from src.utils import count_encoder_decoder_params, load_experiment
8
+ from src.inference import load_image, generate_caption
9
+ from PIL import Image
10
+
11
+
12
+ # COCO metrics
13
+ try:
14
+ from pycocoevalcap.cider.cider import Cider
15
+ from pycocoevalcap.rouge.rouge import Rouge
16
+ HAS_COCOEVAL = True
17
+ except ImportError:
18
+ print("WARNING: pycocoevalcap not installed → CIDEr/ROUGE disabled.")
19
+ HAS_COCOEVAL = False
20
+
21
+
22
+ def evaluate(model, tokenizer, preprocess, image_size, data_dir="data/processed", save_dir="checkpoints", device="cuda"):
23
+
24
+
25
+ captions_path = os.path.join(data_dir, "captions.json")
26
+ splits_path = os.path.join(data_dir, "splits.json")
27
+
28
+ captions = json.load(open(captions_path))
29
+ splits = json.load(open(splits_path))
30
+ val_ids = splits["val"]
31
+
32
+ preds = []
33
+ refs_tokenized = [] # for BLEU
34
+ refs_strings = [] # for JSON log
35
+
36
+ print(f"Running evaluation on {len(val_ids)} images…\n")
37
+ with torch.no_grad():
38
+ for idx, img_id in enumerate(tqdm(val_ids, desc="Evaluating")):
39
+ img_path = os.path.join(data_dir, "images", f"{int(img_id):012d}.jpg")
40
+
41
+ img_tensor = load_image(img_path, preprocess).to(device)
42
+
43
+
44
+ pred_caption = generate_caption(model, tokenizer, img_tensor, device=device)
45
+ gt_caps = captions[str(img_id)]["captions"]
46
+
47
+ # Tokenized refs for BLEU
48
+ refs_tokenized.append([c.split() for c in gt_caps])
49
+
50
+ # String refs for JSON
51
+ refs_strings.append(gt_caps)
52
+
53
+ preds.append(pred_caption)
54
+
55
+ #if idx >= 20:
56
+ # break
57
+
58
+ # Print 20 sample predictions
59
+ print("\nSample Predictions:\n")
60
+ num_examples = 20
61
+ for i in range(min(num_examples, len(preds))):
62
+ img_id = val_ids[i]
63
+ print(f"Image ID: {img_id}")
64
+ print(f"Prediction: {preds[i]}")
65
+ print(f"Ground Truths:")
66
+ for ref in refs_strings[i]:
67
+ print(f" - {ref}")
68
+ print("-" * 60)
69
+
70
+
71
+ #print("Number of preds:", len(preds))
72
+ #print("Number of refs_tokenized:", len(refs_tokenized))
73
+ #print("Example hypothesis:", preds[0])
74
+ #print("Example hypothesis tokens:", preds[0].split())
75
+ #print("Example references:", refs_strings[0])
76
+ #print("Example references tokenized:", refs_tokenized[0])
77
+
78
+ #if HAS_COCOEVAL:
79
+ # Show first 2 examples only
80
+ # for i in range(min(2, len(preds))):
81
+ # img_id = str(int(val_ids[i]))
82
+ # print(f"\nImage ID: {img_id}")
83
+ # print(" COCOEvalCap refs (list of strings):")
84
+ # print(" ", captions[img_id]["captions"])
85
+ # print(" COCOEvalCap pred:")
86
+ # print(" ", preds[i])
87
+
88
+
89
+ # BLEU
90
+ smoothie = SmoothingFunction().method3
91
+
92
+ bleu1 = corpus_bleu(
93
+ refs_tokenized, [p.split() for p in preds],
94
+ weights=(1, 0, 0, 0),
95
+ smoothing_function=smoothie
96
+ )
97
+
98
+ bleu4 = corpus_bleu(
99
+ refs_tokenized, [p.split() for p in preds],
100
+ weights=(0.25, 0.25, 0.25, 0.25),
101
+ smoothing_function=smoothie
102
+ )
103
+
104
+ scores = {"BLEU-1": bleu1, "BLEU-4": bleu4}
105
+
106
+ # CIDEr / ROUGE
107
+ if HAS_COCOEVAL:
108
+ cider_refs = {}
109
+ cider_preds = {}
110
+
111
+ for i in range(len(preds)):
112
+ img_id = val_ids[i]
113
+ cid = str(int(img_id))
114
+ cider_refs[cid] = captions[cid]["captions"]
115
+ cider_preds[cid] = [preds[i]]
116
+
117
+ #keys = list(cider_refs.keys())[:5]
118
+ #for k in keys:
119
+ # print(f"{k}: {cider_refs[k]}")
120
+
121
+ #keys = list(cider_preds.keys())[:5]
122
+ #for k in keys:
123
+ # print(f"{k}: {cider_preds[k]}")
124
+
125
+ cider = Cider()
126
+ cider_score, _ = cider.compute_score(cider_refs, cider_preds)
127
+ scores["CIDEr"] = cider_score
128
+
129
+ rouge = Rouge()
130
+ rouge_score, _ = rouge.compute_score(cider_refs, cider_preds)
131
+ scores["ROUGE-L"] = rouge_score
132
+
133
+ # Save all samples
134
+ samples_full = []
135
+ for i in range(len(preds)):
136
+ img_id = val_ids[i]
137
+ samples_full.append({
138
+ "id": int(img_id),
139
+ "prediction": preds[i],
140
+ "references": refs_strings[i],
141
+ "image": f"{int(img_id):012d}.jpg",
142
+ })
143
+
144
+ # Save a preview subset (first 20)
145
+ samples_preview = samples_full[:20]
146
+
147
+ param_info = count_encoder_decoder_params(model)
148
+
149
+ out_path = os.path.join(save_dir, "eval_results.json")
150
+
151
+ with open(out_path, "w") as f:
152
+ json.dump({
153
+ "scores": scores,
154
+ "derived_params": param_info,
155
+ "samples_preview": samples_preview,
156
+ "samples_full": samples_full
157
+ }, f, indent=2)
158
+
159
+ # Print final scores
160
+ print("\nEvaluation Scores:")
161
+ for k, v in scores.items():
162
+ print(f"{k}: {v:.4f}")
163
+
164
+ print(f"\nSaved detailed results to: {out_path}")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ import argparse
169
+ parser = argparse.ArgumentParser()
170
+
171
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/vision_t5")
172
+ parser.add_argument("--data_dir", type=str, default="data/processed")
173
+
174
+ args = parser.parse_args()
175
+
176
+ device = "cuda" if torch.cuda.is_available() else "cpu"
177
+ print(f"Device: {device}")
178
+
179
+ model, tokenizer, meta, config = load_experiment(args.checkpoint, device=device)
180
+ image_size = config["model"].get("image_size", 224)
181
+ preprocess = build_coco_transform(image_size=image_size)
182
+
183
+ evaluate(
184
+ model,
185
+ tokenizer,
186
+ preprocess=preprocess,
187
+ data_dir=args.data_dir,
188
+ save_dir=args.checkpoint,
189
+ device=device
190
+ )
191
+
192
+ # python evaluate.py --checkpoint checkpoints/vision_t5/20251117_171912
src/evaluate_batched.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluate_batched.py
2
+ import os
3
+ import json
4
+ from tqdm import tqdm
5
+ import torch
6
+ from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
7
+ from src.utils import count_encoder_decoder_params, load_experiment
8
+ from src.inference import load_image, generate_caption
9
+ from data.transforms import build_coco_transform
10
+
11
+ try:
12
+ from pycocoevalcap.cider.cider import Cider
13
+ from pycocoevalcap.rouge.rouge import Rouge
14
+ HAS_COCOEVAL = True
15
+ except ImportError:
16
+ print("WARNING: pycocoevalcap not installed → CIDEr/ROUGE disabled.")
17
+ HAS_COCOEVAL = False
18
+
19
+ # Batched Evaluation (non-breaking addition)
20
+
21
+ @torch.no_grad()
22
+ def evaluate_batched(
23
+ model,
24
+ tokenizer,
25
+ preprocess,
26
+ image_size,
27
+ data_dir="data/processed",
28
+ save_dir="checkpoints",
29
+ device="cuda",
30
+ batch_size=16,
31
+ num_beams=1,
32
+ ):
33
+ """
34
+ Batched version of evaluate().
35
+ """
36
+
37
+ from src.inference import load_images_batch, generate_captions_batch
38
+
39
+ captions_path = os.path.join(data_dir, "captions.json")
40
+ splits_path = os.path.join(data_dir, "splits.json")
41
+
42
+ captions = json.load(open(captions_path))
43
+ splits = json.load(open(splits_path))
44
+ val_ids = splits["val"]
45
+
46
+ preds = []
47
+ refs_tokenized = []
48
+ refs_strings = []
49
+
50
+ print(f"Running *batched* evaluation on {len(val_ids)} images… (batch={batch_size})\n")
51
+
52
+ # Loop in batches
53
+ for start in tqdm(range(0, len(val_ids), batch_size), desc="Evaluating (batched)"):
54
+ end = min(start + batch_size, len(val_ids))
55
+ batch_ids = val_ids[start:end]
56
+
57
+ # Image paths
58
+ img_paths = [
59
+ os.path.join(data_dir, "images", f"{int(i):012d}.jpg")
60
+ for i in batch_ids
61
+ ]
62
+
63
+ # Load batch into tensor
64
+ img_batch = load_images_batch(img_paths, preprocess, image_size).to(device)
65
+
66
+ # Generate predictions for batch
67
+ batch_preds = generate_captions_batch(
68
+ model,
69
+ tokenizer,
70
+ img_batch,
71
+ device=device,
72
+ num_beams=num_beams,
73
+ max_new_tokens=32
74
+ )
75
+
76
+ # Collect references
77
+ for i, img_id in enumerate(batch_ids):
78
+ gt_caps = captions[str(img_id)]["captions"]
79
+
80
+ refs_strings.append(gt_caps)
81
+ refs_tokenized.append([c.split() for c in gt_caps])
82
+ preds.append(batch_preds[i])
83
+
84
+ # Print sample predictions (20 samples, same as evaluate())
85
+ print("\nSample Predictions:\n")
86
+ num_examples = 20
87
+ for i in range(min(num_examples, len(preds))):
88
+ img_id = val_ids[i]
89
+ print(f"Image ID: {img_id}")
90
+ print(f"Prediction: {preds[i]}")
91
+ print("Ground Truths:")
92
+ for ref in refs_strings[i]:
93
+ print(f" - {ref}")
94
+ print("-" * 60)
95
+
96
+ # Compute BLEU
97
+ smoothie = SmoothingFunction().method3
98
+
99
+ bleu1 = corpus_bleu(
100
+ refs_tokenized, [p.split() for p in preds],
101
+ weights=(1, 0, 0, 0),
102
+ smoothing_function=smoothie
103
+ )
104
+
105
+ bleu4 = corpus_bleu(
106
+ refs_tokenized, [p.split() for p in preds],
107
+ weights=(0.25, 0.25, 0.25, 0.25),
108
+ smoothing_function=smoothie
109
+ )
110
+
111
+ scores = {"BLEU-1": bleu1, "BLEU-4": bleu4}
112
+
113
+ # CIDEr / ROUGE
114
+ if HAS_COCOEVAL:
115
+ cider_refs = {}
116
+ cider_preds = {}
117
+
118
+ for i in range(len(preds)):
119
+ img_id = val_ids[i]
120
+ cid = str(int(img_id))
121
+ cider_refs[cid] = captions[cid]["captions"]
122
+ cider_preds[cid] = [preds[i]]
123
+
124
+ cider = Cider()
125
+ cider_score, _ = cider.compute_score(cider_refs, cider_preds)
126
+ scores["CIDEr"] = cider_score
127
+
128
+ rouge = Rouge()
129
+ rouge_score, _ = rouge.compute_score(cider_refs, cider_preds)
130
+ scores["ROUGE-L"] = rouge_score
131
+
132
+ # Save results
133
+ samples_full = []
134
+ for i in range(len(preds)):
135
+ img_id = val_ids[i]
136
+ samples_full.append({
137
+ "id": int(img_id),
138
+ "prediction": preds[i],
139
+ "references": refs_strings[i],
140
+ "image": f"{int(img_id):012d}.jpg",
141
+ })
142
+
143
+ samples_preview = samples_full[:20]
144
+ param_info = count_encoder_decoder_params(model)
145
+
146
+ out_path = os.path.join(save_dir, "eval_results.json")
147
+ with open(out_path, "w") as f:
148
+ json.dump({
149
+ "scores": scores,
150
+ "derived_params": param_info,
151
+ "samples_preview": samples_preview,
152
+ "samples_full": samples_full
153
+ }, f, indent=2)
154
+
155
+ # Print final scores
156
+ print("\nEvaluation Scores:")
157
+ for k, v in scores.items():
158
+ print(f"{k}: {v:.4f}")
159
+
160
+ print(f"\nSaved batched results to: {out_path}")
161
+
162
+
163
+ if __name__ == "__main__":
164
+ import argparse
165
+ parser = argparse.ArgumentParser()
166
+
167
+ parser.add_argument("--checkpoint", type=str, required=True,
168
+ help="Path to checkpoint directory")
169
+ parser.add_argument("--data_dir", type=str, default="data/processed")
170
+ parser.add_argument("--batch_size", type=int, default=16,
171
+ help="Batch size for batched evaluation")
172
+ parser.add_argument("--num_beams", type=int, default=1,
173
+ help="For beam search")
174
+
175
+ args = parser.parse_args()
176
+
177
+ device = "cuda" if torch.cuda.is_available() else "cpu"
178
+ print(f"Device: {device}")
179
+
180
+ # Load model + tokenizer
181
+ model, tokenizer, meta, config = load_experiment(args.checkpoint, device=device)
182
+ image_size = config["model"].get("image_size", 224)
183
+ preprocess = build_coco_transform(image_size=image_size)
184
+
185
+ # Run batched evaluation
186
+ evaluate_batched(
187
+ model=model,
188
+ tokenizer=tokenizer,
189
+ preprocess=preprocess,
190
+ image_size=image_size,
191
+ data_dir=args.data_dir,
192
+ save_dir=args.checkpoint,
193
+ device=device,
194
+ batch_size=args.batch_size,
195
+ num_beams=args.num_beams
196
+ )
197
+
198
+ # Usage:
199
+ # python evaluate_batched.py --checkpoint checkpoints/vision_t5/20251117_171912 --batch_size 16
src/inference.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+
7
+ from transformers import T5TokenizerFast
8
+ from transformers.modeling_outputs import BaseModelOutput
9
+ from models.vision_t5 import VisionT5
10
+ from src.utils import load_experiment
11
+ from data.transforms import build_coco_transform
12
+
13
+
14
+
15
+ def load_image(path, preprocess):
16
+ img = Image.open(path).convert("RGB")
17
+ return preprocess(img).unsqueeze(0) # (1, 3, H, W)
18
+
19
+
20
+
21
+ @torch.no_grad()
22
+ def generate_caption(model, tokenizer, image_tensor, max_new_tokens=32, num_beams=1, device=None):
23
+ if device is None:
24
+ device = next(model.parameters()).device
25
+
26
+ model.eval()
27
+
28
+ image_tensor = image_tensor.to(device)
29
+
30
+ # Encode image
31
+ vision_out = model.vision_encoder(image_tensor)
32
+ img_embeds = vision_out["image_embeds"]
33
+
34
+ if img_embeds.dim() == 2:
35
+ img_embeds = img_embeds.unsqueeze(1)
36
+
37
+ projected = model.projector(img_embeds)
38
+
39
+ encoder_outputs = BaseModelOutput(last_hidden_state=projected)
40
+
41
+ start_token = model.t5.config.decoder_start_token_id
42
+
43
+ # explicit decoder inputs & mask (FIXES THE ERROR)
44
+ input_ids = torch.tensor([[start_token]], device=device)
45
+ attention_mask = torch.tensor([[1]], device=device)
46
+
47
+ output_ids = model.t5.generate(
48
+ encoder_outputs=encoder_outputs,
49
+ decoder_start_token_id=start_token,
50
+ input_ids=input_ids,
51
+ attention_mask=attention_mask,
52
+ num_beams=num_beams,
53
+ max_new_tokens=max_new_tokens,
54
+ )
55
+
56
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
57
+ return caption
58
+
59
+
60
+
61
+
62
+ # Batched evaluation helpers (non-breaking)
63
+
64
+ @torch.no_grad()
65
+ def load_images_batch(paths, preprocess, image_size):
66
+ resize = transforms.Resize((image_size, image_size))
67
+ tensors = []
68
+ for p in paths:
69
+ img = Image.open(p).convert("RGB")
70
+ img = resize(img)
71
+ t = preprocess(img).unsqueeze(0)
72
+ tensors.append(t)
73
+ return torch.cat(tensors, dim=0)
74
+
75
+
76
+ @torch.no_grad()
77
+ def generate_captions_batch(
78
+ model,
79
+ tokenizer,
80
+ image_batch, # (B, 3, H, W)
81
+ max_new_tokens=32,
82
+ num_beams=1,
83
+ device=None,
84
+ ):
85
+ """
86
+ Batched version of generate_caption().
87
+ Does NOT replace or modify existing generate_caption().
88
+ """
89
+
90
+ if device is None:
91
+ device = next(model.parameters()).device
92
+
93
+ model.eval()
94
+
95
+ image_batch = image_batch.to(device)
96
+
97
+ # Encode in batch
98
+ vision_out = model.vision_encoder(image_batch)
99
+ img_embeds = vision_out["image_embeds"] # (B, D) or (B, S, D)
100
+
101
+ if img_embeds.dim() == 2:
102
+ img_embeds = img_embeds.unsqueeze(1)
103
+
104
+ projected = model.projector(img_embeds) # (B, S, d_model)
105
+ encoder_outputs = BaseModelOutput(last_hidden_state=projected)
106
+
107
+ # Build batched decoder inputs
108
+ start = model.t5.config.decoder_start_token_id
109
+ B = image_batch.size(0)
110
+
111
+ input_ids = torch.full((B, 1), start, dtype=torch.long, device=device)
112
+ attention_mask = torch.ones((B, 1), dtype=torch.long, device=device)
113
+
114
+ # Standard HF batching
115
+ output_ids = model.t5.generate(
116
+ encoder_outputs=encoder_outputs,
117
+ decoder_start_token_id=start,
118
+ input_ids=input_ids,
119
+ attention_mask=attention_mask,
120
+ num_beams=num_beams,
121
+ max_new_tokens=max_new_tokens,
122
+ )
123
+
124
+ # Decode individually
125
+ return [
126
+ tokenizer.decode(ids, skip_special_tokens=True)
127
+ for ids in output_ids
128
+ ]
129
+
130
+
131
+ if __name__ == "__main__":
132
+ import argparse
133
+ parser = argparse.ArgumentParser()
134
+
135
+ parser.add_argument("--image", type=str, required=True, help="Path to image")
136
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/vision_t5")
137
+
138
+ args = parser.parse_args()
139
+
140
+ # Load model + tokenizer + config
141
+ model, tokenizer, meta, config = load_experiment(args.checkpoint)
142
+ image_size = config["model"].get("image_size", 224)
143
+ preprocess = build_coco_transform(image_size)
144
+
145
+ # Load image
146
+ image_tensor = load_image(args.image, preprocess)
147
+
148
+ # Generate caption
149
+ caption = generate_caption(model, tokenizer, image_tensor)
150
+ print("\nCaption:", caption)
151
+
src/interpretability.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from torchvision.transforms.functional import resize
6
+ from transformers.modeling_outputs import BaseModelOutput
7
+ import cv2
8
+ from transformers.models.vit.modeling_vit import ViTModel
9
+
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import matplotlib.pyplot as plt
14
+
15
+
16
+ class GradCAM:
17
+ def __init__(self, vision_encoder):
18
+ self.model = vision_encoder.model
19
+ self.target_layer = self._find_last_conv_layer()
20
+
21
+ self.activations = None
22
+ self.gradients = None
23
+
24
+ self.target_layer.register_forward_hook(self._hook_forward)
25
+ self.target_layer.register_backward_hook(self._hook_backward)
26
+
27
+ def _find_last_conv_layer(self):
28
+ for module in reversed(list(self.model.modules())):
29
+ if isinstance(module, torch.nn.Conv2d):
30
+ return module
31
+ raise RuntimeError("No Conv2D layer found for Grad-CAM.")
32
+
33
+ def _hook_forward(self, module, inp, out):
34
+ self.activations = out.detach()
35
+
36
+ def _hook_backward(self, module, grad_in, grad_out):
37
+ self.gradients = grad_out[0].detach()
38
+
39
+ def generate(self, image_tensor):
40
+ self.model.zero_grad()
41
+
42
+ out = self.model(image_tensor) # (B, C, H, W)
43
+
44
+ if out.ndim == 4:
45
+ pooled = out.mean(dim=[2, 3]) # (B, C)
46
+ elif out.ndim == 3:
47
+ pooled = out.mean(dim=1)
48
+ else:
49
+ pooled = out
50
+
51
+ score = pooled.norm()
52
+ score.backward()
53
+
54
+ weights = self.gradients.mean(dim=(2, 3), keepdim=True)
55
+ cam = (weights * self.activations).sum(dim=1).squeeze()
56
+
57
+ cam = F.relu(cam)
58
+ cam -= cam.min()
59
+ cam /= cam.max() + 1e-8
60
+
61
+ return cam.cpu().numpy()
62
+
63
+ def save(self, img_tensor, save_path):
64
+ cam = self.generate(img_tensor)
65
+
66
+ img_np = img_tensor[0].permute(1, 2, 0).cpu().numpy()
67
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
68
+
69
+ cam_resized = cv2.resize(cam, (img_np.shape[1], img_np.shape[0]))
70
+
71
+ plt.figure(figsize=(6, 6))
72
+ plt.imshow(img_np)
73
+ plt.imshow(cam_resized, cmap="inferno", alpha=0.45)
74
+ plt.axis("off")
75
+ plt.tight_layout()
76
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
77
+ plt.close()
78
+
79
+ print(f"[GradCAM] Saved to {save_path}")
80
+
81
+
82
+ def get_vit_self_attention(model, image_tensor):
83
+ vision = model.vision_encoder
84
+
85
+ if "Resnet" in type(vision).__name__:
86
+ return None
87
+
88
+ # Check for CLIP
89
+ if hasattr(vision, "model"):
90
+
91
+ if hasattr(vision.model, "vision_model"):
92
+ hf_vit = vision.model.vision_model
93
+
94
+ outputs = hf_vit(
95
+ pixel_values=image_tensor,
96
+ output_attentions=True,
97
+ return_dict=True,
98
+ )
99
+ return outputs.attentions
100
+
101
+ # Check for ViT
102
+ if isinstance(vision.model, ViTModel):
103
+
104
+ outputs = vision.model(
105
+ pixel_values=image_tensor,
106
+ output_attentions=True,
107
+ return_dict=True,
108
+ )
109
+ return outputs.attentions
110
+
111
+ raise ValueError("Vision encoder does not expose ViT attentions.")
112
+
113
+
114
+
115
+ # ATTENTION ROLLOUT (across layers)
116
+ def attention_rollout(attn_mats, discard_ratio=0.0):
117
+
118
+ device = attn_mats[0].device
119
+ result = torch.eye(attn_mats[0].size(-1), device=device)
120
+
121
+ for attn in attn_mats:
122
+ attn = attn.mean(dim=0) # average heads
123
+
124
+ if discard_ratio > 0:
125
+ flat = attn.view(-1)
126
+ threshold = flat.topk(int(flat.numel() * discard_ratio), largest=False)[0].max()
127
+ attn = torch.where(attn < threshold, torch.zeros_like(attn), attn)
128
+
129
+ attn = attn / attn.sum(dim=-1, keepdim=True)
130
+ result = attn @ result
131
+
132
+ return result
133
+
134
+
135
+ def rollout_to_image(rollout, image_size):
136
+ tokens = rollout.size(0)
137
+ num_patches = int((tokens - 1) ** 0.5)
138
+
139
+ spatial = rollout[0, 1:].reshape(num_patches, num_patches)
140
+ spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min())
141
+
142
+ spatial = resize(
143
+ spatial.unsqueeze(0).unsqueeze(0),
144
+ (image_size, image_size)
145
+ )
146
+
147
+ return spatial.squeeze().detach().cpu().numpy()
148
+
149
+
150
+ def plot_attention_overlay(image, heatmap, alpha=0.45):
151
+ if torch.is_tensor(image):
152
+ image = image.permute(1,2,0).cpu().numpy()
153
+
154
+ image = (image - image.min()) / (image.max() - image.min())
155
+
156
+ plt.figure(figsize=(6,6))
157
+ plt.imshow(image)
158
+ plt.imshow(heatmap, cmap='inferno', alpha=alpha)
159
+ plt.axis("off")
160
+ plt.show()
161
+
162
+
163
+ # GRADIENT MAP
164
+ def token_gradient_map(model, tokenizer, image_tensor, target_word, device="cuda"):
165
+ model.eval()
166
+
167
+ image_tensor = image_tensor.to(device)
168
+ image_tensor.requires_grad_(True)
169
+
170
+ vision_out = model.vision_encoder(image_tensor)
171
+ img_embeds = vision_out["image_embeds"]
172
+
173
+ if img_embeds.dim() == 2:
174
+ img_embeds = img_embeds.unsqueeze(1)
175
+
176
+ projected = model.projector(img_embeds)
177
+
178
+ encoder_outputs = BaseModelOutput(last_hidden_state=projected)
179
+
180
+ start = model.t5.config.decoder_start_token_id
181
+ decoder_input_ids = torch.tensor([[start]], device=device)
182
+
183
+ outputs = model.t5(
184
+ encoder_outputs=encoder_outputs,
185
+ decoder_input_ids=decoder_input_ids,
186
+ return_dict=True,
187
+ )
188
+
189
+ logits = outputs.logits[:, -1, :]
190
+ target_id = tokenizer.convert_tokens_to_ids(target_word)
191
+ logit = logits[0, target_id]
192
+
193
+ logit.backward()
194
+
195
+ grad = image_tensor.grad.abs().mean(dim=1).squeeze().cpu().numpy()
196
+ grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
197
+
198
+ return grad
199
+
200
+ # ATTENTION x GRAD
201
+ def attngrad(model, tokenizer, image_tensor, target_word, image_size=224, device="cuda"):
202
+
203
+ raw_attns = get_vit_self_attention(model, image_tensor.to(device))
204
+ attn_mats = [a[0] for a in raw_attns]
205
+
206
+ rollout = attention_rollout(attn_mats)
207
+ roll_map = rollout_to_image(rollout, image_size)
208
+
209
+ grad_map = token_gradient_map(model, tokenizer, image_tensor, target_word, device)
210
+
211
+ combined = roll_map * grad_map
212
+ combined = (combined - combined.min()) / (combined.max() - combined.min())
213
+ return combined
214
+
215
+
216
+
217
+ def token_gradient_map_smooth(model, tokenizer, image_tensor, target_word, sigma=5, device="cuda"):
218
+ model.eval()
219
+
220
+ image_tensor = image_tensor.to(device)
221
+ image_tensor.requires_grad_(True)
222
+
223
+ # Vision encoder
224
+ vision_out = model.vision_encoder(image_tensor)
225
+ img_embeds = vision_out["image_embeds"]
226
+ if img_embeds.dim() == 2:
227
+ img_embeds = img_embeds.unsqueeze(1)
228
+
229
+ projected = model.projector(img_embeds)
230
+
231
+
232
+ encoder_outputs = BaseModelOutput(last_hidden_state=projected)
233
+
234
+
235
+ start_token = model.t5.config.decoder_start_token_id
236
+
237
+ decoder_input_ids = torch.tensor(
238
+ [[start_token]], device=device, dtype=torch.long
239
+ )
240
+ attention_mask = torch.tensor([[1]], device=device)
241
+
242
+ outputs = model.t5(
243
+ encoder_outputs=encoder_outputs,
244
+ decoder_input_ids=decoder_input_ids,
245
+ attention_mask=attention_mask,
246
+ output_attentions=False,
247
+ output_hidden_states=False,
248
+ return_dict=True,
249
+ )
250
+
251
+ vocab_logits = outputs.logits[:, -1, :]
252
+ target_id = tokenizer.convert_tokens_to_ids(target_word)
253
+ logit = vocab_logits[0, target_id]
254
+
255
+ logit.backward()
256
+
257
+ grad = image_tensor.grad.data.abs().mean(dim=1).squeeze().cpu().numpy()
258
+ grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
259
+
260
+ grad_smooth = smooth_heatmap(grad, sigma=sigma)
261
+ return grad_smooth
262
+
263
+
264
+ def integrated_gradients(
265
+ model,
266
+ tokenizer,
267
+ image_tensor,
268
+ target_word,
269
+ steps=30,
270
+ device="cuda"
271
+ ):
272
+ model.eval()
273
+ device = torch.device(device)
274
+
275
+ image_tensor = image_tensor.to(device)
276
+ image_tensor.requires_grad_(True)
277
+
278
+ baseline = torch.zeros_like(image_tensor)
279
+
280
+ target_id = tokenizer.convert_tokens_to_ids(target_word)
281
+
282
+ total_grad = torch.zeros_like(image_tensor)
283
+
284
+ for i in range(1, steps + 1):
285
+ alpha = i / steps
286
+
287
+ img = baseline + alpha * (image_tensor - baseline)
288
+ img.requires_grad_(True)
289
+
290
+ vision_out = model.vision_encoder(img)
291
+ img_embeds = vision_out["image_embeds"]
292
+ if img_embeds.dim() == 2:
293
+ img_embeds = img_embeds.unsqueeze(1)
294
+
295
+ projected = model.projector(img_embeds)
296
+ encoder_outputs = BaseModelOutput(last_hidden_state=projected)
297
+
298
+ start_token = model.t5.config.decoder_start_token_id
299
+ decoder_input_ids = torch.tensor([[start_token]], device=device)
300
+ attention_mask = torch.tensor([[1]], device=device)
301
+
302
+ outputs = model.t5(
303
+ encoder_outputs=encoder_outputs,
304
+ decoder_input_ids=decoder_input_ids,
305
+ attention_mask=attention_mask,
306
+ return_dict=True,
307
+ )
308
+
309
+ vocab_logits = outputs.logits[:, -1, :]
310
+ logit = vocab_logits[0, target_id]
311
+
312
+ grads = torch.autograd.grad(
313
+ outputs=logit,
314
+ inputs=img,
315
+ retain_graph=True,
316
+ create_graph=False,
317
+ allow_unused=True,
318
+ )[0]
319
+
320
+ if grads is None:
321
+ raise RuntimeError("Integrated gradients: grad is None — gradient path was broken.")
322
+
323
+ total_grad += grads
324
+
325
+ avg_grad = total_grad / steps
326
+
327
+ heat = avg_grad.abs().mean(dim=1).squeeze().cpu().numpy()
328
+ heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
329
+
330
+ return heat
331
+
332
+
333
+
334
+ def smooth_heatmap(hm, k=21, sigma=6):
335
+ hm = cv2.GaussianBlur(hm, (k, k), sigma)
336
+ hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-8)
337
+ return hm
338
+
339
+
340
+ def get_cross_attention(model, encoder_outputs, decoder_input_ids, device="cuda"):
341
+
342
+ model.eval()
343
+ with torch.no_grad():
344
+ outputs = model.t5(
345
+ encoder_outputs=encoder_outputs,
346
+ decoder_input_ids=decoder_input_ids.to(device),
347
+ output_attentions=True,
348
+ return_dict=True,
349
+ )
350
+
351
+ # outputs.cross_attentions is a tuple of layers (batch, heads, tgt_len, src_len)
352
+ cross = outputs.cross_attentions
353
+ attn_layers = [c[0] for c in cross] # use batch 0
354
+ return attn_layers
355
+
356
+
357
+ """
358
+ def cross_attention_to_image(attn, image_size=224):
359
+
360
+ attn = attn.mean(dim=0) # (tgt_len, src_len)
361
+
362
+ attn = attn[-1] # (src_len,)
363
+
364
+ attn = attn[1:]
365
+
366
+ num_patches = int(attn.numel() ** 0.5)
367
+ heat = attn.reshape(num_patches, num_patches)
368
+
369
+ heat = heat - heat.min()
370
+ heat = heat / (heat.max() + 1e-8)
371
+
372
+ heat = resize(
373
+ heat.unsqueeze(0).unsqueeze(0),
374
+ (image_size, image_size)
375
+ ).squeeze()
376
+
377
+ return heat.detach().cpu().numpy()
378
+ """
379
+
380
+ def cross_attention_to_image(attn):
381
+
382
+ attn = torch.tensor(attn) if not torch.is_tensor(attn) else attn
383
+
384
+ if attn.numel() == 0:
385
+ return np.zeros((14, 14), dtype=np.float32)
386
+
387
+ if attn.dim() == 2:
388
+ attn_vec = attn[-1] # use last generated token
389
+ elif attn.dim() == 1:
390
+ attn_vec = attn
391
+ else:
392
+ raise ValueError(f"Unexpected attn shape: {attn.shape}")
393
+
394
+ # DROP CLS TOKEN (index 0) for CLIP ViT-L/14 197 tokens but 196 spatial patches
395
+ if attn_vec.size(0) == 197:
396
+ attn_vec = attn_vec[1:] # now length = 196
397
+
398
+ src_len = attn_vec.size(0)
399
+ side = int(src_len**0.5)
400
+
401
+ if side * side != src_len:
402
+ new_len = side * side
403
+ padded = torch.zeros(new_len, device=attn_vec.device)
404
+ padded[:min(new_len, src_len)] = attn_vec[:min(new_len, src_len)]
405
+ attn_vec = padded
406
+
407
+ attn_vec = attn_vec / (attn_vec.max() + 1e-8)
408
+
409
+ heatmap = attn_vec.reshape(side, side).cpu().numpy()
410
+
411
+ return heatmap
412
+
413
+
414
+
415
+ def plot_cross_attention_overlay(image_tensor, heatmap, save_path=None, alpha=0.45):
416
+ img = image_tensor[0].permute(1,2,0).cpu().numpy()
417
+ img = (img - img.min()) / (img.max() - img.min())
418
+
419
+ plt.figure(figsize=(6,6))
420
+ plt.imshow(img)
421
+ plt.imshow(heatmap, cmap='inferno', alpha=alpha)
422
+ plt.axis("off")
423
+
424
+ if save_path:
425
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
426
+ plt.close()
427
+ print(f"[CrossAttention] Saved to {save_path}")
428
+ else:
429
+ plt.show()
430
+
431
+
432
+ def visualize_cross_attention(model, tokenizer, image_tensor, word, device="cuda"):
433
+ device = torch.device(device)
434
+ image_tensor = image_tensor.to(device)
435
+
436
+ vision_out = model.vision_encoder(image_tensor)
437
+ img_embeds = vision_out["image_embeds"]
438
+ if img_embeds.dim() == 2:
439
+ img_embeds = img_embeds.unsqueeze(1)
440
+
441
+ projected = model.projector(img_embeds)
442
+ encoder_outputs = BaseModelOutput(last_hidden_state=projected)
443
+
444
+ generated = [model.t5.config.decoder_start_token_id]
445
+
446
+ for _ in range(30):
447
+ decoder_input_ids = torch.tensor([generated], device=device)
448
+ attn_layers = get_cross_attention(
449
+ model, encoder_outputs, decoder_input_ids
450
+ )
451
+
452
+ logits = model.t5(
453
+ encoder_outputs=encoder_outputs,
454
+ decoder_input_ids=decoder_input_ids,
455
+ return_dict=True
456
+ ).logits[:, -1, :]
457
+ next_id = int(logits.argmax())
458
+ generated.append(next_id)
459
+
460
+ if next_id == tokenizer.convert_tokens_to_ids(word):
461
+ break
462
+
463
+ last_attn = attn_layers[-1] # (heads, T, S)
464
+ heat = cross_attention_to_image(last_attn)
465
+
466
+ plot_cross_attention_overlay(image_tensor, heat)
src/train.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import torch
4
+ from torch.optim import AdamW
5
+ from tqdm import tqdm
6
+ from transformers import T5TokenizerFast
7
+
8
+
9
+ from models.vision_t5 import VisionT5
10
+ from models.encoder_projection_t5 import ImageProjection
11
+ import models.encoders as encoders
12
+
13
+ from data.loaders import get_coco_dataloaders
14
+ from src.inference import generate_caption
15
+ from src.utils import save_experiment, filter_kwargs, build_model
16
+ from torch.optim.lr_scheduler import CosineAnnealingLR
17
+
18
+ import math
19
+
20
+
21
+
22
+ def build_cosine_warmup_scheduler(optimizer, num_warmup_steps, num_training_steps):
23
+ def lr_lambda(step):
24
+ if step < num_warmup_steps:
25
+ return float(step) / float(max(1, num_warmup_steps))
26
+ progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
27
+ return 0.5 * (1 + math.cos(math.pi * progress)) # cosine decay
28
+
29
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
30
+
31
+
32
+
33
+ def train_one_epoch(model, dataloader, optimizer, device, scaler, scheduler):
34
+ model.train()
35
+ running_loss = 0.0
36
+
37
+ for batch in tqdm(dataloader, desc="Training"):
38
+ pixel_values = batch["pixel_values"].to(device)
39
+ input_ids = batch["input_ids"].to(device)
40
+ attention_mask = batch["attention_mask"].to(device)
41
+
42
+ # teacher forcing labels
43
+ labels = input_ids.clone()
44
+ labels[labels == model.t5.config.pad_token_id] = -100 # HF provided value to ignore in labels for loss calc.
45
+
46
+ optimizer.zero_grad()
47
+
48
+ # Using AMP to save memory
49
+ with torch.cuda.amp.autocast():
50
+ outputs = model(
51
+ pixel_values=pixel_values,
52
+ input_ids=input_ids,
53
+ attention_mask=attention_mask,
54
+ labels=labels,
55
+ )
56
+ loss = outputs.loss
57
+
58
+ scaler.scale(loss).backward()
59
+
60
+ # Gradient clipping
61
+ scaler.unscale_(optimizer)
62
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
63
+
64
+ scaler.step(optimizer)
65
+ scaler.update()
66
+ scheduler.step()
67
+ running_loss += loss.item()
68
+
69
+ return running_loss / len(dataloader)
70
+
71
+
72
+ # Validation
73
+ @torch.no_grad()
74
+ def validate(model, tokenizer, dataloader, device, preview=False):
75
+ model.eval()
76
+ running_loss = 0.0
77
+
78
+ sample_img = None
79
+ sample_gt = None
80
+
81
+ for batch in tqdm(dataloader, desc="Validation"):
82
+ pixel_values = batch["pixel_values"].to(device)
83
+ input_ids = batch["input_ids"].to(device)
84
+ attention_mask = batch["attention_mask"].to(device)
85
+
86
+ # Teacher-forcing labels
87
+ labels = input_ids.clone()
88
+ labels[labels == tokenizer.pad_token_id] = -100
89
+
90
+ outputs = model(
91
+ pixel_values=pixel_values,
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ labels=labels,
95
+ )
96
+
97
+ running_loss += outputs.loss.item()
98
+
99
+ # Store sample for preview
100
+ if preview and sample_img is None:
101
+ sample_img = pixel_values[0].detach().cpu()
102
+ # decode GT caption (first non-pad tokens)
103
+ gt_ids = input_ids[0][input_ids[0] != tokenizer.pad_token_id]
104
+ sample_gt = tokenizer.decode(gt_ids, skip_special_tokens=True)
105
+
106
+ # preview
107
+ if preview and sample_img is not None:
108
+ print("\n--- Validation Preview ---")
109
+ pred = generate_caption(model, tokenizer, sample_img.unsqueeze(0), device=device)
110
+ print("Prediction:", pred)
111
+ print("Ground Truth:", sample_gt)
112
+ print("--------------------------\n")
113
+
114
+ return running_loss / len(dataloader)
115
+
116
+
117
+
118
+ def main(config):
119
+ device = "cuda" if torch.cuda.is_available() else "cpu"
120
+
121
+ # Model + tokenizer
122
+ model, tokenizer = build_model(config)
123
+ model.to(device)
124
+
125
+ # Data
126
+ batch_size = config["training"]["batch_size"]
127
+ image_size = config["model"].get("image_size", 224)
128
+ train_loader, val_loader, _ = get_coco_dataloaders(batch_size=batch_size, data_dir=config["paths"]["data_dir"], image_size=image_size)
129
+
130
+ optimizer = AdamW(model.parameters(), lr=config["training"]["lr"])
131
+ scaler = torch.cuda.amp.GradScaler() # For mixed precision
132
+
133
+ num_training_steps = len(train_loader) * config["training"]["epochs"]
134
+ num_warmup_steps = int(0.05 * num_training_steps)
135
+ scheduler = build_cosine_warmup_scheduler(
136
+ optimizer,
137
+ num_warmup_steps=num_warmup_steps,
138
+ num_training_steps=num_training_steps
139
+ )
140
+ best_val = float("inf")
141
+ best_epoch = -1
142
+
143
+ # Train loop
144
+ for epoch in range(1, config["training"]["epochs"] + 1):
145
+ print(f"\nEpoch {epoch}/{config['training']['epochs']}")
146
+
147
+ train_loss = train_one_epoch(model, train_loader, optimizer, device, scaler, scheduler)
148
+ print("Train Loss:", train_loss)
149
+
150
+ val_loss = validate(model, tokenizer, val_loader, device, preview=config["training"]["preview_val"])
151
+ print("Val Loss:", val_loss)
152
+
153
+ if val_loss < best_val:
154
+ best_val = val_loss
155
+ best_epoch = epoch
156
+
157
+ save_experiment(
158
+ model=model,
159
+ tokenizer=tokenizer,
160
+ config=config,
161
+ save_dir=config["paths"]["output_dir"],
162
+ notes=f"BEST checkpoint epoch={epoch}, val_loss={val_loss:.4f}"
163
+ )
164
+ print(f"[CHECKPOINT] Saved new BEST model at epoch {epoch}")
165
+
166
+
167
+
168
+ if __name__ == "__main__":
169
+ parser = argparse.ArgumentParser()
170
+ parser.add_argument("--config", type=str, required=True)
171
+ args = parser.parse_args()
172
+
173
+ with open(args.config, "r") as f:
174
+ config = yaml.safe_load(f)
175
+
176
+ main(config)
src/utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ import os
3
+ import yaml
4
+ import torch
5
+ from datetime import datetime
6
+
7
+ from transformers import T5TokenizerFast
8
+ from models.vision_t5 import VisionT5
9
+ import models.encoders as encoders
10
+ from models.encoder_projection_t5 import ImageProjection
11
+ import inspect
12
+
13
+
14
+
15
+ def timestamp():
16
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
17
+
18
+
19
+
20
+ def save_experiment(model, tokenizer, config, save_dir, notes="", run_name=None, add_timestamp=True):
21
+
22
+ if add_timestamp:
23
+ tag = timestamp()
24
+ if run_name:
25
+ save_dir = os.path.join(save_dir, f"{run_name}_{tag}")
26
+ else:
27
+ save_dir = os.path.join(save_dir, tag)
28
+
29
+ os.makedirs(save_dir, exist_ok=True)
30
+
31
+ torch.save(model.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
32
+
33
+ tok_dir = os.path.join(save_dir, "tokenizer")
34
+ os.makedirs(tok_dir, exist_ok=True)
35
+ tokenizer.save_pretrained(tok_dir)
36
+
37
+ with open(os.path.join(save_dir, "config_trained.yaml"), "w") as f:
38
+ yaml.safe_dump(config, f)
39
+
40
+ metadata = {
41
+ "encoder": config["model"]["encoder"],
42
+ "encoder_params": config["model"].get("encoder_params", {}),
43
+ "decoder": config["model"]["t5_name"],
44
+ "decoder_params": config["model"].get("decoder_params", {}),
45
+ "train_epochs": config["training"]["epochs"],
46
+ "batch_size": config["training"]["batch_size"],
47
+ "lr": config["training"]["lr"],
48
+ "notes": notes,
49
+ "run_name": run_name,
50
+ "timestamp": timestamp(),
51
+ }
52
+
53
+ with open(os.path.join(save_dir, "metadata.yaml"), "w") as f:
54
+ yaml.safe_dump(metadata, f)
55
+
56
+ print(f"[OK] Experiment saved → {save_dir}")
57
+ return save_dir
58
+
59
+
60
+
61
+
62
+ def load_experiment(checkpoint_dir, device="cpu"):
63
+ import yaml, torch, os
64
+
65
+ metadata_path = os.path.join(checkpoint_dir, "metadata.yaml")
66
+ config_path = os.path.join(checkpoint_dir, "config_trained.yaml")
67
+
68
+ if not os.path.exists(metadata_path):
69
+ raise FileNotFoundError(f"No metadata.yaml found at {checkpoint_dir}")
70
+ if not os.path.exists(config_path):
71
+ raise FileNotFoundError(f"No config_trained.yaml found at {checkpoint_dir}")
72
+
73
+ with open(metadata_path, "r") as f:
74
+ metadata = yaml.safe_load(f)
75
+
76
+ with open(config_path, "r") as f:
77
+ config = yaml.safe_load(f)
78
+
79
+
80
+ model, tokenizer = build_model(config)
81
+
82
+ tok_dir = os.path.join(checkpoint_dir, "tokenizer")
83
+ if os.path.isdir(tok_dir):
84
+ tokenizer = T5TokenizerFast.from_pretrained(tok_dir)
85
+
86
+ ckpt_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
87
+ weights = torch.load(ckpt_path, map_location=device)
88
+ model.load_state_dict(weights, strict=False)
89
+
90
+ model.to(device)
91
+ model.eval()
92
+
93
+ print(f"Loaded experiment from {checkpoint_dir}")
94
+ return model, tokenizer, metadata, config
95
+
96
+
97
+
98
+ def filter_kwargs(cls, kwargs):
99
+ sig = inspect.signature(cls.__init__).parameters
100
+ return {k: v for k, v in kwargs.items() if k in sig}
101
+
102
+
103
+
104
+ def build_model(config):
105
+
106
+ encoder_name = config["model"]["encoder"]
107
+ raw_encoder_params = config["model"].get("encoder_params", {})
108
+
109
+ t5_name = config["model"]["t5_name"]
110
+ decoder_params = config["model"].get("decoder_params", {})
111
+
112
+ tokenizer = T5TokenizerFast.from_pretrained(t5_name)
113
+
114
+ # dynamically load encoder class
115
+ if not hasattr(encoders, encoder_name):
116
+ raise ValueError(f"Encoder '{encoder_name}' not found in encoders.py")
117
+
118
+ EncoderClass = getattr(encoders, encoder_name)
119
+
120
+ encoder_params = filter_kwargs(EncoderClass, raw_encoder_params)
121
+
122
+ # Instantiate encoder
123
+ vision_encoder = EncoderClass(**encoder_params)
124
+
125
+ # Projection layer
126
+ t5_hidden = VisionT5.get_t5_hidden_size(t5_name)
127
+ projector = ImageProjection(
128
+ encoder_dim=vision_encoder.get_output_dim(),
129
+ t5_hidden_size=t5_hidden
130
+ )
131
+
132
+ # Construct model
133
+ model = VisionT5(
134
+ vision_encoder=vision_encoder,
135
+ projector=projector,
136
+ t5_name=t5_name,
137
+ decoder_params=decoder_params
138
+ )
139
+
140
+ return model, tokenizer
141
+
142
+
143
+
144
+ def load_yaml(path):
145
+ with open(path, "r") as f:
146
+ return yaml.safe_load(f)
147
+
148
+
149
+
150
+ def count_encoder_decoder_params(model):
151
+
152
+ enc_total = enc_train = 0
153
+ proj_total = proj_train = 0
154
+ dec_total = dec_train = 0
155
+ other_total = other_train = 0
156
+
157
+ for name, p in model.named_parameters():
158
+ n = p.numel()
159
+
160
+ # Vision Encoder
161
+ if name.startswith("vision_encoder."):
162
+ enc_total += n
163
+ if p.requires_grad:
164
+ enc_train += n
165
+ continue
166
+
167
+ # Projector
168
+ if name.startswith("projector."):
169
+ proj_total += n
170
+ if p.requires_grad:
171
+ proj_train += n
172
+ continue
173
+
174
+ # T5 Decoder (covers small, base, large, AND LoRA)
175
+ if (
176
+ name.startswith("t5.decoder.") or
177
+ "decoder.block" in name or
178
+ name.startswith("t5.model.decoder.") or
179
+ name.startswith("t5.lm_head.") or
180
+ name.startswith("t5.shared.")
181
+ ):
182
+ dec_total += n
183
+ if p.requires_grad:
184
+ dec_train += n
185
+ continue
186
+
187
+ if "lora_" in name and "decoder" in name:
188
+ dec_total += n
189
+ if p.requires_grad:
190
+ dec_train += n
191
+ continue
192
+
193
+ # T5 Encoder (always frozen)
194
+ if name.startswith("t5.encoder."):
195
+ other_total += n
196
+ if p.requires_grad:
197
+ other_train += n
198
+ continue
199
+
200
+ # Other params
201
+ other_total += n
202
+ if p.requires_grad:
203
+ other_train += n
204
+
205
+ total_params = enc_total + proj_total + dec_total + other_total
206
+ trainable_params = enc_train + proj_train + dec_train + other_train
207
+
208
+ return {
209
+ "encoder_total_params": enc_total,
210
+ "encoder_trainable_params": enc_train,
211
+ "encoder_trainable_fraction":
212
+ enc_train / enc_total if enc_total else None,
213
+
214
+ "projector_total_params": proj_total,
215
+ "projector_trainable_params": proj_train,
216
+ "projector_trainable_fraction":
217
+ proj_train / proj_total if proj_total else None,
218
+
219
+ "decoder_total_params": dec_total,
220
+ "decoder_trainable_params": dec_train,
221
+ "decoder_trainable_fraction":
222
+ dec_train / dec_total if dec_total else None,
223
+
224
+ "other_total_params": other_total,
225
+ "other_trainable_params": other_train,
226
+
227
+ "total_params": total_params,
228
+ "trainable_params": trainable_params,
229
+ "trainable_params_fraction":
230
+ trainable_params / total_params if total_params else None,
231
+ }
232
+
233
+
234
+ def classify_param(name):
235
+
236
+ if name.startswith("vision_encoder."):
237
+ return "encoder"
238
+
239
+ if name.startswith("projector."):
240
+ return "projector"
241
+
242
+ if (
243
+ name.startswith("t5.decoder.") or
244
+ name.startswith("t5.model.decoder.") or
245
+ "decoder.block" in name or
246
+ name.startswith("t5.lm_head.") or
247
+ name.startswith("t5.shared.") or
248
+ ("lora_" in name and "decoder" in name)
249
+ ):
250
+ return "decoder"
251
+
252
+ if name.startswith("t5.encoder."):
253
+ return "t5_encoder_frozen"
254
+
255
+ return "other"
256
+