mbiswas commited on
Commit
b781107
·
verified ·
1 Parent(s): d2b2a20

Upload 10 files

Browse files
Files changed (10) hide show
  1. constants.py +30 -0
  2. dataset.py +349 -0
  3. decoder_language_model.py +165 -0
  4. finetune_lm_head_ce_loss.py +418 -0
  5. infer.py +504 -0
  6. model_components.py +163 -0
  7. train.py +264 -0
  8. train_stage_2.py +267 -0
  9. utils.py +57 -0
  10. vision_language_model.py +400 -0
constants.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ IMAGE_SIZE = 512
4
+ PATCH_SIZE = 16
5
+ HIDDEN_DIM = 256
6
+ CONTEXT_LENGTH = 1536
7
+ TEXT_LENGTH = 512 # Max length for *target* sequence (coords)
8
+ PROMPT_LENGTH = 64 # Max length for *prompt* sequence (description) - Adjust as needed
9
+ DROPOUT = 0.1
10
+ NUM_HEADS = 8
11
+ NUM_LAYERS = 12 # Keep moderate layers
12
+ BATCH_SIZE = 16
13
+ LEARNING_RATE = 1e-3 # Lower LR might be needed with contrastive loss
14
+ DTYPE = torch.float32 # torch.bfloat16 created some instability, why?
15
+ GRAD_ACCUMULATION_STEPS = 16
16
+ IMAGE_MEAN = [0.485, 0.456, 0.406]
17
+ IMAGE_STD = [0.229, 0.224, 0.225]
18
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ IMAGE_LOCATION = "./images/"
20
+ NUM_BINS = 32
21
+ SHARED_EMBED_DIM = 256 # Dimension for contrastive space
22
+ NUM_BINS = 32
23
+ MAX_POINTS = 10 # Maximum number of points per image to handle
24
+
25
+ # Training loop constants
26
+ NUM_EPOCHS = 400 # desired number of epochs
27
+ LOGGING_STEPS = 1 # Log every N optimization steps
28
+ MAX_GRAD_NORM = 1.0
29
+ LAMBDA_CONTRASTIVE = 2 # Weight for contrastive loss - TUNE THIS
30
+ LAMBDA_REGRESSION = 2 # Works but noisy
dataset.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm.auto import tqdm
2
+ from constants import *
3
+ from utils import *
4
+ import pickle
5
+ from torch.utils.data import Dataset
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ from torch.utils.data import DataLoader
10
+ import os
11
+
12
+ def format_point_text(points):
13
+ # This function should already handle multiple points correctly
14
+ text = "<result_start>"
15
+ for point in points:
16
+ # Ensure point coordinates are within [0, 100] before processing
17
+ px = min(max(int(point.get('x', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) # Added .get for safety
18
+ py = min(max(int(point.get('y', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1)
19
+ x_bin = min(px // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
20
+ y_bin = min(py // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
21
+ text += f"<pointx_start><coord_bin_{x_bin}><pointx_end><pointy_start><coord_bin_{y_bin}><pointy_end>"
22
+ text += "<result_end>" + tokenizer.eos_token
23
+ return text
24
+
25
+ def format_data_for_training(sample):
26
+ """Format data sample for training, handling 0 to MAX_POINTS continuous coordinates."""
27
+ try:
28
+ # Check if 'points' key exists and is a list, otherwise treat as 0 points
29
+ sample_points = sample.get('points', [])
30
+ if not isinstance(sample_points, list):
31
+ print(f"Warning: Invalid 'points' type for {sample.get('image_url', 'N/A')}. Treating as 0 points.")
32
+ sample_points = []
33
+
34
+ # Limit the number of points processed
35
+ points_to_process = sample_points[:MAX_POINTS]
36
+ num_points = len(points_to_process)
37
+
38
+ # Load image - this is where most memory is used
39
+ image_path = f"{IMAGE_LOCATION}{sample['image_url']}"
40
+
41
+ # Check if file exists before attempting to open
42
+ if not os.path.exists(image_path):
43
+ print(f"Warning: Image not found: {image_path}. Skipping.")
44
+ return None
45
+
46
+ # Open image with error handling
47
+ try:
48
+ image = Image.open(image_path)
49
+ # Convert grayscale to RGB if needed
50
+ if image.mode != 'RGB':
51
+ image = image.convert('RGB')
52
+ image_tensor = image_to_tensor(image)
53
+ # Explicitly delete the PIL image to free memory
54
+ del image
55
+ except Exception as e:
56
+ print(f"Error processing image {image_path}: {e}")
57
+ return None
58
+
59
+ # Process text with memory efficiency in mind
60
+ prompt_text = f"<point_start>{sample['label']}<point_end>"
61
+ # format_point_text correctly handles an empty points_to_process list
62
+ target_text = format_point_text(points_to_process)
63
+
64
+ # Tokenize with explicit max lengths
65
+ prompt_tokens = tokenizer(prompt_text, return_tensors="pt", max_length=PROMPT_LENGTH,
66
+ truncation=True, padding=False)
67
+ target_tokens = tokenizer(target_text, return_tensors="pt", max_length=TEXT_LENGTH,
68
+ truncation=True, padding=False)
69
+
70
+ # Check for empty tokens after tokenization
71
+ if prompt_tokens.input_ids.numel() == 0 or target_tokens.input_ids.numel() == 0:
72
+ print(f"Warning: Empty tokens after tokenization for {sample.get('image_url', 'N/A')}. Skipping.")
73
+ return None
74
+
75
+ # --- Handle Multiple Continuous Coordinates with Padding (Handles num_points=0 correctly) ---
76
+ continuous_coords_list = []
77
+ for point in points_to_process: # This loop won't run if num_points is 0
78
+ coord_x = min(max(point.get('x', 50) / 100.0, 0.0), 1.0)
79
+ coord_y = min(max(point.get('y', 50) / 100.0, 0.0), 1.0)
80
+ continuous_coords_list.append([coord_x, coord_y])
81
+
82
+ # Pad coordinates and create mask
83
+ # If continuous_coords_list is empty, create empty tensor with right shape
84
+ if num_points == 0:
85
+ padded_coords = torch.full((MAX_POINTS, 2), -1.0)
86
+ coords_mask = torch.zeros(MAX_POINTS)
87
+ else:
88
+ coords_tensor = torch.tensor(continuous_coords_list, dtype=torch.float32)
89
+ padding_needed = MAX_POINTS - num_points
90
+ padded_coords = F.pad(coords_tensor, (0, 0, 0, padding_needed), value=-1.0)
91
+ coords_mask = torch.cat([torch.ones(num_points, dtype=torch.float32),
92
+ torch.zeros(padding_needed, dtype=torch.float32)])
93
+
94
+ # Create and return the formatted sample
95
+ return {
96
+ "image": image_tensor,
97
+ "prompt_ids": prompt_tokens.input_ids[0],
98
+ "target_ids": target_tokens.input_ids[0],
99
+ "continuous_coords": padded_coords,
100
+ "coords_mask": coords_mask,
101
+ "num_points": num_points,
102
+ "label": sample['label'],
103
+ "image_url": sample['image_url']
104
+ }
105
+ except FileNotFoundError:
106
+ print(f"Warning: Image not found: {sample.get('image_url', 'N/A')}. Skipping.")
107
+ return None
108
+ except Exception as e:
109
+ print(f"Error formatting sample ({sample.get('image_url', 'N/A')}): {e}. Skipping.")
110
+ import traceback
111
+ traceback.print_exc()
112
+ return None
113
+
114
+
115
+ class PointDataset(Dataset):
116
+ def __init__(self, data_path="active_point_dataset.pkl", split="train", test_size=1000):
117
+ with open(data_path, "rb") as f:
118
+ raw_data = pickle.load(f)
119
+
120
+ # --- Corrected filter and print statement ---
121
+ # Keep samples with 0 to MAX_POINTS points. Handle potential non-list 'points' safely.
122
+ original_count = len(raw_data)
123
+ raw_data = [sample for sample in raw_data
124
+ if 0 <= len(sample.get('points', [])) <= MAX_POINTS and isinstance(sample.get('points', []), list)]
125
+ filtered_count = len(raw_data)
126
+ print(f"Original raw data size: {original_count}")
127
+ print(f"Filtered raw data to {filtered_count} samples with 0 to {MAX_POINTS} points.")
128
+
129
+ total_samples = len(raw_data)
130
+ if total_samples == 0:
131
+ raise ValueError("No samples left after filtering. Check data or MAX_POINTS.") # Added error for empty dataset
132
+
133
+ if total_samples <= test_size:
134
+ print(f"Warning: Dataset size {total_samples} <= test_size {test_size}.")
135
+ test_size = max(1, int(total_samples * 0.2)) if total_samples > 1 else 0
136
+ train_end = total_samples - test_size
137
+ # Update print statement to reflect 0 points are included
138
+ print(f"Dataset: {total_samples} total (0 to {MAX_POINTS} points), {train_end} train, {test_size} test")
139
+
140
+ # --- Corrected split logic to use actual train/test counts ---
141
+ if split == "train":
142
+ # Check if train_end is valid before slicing
143
+ if train_end <= 0: print("Warning: No samples allocated for training split.")
144
+ self.raw_data = raw_data[:train_end]
145
+ elif split == "test":
146
+ # Check if test_size is valid before slicing
147
+ if test_size <= 0: print("Warning: No samples allocated for test split.")
148
+ self.raw_data = raw_data[train_end:]
149
+ else:
150
+ raise ValueError("split must be 'train' or 'test'")
151
+
152
+ # DO NOT preprocess data here - just store the raw data
153
+ # This is the key change - we don't load all images at once
154
+ print(f"Dataset initialized with {len(self.raw_data)} samples for {split}")
155
+
156
+ # Optional: Cache a small number of recent items to speed up repeated access
157
+ self.cache_size = 8000 # Adjust based on memory constraints
158
+ self.cache = {} # Simple LRU cache for processed samples
159
+
160
+ def __len__(self):
161
+ return len(self.raw_data)
162
+
163
+ def __getitem__(self, idx):
164
+ # Check if the item is in the cache
165
+ if idx in self.cache:
166
+ return self.cache[idx]
167
+
168
+ # Process the sample on-demand
169
+ sample = self.raw_data[idx]
170
+ formatted = format_data_for_training(sample)
171
+
172
+ # If processing failed, try the next sample
173
+ if formatted is None:
174
+ # Find next valid index (with wrapping)
175
+ next_idx = (idx + 1) % len(self.raw_data)
176
+
177
+ # Prevent infinite loop if all samples are invalid
178
+ attempts = 0
179
+ while formatted is None and attempts < min(10, len(self.raw_data)):
180
+ sample = self.raw_data[next_idx]
181
+ formatted = format_data_for_training(sample)
182
+ next_idx = (next_idx + 1) % len(self.raw_data)
183
+ attempts += 1
184
+
185
+ # If we still don't have a valid sample after attempts, return a dummy sample
186
+ if formatted is None:
187
+ print(f"Warning: Failed to find valid sample after {attempts} attempts")
188
+ # Create minimal valid sample with zeros
189
+ formatted = self._create_dummy_sample()
190
+
191
+ # Update cache - simple LRU implementation
192
+ if len(self.cache) >= self.cache_size:
193
+ # Remove oldest item (first key)
194
+ if self.cache:
195
+ oldest_key = next(iter(self.cache))
196
+ del self.cache[oldest_key]
197
+
198
+ # Add to cache
199
+ self.cache[idx] = formatted
200
+
201
+ return formatted
202
+
203
+ def _create_dummy_sample(self):
204
+ """Creates a minimal valid sample when all else fails."""
205
+ # Create empty image tensor
206
+ image_tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE)
207
+
208
+ # Create minimal tokens
209
+ prompt_text = "<point_start>dummy<point_end>"
210
+ target_text = "<result_start><result_end>" + tokenizer.eos_token
211
+
212
+ prompt_tokens = tokenizer(prompt_text, return_tensors="pt").input_ids[0]
213
+ target_tokens = tokenizer(target_text, return_tensors="pt").input_ids[0]
214
+
215
+ # Create empty coordinates
216
+ padded_coords = torch.full((MAX_POINTS, 2), -1.0)
217
+ coords_mask = torch.zeros(MAX_POINTS)
218
+
219
+ return {
220
+ "image": image_tensor,
221
+ "prompt_ids": prompt_tokens,
222
+ "target_ids": target_tokens,
223
+ "continuous_coords": padded_coords,
224
+ "coords_mask": coords_mask,
225
+ "num_points": 0,
226
+ "label": "dummy",
227
+ "image_url": "none"
228
+ }
229
+
230
+ # --- collate_fn remains the same as the previous version ---
231
+ @staticmethod
232
+ def collate_fn(batch):
233
+ # ... (Same as before, correctly handles stacking the padded coords and masks) ...
234
+ batch = [item for item in batch if item is not None]
235
+ if not batch: return None
236
+
237
+ images = torch.stack([item['image'] for item in batch]).to(DTYPE)
238
+
239
+ # --- Pad Prompt IDs ---
240
+ max_prompt_len = max(item['prompt_ids'].size(0) for item in batch)
241
+ prompt_ids_padded, prompt_attention_mask = [], []
242
+ for item in batch:
243
+ ids, pad_len = item['prompt_ids'], max_prompt_len - item['prompt_ids'].size(0)
244
+ prompt_ids_padded.append(torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
245
+ prompt_attention_mask.append(torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)]))
246
+ prompt_ids = torch.stack(prompt_ids_padded)
247
+ prompt_attention_mask = torch.stack(prompt_attention_mask)
248
+
249
+ # --- Pad Target IDs & Create Generative Targets ---
250
+ max_target_len = max(item['target_ids'].size(0) for item in batch)
251
+ target_ids_padded, target_attention_mask, generative_targets = [], [], []
252
+ for item in batch:
253
+ ids, pad_len = item['target_ids'], max_target_len - item['target_ids'].size(0)
254
+ padded_ids = torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)])
255
+ target_ids_padded.append(padded_ids)
256
+ mask = torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)])
257
+ target_attention_mask.append(mask)
258
+ targets = torch.full_like(padded_ids, -100)
259
+ if ids.size(0) > 1:
260
+ targets[:ids.size(0)-1] = ids[1:]
261
+ if ids.numel() > 0 and ids[-1] == tokenizer.eos_token_id:
262
+ if ids.size(0) > 1:
263
+ targets[ids.size(0)-1] = tokenizer.eos_token_id
264
+ else:
265
+ targets[0] = -100
266
+ generative_targets.append(targets)
267
+ target_ids = torch.stack(target_ids_padded)
268
+ target_attention_mask = torch.stack(target_attention_mask)
269
+ generative_targets = torch.stack(generative_targets)
270
+
271
+ # --- Stack Continuous Coords and Masks ---
272
+ continuous_coords = torch.stack([item['continuous_coords'] for item in batch])
273
+ coords_mask = torch.stack([item['coords_mask'] for item in batch])
274
+ num_points = [item['num_points'] for item in batch]
275
+
276
+ labels = [item['label'] for item in batch]
277
+ image_urls = [item.get('image_url', '') for item in batch]
278
+
279
+ return {
280
+ 'image': images,
281
+ 'prompt_ids': prompt_ids,
282
+ 'prompt_attention_mask': prompt_attention_mask,
283
+ 'target_ids': target_ids,
284
+ 'target_attention_mask': target_attention_mask,
285
+ 'generative_targets': generative_targets,
286
+ 'continuous_coords': continuous_coords,
287
+ 'coords_mask': coords_mask,
288
+ 'num_points': num_points,
289
+ 'label': labels,
290
+ 'image_url': image_urls
291
+ }
292
+
293
+ def create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
294
+ """Create training dataloader with memory-efficient settings.
295
+
296
+ Args:
297
+ batch_size: Number of samples per batch
298
+ num_workers: Number of worker processes for data loading
299
+ prefetch_factor: Number of batches to prefetch per worker
300
+
301
+ Returns:
302
+ DataLoader instance or None if dataset is empty
303
+ """
304
+ dataset = PointDataset(split="train")
305
+ if len(dataset) == 0:
306
+ return None
307
+
308
+ # Configure DataLoader for memory efficiency
309
+ return DataLoader(
310
+ dataset,
311
+ batch_size=batch_size,
312
+ shuffle=True,
313
+ collate_fn=PointDataset.collate_fn,
314
+ pin_memory=True, # Speeds up CPU to GPU transfer
315
+ num_workers=num_workers,
316
+ prefetch_factor=prefetch_factor if num_workers > 0 else None, # Only valid with workers
317
+ persistent_workers=num_workers > 0, # Keep workers alive between epochs
318
+ drop_last=False # Don't drop the last incomplete batch
319
+ )
320
+
321
+ def create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
322
+ """Create test dataloader with memory-efficient settings.
323
+
324
+ Args:
325
+ batch_size: Number of samples per batch
326
+ num_workers: Number of worker processes for data loading
327
+ prefetch_factor: Number of batches to prefetch per worker
328
+
329
+ Returns:
330
+ DataLoader instance or None if dataset is empty
331
+ """
332
+ dataset = PointDataset(split="test")
333
+ if len(dataset) == 0:
334
+ print("Warning: Test dataset is empty. Returning None.")
335
+ return None
336
+
337
+ # Test loader with similar memory settings but no shuffling
338
+ return DataLoader(
339
+ dataset,
340
+ batch_size=batch_size,
341
+ shuffle=False,
342
+ collate_fn=PointDataset.collate_fn,
343
+ pin_memory=True,
344
+ num_workers=num_workers,
345
+ prefetch_factor=prefetch_factor if num_workers > 0 else None,
346
+ persistent_workers=num_workers > 0,
347
+ drop_last=False
348
+ )
349
+
decoder_language_model.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model_components import Block
2
+ from constants import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from utils import tokenizer, vocab_size
7
+
8
+ class DecoderLanguageModel(nn.Module):
9
+ """
10
+ Transformer Decoder Language Model with optional coordinate regression head.
11
+ Processes a combined sequence of embeddings.
12
+ Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS).
13
+ """
14
+ def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS,
15
+ n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT):
16
+ super().__init__()
17
+ # --- Input Embeddings ---
18
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
19
+ self.position_embedding_table = nn.Embedding(max_context, n_embd)
20
+ self.dropout = nn.Dropout(dropout)
21
+
22
+ # --- Transformer Blocks ---
23
+ self.blocks = nn.ModuleList([
24
+ Block(n_embd, num_heads, dropout, is_decoder=True)
25
+ for _ in range(n_layer)
26
+ ])
27
+
28
+ # --- Final Layer Norm ---
29
+ self.ln_f = nn.LayerNorm(n_embd)
30
+
31
+ # --- Output Heads ---
32
+ # 1. Head for token classification
33
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
34
+
35
+ # 2. Head for direct coordinate regression (predicting MAX_POINTS * 2 values)
36
+ self.regression_head = nn.Sequential(
37
+ nn.Linear(n_embd, n_embd // 2),
38
+ nn.GELU(),
39
+ nn.Linear(n_embd // 2, MAX_POINTS * 2), # Output MAX_POINTS * (x, y)
40
+ nn.Sigmoid() # Output activation [0, 1]
41
+ )
42
+ # --- End Output Heads ---
43
+
44
+ self.n_embd = n_embd
45
+ self.max_context = max_context
46
+ self.token_embedding_table.weight = self.lm_head.weight
47
+ self.apply(self._init_weights)
48
+ print(f"DecoderLanguageModel initialized with {n_layer} layers.")
49
+
50
+ def _init_weights(self, module):
51
+ # ... (same as before) ...
52
+ if isinstance(module, nn.Linear):
53
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
54
+ if module.bias is not None:
55
+ torch.nn.init.zeros_(module.bias)
56
+ elif isinstance(module, nn.Embedding):
57
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
58
+ elif isinstance(module, nn.LayerNorm):
59
+ torch.nn.init.zeros_(module.bias)
60
+ torch.nn.init.ones_(module.weight)
61
+
62
+
63
+ def forward(self, combined_embeds, attention_mask=None, targets=None):
64
+ """
65
+ Forward pass for training or inference where loss is calculated.
66
+ Regression output is now handled *outside* this module by VLM.
67
+ """
68
+ # --- Input Validation & Processing ---
69
+ if combined_embeds.ndim != 3:
70
+ raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}")
71
+ B, T, C = combined_embeds.shape
72
+ if T > self.max_context:
73
+ # ... (context truncation logic - same as before) ...
74
+ print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.")
75
+ combined_embeds = combined_embeds[:, -self.max_context:, :]
76
+ if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:]
77
+ if targets is not None: targets = targets[:, -self.max_context:]
78
+ T = self.max_context
79
+
80
+ # --- Positional Encoding ---
81
+ pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device)
82
+ pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1)
83
+ pos_emb = self.position_embedding_table(pos) # Shape: (T, C)
84
+ x = combined_embeds + pos_emb.unsqueeze(0)
85
+ x = self.dropout(x)
86
+
87
+ # --- Transformer Blocks ---
88
+ for block in self.blocks:
89
+ x = block(x, attention_mask=attention_mask)
90
+
91
+ # --- Final Layer Norm ---
92
+ x_norm = self.ln_f(x) # Shape: (B, T, C) - Pass this out for VLM regression head
93
+
94
+ # --- Classification Head Output ---
95
+ logits = self.lm_head(x_norm) # Shape: (B, T, VocabSize)
96
+
97
+ # --- Classification Loss Calculation ---
98
+ class_loss = None
99
+ if targets is not None:
100
+ # ... (cross_entropy calculation - same as before) ...
101
+ try:
102
+ class_loss = F.cross_entropy(
103
+ logits.view(-1, logits.size(-1)),
104
+ targets.view(-1),
105
+ ignore_index=-100
106
+ )
107
+ if torch.isnan(class_loss):
108
+ print("Warning: class_loss is NaN.")
109
+ class_loss = None
110
+ except Exception as e:
111
+ print(f"Error calculating cross_entropy: {e}")
112
+ print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}")
113
+ class_loss = None
114
+
115
+ # Return logits, class_loss, and the final normalized hidden states
116
+ return logits, class_loss, x_norm
117
+
118
+ # --- Generation Method (Example - if needed internally, otherwise VLM handles it) ---
119
+ # If VLM needs this class to perform generation based on token IDs:
120
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
121
+ """
122
+ Autoregressive generation based on starting token IDs.
123
+ NOTE: This version doesn't handle combined embeddings directly.
124
+ The VisionLanguageModel should ideally use a method like
125
+ generate_from_embeddings or implement the loop externally.
126
+ """
127
+ self.eval()
128
+ for _ in range(max_new_tokens):
129
+ # --- Context Management ---
130
+ # Crop idx if longer than context length
131
+ idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:]
132
+
133
+ # --- Forward Pass ---
134
+ # Get embeddings
135
+ tok_embeds = self.token_embedding_table(idx_cond) # (B, T, C)
136
+ # Get positional embeddings
137
+ pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device)
138
+ pos = pos.clamp(max=self.max_context - 1)
139
+ pos_emb = self.position_embedding_table(pos).unsqueeze(0) # (1, T, C)
140
+ x = self.dropout(tok_embeds + pos_emb)
141
+ # Pass through blocks (no padding mask needed here as we handle single sequence)
142
+ for block in self.blocks:
143
+ x = block(x, attention_mask=None) # Causal mask is internal to block/head
144
+ # Final layer norm and head for the last token only
145
+ x = self.ln_f(x[:, -1:, :]) # (B, 1, C)
146
+ logits = self.lm_head(x) # (B, 1, V)
147
+ logits = logits.squeeze(1) # (B, V)
148
+
149
+ # --- Sampling ---
150
+ logits = logits / temperature
151
+ if top_k is not None and top_k > 0:
152
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
153
+ logits[logits < v[:, [-1]]] = -float('Inf')
154
+ probs = F.softmax(logits, dim=-1)
155
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
156
+
157
+ # Append sampled token
158
+ idx = torch.cat((idx, idx_next), dim=1)
159
+
160
+ # Stop if EOS
161
+ if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all():
162
+ break
163
+ self.train()
164
+ return idx
165
+
finetune_lm_head_ce_loss.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # finetune_lm_head_ce_loss.py
2
+ # python finetune_lm_head_ce_loss.py --pretrained_model_path model_regression_multi_stage_2_11.pth
3
+
4
+ # finetune_lm_head_ce_loss.py
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ from torch.optim.lr_scheduler import CosineAnnealingLR # Using Cosine decay for fine-tuning
10
+ from tqdm.auto import tqdm
11
+ import wandb
12
+ from datetime import datetime
13
+ import numpy as np
14
+ import argparse
15
+ import os
16
+ import math
17
+ import traceback # For detailed error printing
18
+ from constants import *
19
+
20
+ try:
21
+ # Ensure get_tokenizer defines global tokenizer and vocab_size
22
+ from utils import get_tokenizer, tokenizer, vocab_size, tensor_to_image, image_to_tensor
23
+ if 'tokenizer' not in globals() or 'vocab_size' not in globals():
24
+ print("Initializing tokenizer...")
25
+ tokenizer, vocab_size = get_tokenizer()
26
+ except ImportError:
27
+ print("Error: Could not import required functions/variables from utils.py.")
28
+ exit()
29
+ except NameError:
30
+ print("Error: tokenizer or vocab_size not defined after importing utils.")
31
+ exit()
32
+ except Exception as e:
33
+ print(f"Error during utils import or tokenizer init: {e}")
34
+ exit()
35
+
36
+
37
+ try:
38
+ # Dataset needs to handle 0 points and MAX_POINTS filter
39
+ # Collate fn should return necessary keys including 'generative_targets'
40
+ from dataset import create_train_dataloader, create_test_dataloader, PointDataset
41
+ except ImportError:
42
+ print("Error: Could not import from dataset.py.")
43
+ exit()
44
+
45
+ try:
46
+ # VisionLanguageModel __init__ should match the one used in the training script
47
+ # Make sure DecoderLanguageModel etc. are also available
48
+ from vision_language_model import VisionLanguageModel
49
+ except ImportError:
50
+ print("Error: Could not import VisionLanguageModel from vision_language_model.py.")
51
+ exit()
52
+
53
+ # --- Fine-tuning Specific Arguments ---
54
+ parser = argparse.ArgumentParser(description="Re-initialize and fine-tune the LM head using ONLY classification loss.")
55
+ parser.add_argument("--pretrained_model_path", type=str, required=True, help="Path to the pre-trained model state_dict (.pth file).")
56
+ parser.add_argument("--output_model_path", type=str, default="model_lm_reinit_ce_finetuned.pth", help="Path to save the fine-tuned model.")
57
+ parser.add_argument("--ft_epochs", type=int, default=10, help="Number of epochs for fine-tuning.")
58
+ parser.add_argument("--ft_lr", type=float, default=5e-4, help="Learning rate for fine-tuning.")
59
+ parser.add_argument("--ft_batch_size", type=int, default=BATCH_SIZE, help="Batch size for fine-tuning.")
60
+ parser.add_argument("--ft_grad_accum", type=int, default=GRAD_ACCUMULATION_STEPS, help="Gradient accumulation steps.")
61
+ parser.add_argument("--ft_log_steps", type=int, default=1, help="Logging frequency.")
62
+ parser.add_argument("--train_final_ln", action='store_true', help="Also train the final LayerNorm (ln_f) before the lm_head.")
63
+ parser.add_argument("--wandb_project", type=str, default="point-lm-head-reinit-ce-finetune", help="WandB project name.")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ args = parser.parse_args()
68
+
69
+ # Use constants/args consistently
70
+ FT_BATCH_SIZE = args.ft_batch_size
71
+ FT_GRAD_ACCUM = args.ft_grad_accum
72
+ FT_LOG_STEPS = args.ft_log_steps
73
+
74
+ print(f"Using device: {DEVICE}")
75
+ print(f"Re-initializing and fine-tuning LM head (and final LN: {args.train_final_ln})")
76
+ print(f"Using ONLY Classification (Cross-Entropy) Loss")
77
+ print(f"Pretrained model: {args.pretrained_model_path}")
78
+ print(f"Output model: {args.output_model_path}")
79
+ print(f"Epochs: {args.ft_epochs}, LR: {args.ft_lr}, Batch Size: {FT_BATCH_SIZE}, Grad Accum: {FT_GRAD_ACCUM}")
80
+
81
+ # --- Load Model Definition ---
82
+ print("Loading model definition...")
83
+ try:
84
+ # Use parameters consistent with the pre-trained model's architecture
85
+ model_args = {
86
+ 'n_embd': HIDDEN_DIM, 'vocab_size': vocab_size, 'img_size': IMAGE_SIZE, 'patch_size': PATCH_SIZE,
87
+ 'num_heads': NUM_HEADS, 'num_blks_vit': NUM_LAYERS, 'num_blks_dec': NUM_LAYERS,
88
+ 'emb_dropout': 0.0, 'blk_dropout': 0.0, 'max_context': CONTEXT_LENGTH,
89
+ 'shared_embed_dim': SHARED_EMBED_DIM,
90
+ # Use the single contrastive lambda expected by the VLM class from training
91
+ 'lambda_contrastive': 0.0,
92
+ 'lambda_regression': 0.0,
93
+ 'max_points': MAX_POINTS
94
+ }
95
+ model = VisionLanguageModel(**model_args).to(DEVICE)
96
+ except Exception as e:
97
+ print(f"Error initializing model structure: {e}")
98
+ exit()
99
+
100
+ # --- Load Pre-trained Weights ---
101
+ print(f"Loading pre-trained weights from: {args.pretrained_model_path}")
102
+ try:
103
+ # Use strict=False just in case parameter names differ slightly
104
+ model.load_state_dict(torch.load(args.pretrained_model_path, map_location=DEVICE, weights_only=True), strict=False)
105
+ print("Pre-trained weights loaded successfully.")
106
+ except FileNotFoundError: print(f"Error: Pre-trained model file not found at {args.pretrained_model_path}"); exit()
107
+ except Exception as e: print(f"Error loading model state_dict: {e}"); exit()
108
+
109
+ # --- Reinitialize LM Head ---
110
+ print("Reinitializing LM Head...")
111
+ model.decoder.lm_head.reset_parameters()
112
+ # --- Explicitly Re-Tie Weights AFTER reinitialization ---
113
+ model.decoder.token_embedding_table.weight = model.decoder.lm_head.weight
114
+ print("LM Head reinitialized and weights explicitly retied.")
115
+
116
+ # --- Freeze/Unfreeze Parameters (Do ONCE before loop) ---
117
+ print("Setting requires_grad flags...")
118
+ params_to_optimize = []
119
+ trainable_param_names = []
120
+ total_params = sum(p.numel() for p in model.parameters())
121
+
122
+ for param in model.parameters():
123
+ param.requires_grad = False # Freeze all
124
+
125
+ print("\nParameters explicitly marked as trainable:")
126
+ for param in model.decoder.lm_head.parameters():
127
+ param.requires_grad = True
128
+ params_to_optimize.append(param)
129
+ for name, p in model.decoder.lm_head.named_parameters():
130
+ if p is param: trainable_param_names.append(f"decoder.lm_head.{name}"); break
131
+
132
+ if args.train_final_ln:
133
+ for param in model.decoder.ln_f.parameters():
134
+ param.requires_grad = True
135
+ params_to_optimize.append(param)
136
+ for name, p in model.decoder.ln_f.named_parameters():
137
+ if p is param: trainable_param_names.append(f"decoder.ln_f.{name}"); break
138
+
139
+ # --- Create Optimizer using the specific list ---
140
+ print("\nParameters passed to optimizer:")
141
+ for name in trainable_param_names: print(f"- {name}")
142
+ trainable_params_count = sum(p.numel() for p in params_to_optimize)
143
+ print(f"\nTotal parameters: {total_params}")
144
+ print(f"Trainable parameters (optimizer target): {trainable_params_count} ({100 * trainable_params_count / total_params:.2f}%)")
145
+
146
+ # Verification print
147
+ print("\nVerification: All parameters with requires_grad=True:")
148
+ actual_trainable_count = 0
149
+ for name, param in model.named_parameters():
150
+ if param.requires_grad:
151
+ is_in_optimize_list = any(p is param for p in params_to_optimize)
152
+ print(f"- {name} (Requires Grad: {param.requires_grad}, In Optimizer List: {is_in_optimize_list})")
153
+ actual_trainable_count += param.numel()
154
+ print(f"Actual trainable count (incl. tied): {actual_trainable_count}")
155
+
156
+
157
+ if not params_to_optimize: print("Error: No parameters collected for the optimizer."); exit()
158
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=args.ft_lr, betas=(0.9, 0.95), weight_decay=0.1)
159
+ print("Optimizer created.")
160
+
161
+ # --- Dataloaders & Scheduler ---
162
+ print("Creating dataloaders...")
163
+ train_loader = create_train_dataloader(batch_size=FT_BATCH_SIZE, num_workers=4)
164
+ test_loader = create_test_dataloader(batch_size=FT_BATCH_SIZE, num_workers=2)
165
+ if train_loader is None: exit("Training loader failed to initialize.")
166
+ test_loader_has_data = test_loader and len(test_loader.dataset) > 0
167
+ scheduler = None
168
+ if train_loader and len(train_loader) > 0:
169
+ steps_per_epoch = (len(train_loader) // FT_GRAD_ACCUM) + (1 if len(train_loader) % FT_GRAD_ACCUM != 0 else 0)
170
+ total_steps = steps_per_epoch * args.ft_epochs
171
+ print(f"Fine-tuning: Total estimated optimization steps: {total_steps}")
172
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=args.ft_lr / 10)
173
+ else: print("Warning: Train loader empty. Cannot setup scheduler.")
174
+
175
+ # --- Wandb Setup ---
176
+ wandb_enabled = False
177
+ try:
178
+ wandb.init(
179
+ project=args.wandb_project,
180
+ name=f"lm-head-reinit-ce-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
181
+ config={ "fine_tuning_lr": args.ft_lr, "fine_tuning_epochs": args.ft_epochs, "batch_size": FT_BATCH_SIZE,
182
+ "grad_accum": FT_GRAD_ACCUM, "pretrained_model": args.pretrained_model_path, "train_final_ln": args.train_final_ln,
183
+ "loss": "Classification Only" } )
184
+ wandb_enabled = True
185
+ except Exception as e: print(f"Wandb initialization failed: {e}.")
186
+
187
+ # --- Fine-tuning Loop ---
188
+ print("Starting LM head re-init fine-tuning with Classification Loss...")
189
+ torch.autograd.set_detect_anomaly(True)
190
+ step_counter = 0
191
+ optimizer.zero_grad()
192
+
193
+ for epoch in range(args.ft_epochs):
194
+ model.train() # Set dropout/layernorm layers to train mode
195
+
196
+ epoch_class_loss_accum = 0.0
197
+ valid_batches_accum = 0
198
+ pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"FT Epoch {epoch+1}/{args.ft_epochs}", leave=False)
199
+
200
+ for batch_idx, batch in pbar:
201
+ if batch is None: continue
202
+ # --- Unpack Data ---
203
+ try:
204
+ images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
205
+ prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
206
+ prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
207
+ target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
208
+ target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
209
+ generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True) # Needed for loss
210
+ # Pass None for unused args (model forward should handle this)
211
+ continuous_coords = batch.get('continuous_coords'); coords_mask = batch.get('coords_mask'); num_points_list = batch.get('num_points')
212
+ if continuous_coords is not None: continuous_coords = continuous_coords.to(DEVICE, non_blocking=True)
213
+ if coords_mask is not None: coords_mask = coords_mask.to(DEVICE, non_blocking=True)
214
+ except KeyError as e: print(f"KeyError unpacking batch: {e}"); continue
215
+ except Exception as e: print(f"Error unpacking batch: {e}"); continue
216
+
217
+ # --- Forward Pass ---
218
+ # Run full model normally. Autograd handles requires_grad flags.
219
+ try:
220
+ # We only need the logits output from the main model call
221
+ logits, _, _, _, _, _, *_ = model(
222
+ img_array=images, prompt_ids=prompt_ids, prompt_attention_mask=prompt_attention_mask,
223
+ target_ids=target_ids, target_attention_mask=target_attention_mask,
224
+ generative_targets=generative_targets, # Pass targets, model might use internally
225
+ continuous_coords=continuous_coords, coords_mask=coords_mask,
226
+ )
227
+ if logits is None or not torch.isfinite(logits).all():
228
+ print(f"!!! ERROR: NaN/Inf/None detected in logits. Skipping batch {batch_idx}. !!!")
229
+ optimizer.zero_grad(); continue
230
+
231
+ except Exception as e:
232
+ print(f"!!! ERROR during forward pass: {e} !!!"); traceback.print_exc()
233
+ optimizer.zero_grad(); continue
234
+
235
+ # --- Calculate Classification Loss EXTERNALLY ---
236
+ loss_to_backward = None
237
+ try:
238
+ # Get batch size and vocab size from logits
239
+ B, T_logits, V = logits.shape
240
+
241
+ # --- Prepare PADDED Targets for External CE Loss ---
242
+ # Logic to pad generative_targets to match T_logits
243
+ B_targ, T_target_orig = generative_targets.shape
244
+ N_img = model.num_patches
245
+ T_prompt = prompt_ids.shape[1]
246
+ T_combined_expected = N_img + T_prompt + T_target_orig # Expected full length
247
+
248
+ if T_logits != T_combined_expected:
249
+ # Handle potential truncation due to context length
250
+ print(f"Warning: Logits length {T_logits} != Expected combined length {T_combined_expected}. Adjusting targets.")
251
+ T_target_in_logits = max(0, T_logits - (N_img + T_prompt))
252
+ generative_targets_sliced = generative_targets[:, :T_target_in_logits]
253
+ combined_class_targets = torch.cat([
254
+ torch.full((B, T_logits - T_target_in_logits), -100, dtype=torch.long, device=DEVICE),
255
+ generative_targets_sliced
256
+ ], dim=1)
257
+ else:
258
+ # Pad generative_targets normally
259
+ combined_class_targets = torch.cat([
260
+ torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),
261
+ generative_targets
262
+ ], dim=1)
263
+
264
+ # Verify shapes before loss calculation
265
+ if logits.shape[1] != combined_class_targets.shape[1]:
266
+ raise ValueError(f"Shape mismatch before CE Loss! Logits T={logits.shape[1]}, Targets T={combined_class_targets.shape[1]}")
267
+
268
+ # Calculate loss using the logits that require grad and the padded targets
269
+ loss_to_backward = F.cross_entropy(
270
+ logits.view(-1, V), # Shape (B * T_logits, V)
271
+ combined_class_targets.view(-1), # Shape (B * T_logits)
272
+ ignore_index=-100
273
+ )
274
+
275
+ if not torch.isfinite(loss_to_backward):
276
+ print(f"Warning: NaN/Inf detected in calculated class_loss ({loss_to_backward}).")
277
+ loss_to_backward = None
278
+
279
+ except Exception as e:
280
+ print(f"Error calculating external CE loss: {e}")
281
+ loss_to_backward = None
282
+
283
+ # Check loss before backward
284
+ if loss_to_backward is None:
285
+ print(f"Warning: Skipping batch {batch_idx} due to invalid loss calculation.")
286
+ optimizer.zero_grad(); continue
287
+
288
+ # --- Verification ---
289
+ if loss_to_backward.grad_fn is None:
290
+ print(f"!!! ERROR: loss_to_backward (value: {loss_to_backward.item()}) has no grad_fn! Batch {batch_idx} !!!")
291
+ optimizer.zero_grad(); continue
292
+
293
+ # Accumulate for logging
294
+ epoch_class_loss_accum += loss_to_backward.item(); valid_batches_accum += 1
295
+ scaled_loss = loss_to_backward / FT_GRAD_ACCUM
296
+
297
+ # --- Backward Pass ---
298
+ try:
299
+ scaled_loss.backward()
300
+ except RuntimeError as e: print(f"!!! RUNTIME ERROR backward: {e} !!!"); optimizer.zero_grad(); continue
301
+
302
+ # --- Gradient Accumulation Step ---
303
+ if (batch_idx + 1) % FT_GRAD_ACCUM == 0 or (batch_idx + 1) == len(train_loader):
304
+ # Check/Clip gradients of OPTIMIZED parameters
305
+ found_non_finite_grad = False
306
+ for p in params_to_optimize:
307
+ if p.grad is not None and not torch.isfinite(p.grad).all():
308
+ print(f"!!! WARNING: NaN/Inf gradient BEFORE step. Skipping step. !!!")
309
+ found_non_finite_grad = True; break
310
+ if found_non_finite_grad: optimizer.zero_grad(); continue
311
+
312
+ grad_norm = torch.nn.utils.clip_grad_norm_(params_to_optimize, MAX_GRAD_NORM)
313
+ if not torch.isfinite(grad_norm): print(f"!!! WARNING: Grad norm NaN/Inf ({grad_norm.item()}) AFTER clipping. Skipping step. !!!"); optimizer.zero_grad(); continue
314
+
315
+ optimizer.step()
316
+ if scheduler: scheduler.step()
317
+ optimizer.zero_grad()
318
+ step_counter += 1
319
+
320
+ # --- Logging ---
321
+ if step_counter % FT_LOG_STEPS == 0 and valid_batches_accum > 0:
322
+ avg_class_loss = epoch_class_loss_accum / valid_batches_accum
323
+ current_lr = optimizer.param_groups[0]['lr']
324
+ # --- Test Evaluation (Class loss only) ---
325
+ test_class_loss_val = float('nan')
326
+ if test_loader_has_data:
327
+ model.eval()
328
+ with torch.no_grad():
329
+ try:
330
+ test_batch = next(iter(test_loader))
331
+ if test_batch:
332
+ # Unpack test data needed for forward pass -> logits
333
+ t_images = test_batch['image'].to(DEVICE).to(DTYPE)
334
+ t_p_ids = test_batch['prompt_ids'].to(DEVICE)
335
+ t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
336
+ t_t_ids = test_batch['target_ids'].to(DEVICE)
337
+ t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
338
+ t_gen_targets = test_batch['generative_targets'].to(DEVICE) # Need this for external CE calc
339
+ # Pass None for other args if model handles it
340
+ t_cont_coords = test_batch.get('continuous_coords'); t_coords_mask = test_batch.get('coords_mask'); t_num_pts = test_batch.get('num_points')
341
+ if t_cont_coords is not None: t_cont_coords = t_cont_coords.to(DEVICE)
342
+ if t_coords_mask is not None: t_coords_mask = t_coords_mask.to(DEVICE)
343
+
344
+ # Run forward just to get logits
345
+ logits_t, _, _, _, _, _, *_ = model(
346
+ t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask, t_gen_targets,
347
+ t_cont_coords, t_coords_mask
348
+ )
349
+
350
+ # Calculate CE loss externally for logging
351
+ if logits_t is not None and t_gen_targets is not None:
352
+ try:
353
+ # Prepare padded targets matching logits_t shape
354
+ B_test, T_logits_t, V_test = logits_t.shape
355
+ _, T_target_orig_t = t_gen_targets.shape
356
+ N_img_test = model.num_patches
357
+ T_prompt_test = t_p_ids.shape[1]
358
+ T_combined_expected_t = N_img_test + T_prompt_test + T_target_orig_t
359
+
360
+ if T_logits_t != T_combined_expected_t:
361
+ T_target_in_logits_t = max(0, T_logits_t - (N_img_test + T_prompt_test))
362
+ generative_targets_sliced_t = t_gen_targets[:, :T_target_in_logits_t]
363
+ combined_class_targets_t = torch.cat([
364
+ torch.full((B_test, T_logits_t - T_target_in_logits_t), -100, dtype=torch.long, device=DEVICE),
365
+ generative_targets_sliced_t
366
+ ], dim=1)
367
+ else:
368
+ combined_class_targets_t = torch.cat([
369
+ torch.full((B_test, N_img_test + T_prompt_test), -100, dtype=torch.long, device=DEVICE),
370
+ t_gen_targets
371
+ ], dim=1)
372
+
373
+ if logits_t.shape[1] != combined_class_targets_t.shape[1]:
374
+ raise ValueError("Shape mismatch test CE!")
375
+
376
+ t_class_loss = F.cross_entropy(logits_t.view(-1, V_test), combined_class_targets_t.view(-1), ignore_index=-100)
377
+ test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
378
+ except Exception as e_ce_test: print(f"Error CE Test: {e_ce_test}")
379
+ except StopIteration: print("Info: Test loader exhausted during logging.")
380
+ except Exception as e: print(f"Error during test eval: {e}")
381
+ model.train() # Set back to train mode
382
+
383
+ # Log data
384
+ log_data = { # Simplified logging
385
+ "train/class_loss": avg_class_loss,
386
+ "test/class_loss": test_class_loss_val,
387
+ "epoch": epoch + ((batch_idx + 1) / len(train_loader)),
388
+ "step": step_counter,
389
+ "learning_rate": current_lr,
390
+ "gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else float('nan'),
391
+ }
392
+ pbar.set_postfix({"lr": f"{current_lr:.2e}", "cls_loss": f"{avg_class_loss:.4f}", "gnorm": f"{log_data['gradient_norm']:.3f}"})
393
+ if wandb_enabled: wandb.log(log_data, step=step_counter)
394
+
395
+ # Reset accumulators
396
+ epoch_class_loss_accum = 0.0; valid_batches_accum = 0
397
+
398
+ # --- End of Epoch ---
399
+ print(f"\nFT Epoch {epoch+1}/{args.ft_epochs} completed.")
400
+ # Optional: Save checkpoint periodically
401
+ if (epoch + 1) % 5 == 0 or (epoch + 1) == args.ft_epochs:
402
+ chkpt_path = args.output_model_path.replace(".pth", f"_epoch{epoch+1}.pth")
403
+ try:
404
+ torch.save(model.state_dict(), chkpt_path)
405
+ print(f"Checkpoint saved to: {chkpt_path}")
406
+ except Exception as e: print(f"Error saving checkpoint: {e}")
407
+
408
+
409
+ # --- End of Fine-tuning ---
410
+ print("\nLM head fine-tuning with CE loss completed!")
411
+ try:
412
+ torch.save(model.state_dict(), args.output_model_path)
413
+ print(f"Fine-tuned model saved to: {args.output_model_path}")
414
+ except Exception as e: print(f"Error saving fine-tuned model: {e}")
415
+
416
+ if wandb_enabled:
417
+ wandb.finish()
418
+ torch.autograd.set_detect_anomaly(False) # Disable anomaly detection
infer.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from utils import image_to_tensor, tokenizer, tensor_to_image, vocab_size, tokenizer
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from PIL import ImageDraw, Image
6
+ from dataset import create_test_dataloader
7
+ from vision_language_model import VisionLanguageModel
8
+
9
+
10
+ model = VisionLanguageModel(
11
+ n_embd=HIDDEN_DIM,
12
+ vocab_size=vocab_size,
13
+ img_size=IMAGE_SIZE,
14
+ patch_size=PATCH_SIZE,
15
+ num_heads=NUM_HEADS,
16
+ num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
17
+ num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
18
+ emb_dropout=DROPOUT,
19
+ blk_dropout=DROPOUT,
20
+ max_context=CONTEXT_LENGTH,
21
+ shared_embed_dim=SHARED_EMBED_DIM,
22
+ lambda_contrastive=LAMBDA_CONTRASTIVE,
23
+ lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
24
+ ).to(DEVICE)
25
+
26
+ MODEL_PATH = "model_regression_multi_first_100.pth" # "model_regression_multi_16.pth"
27
+
28
+ if DEVICE == "cuda":
29
+ model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
30
+ else:
31
+ model.load_state_dict(torch.load(MODEL_PATH, weights_only=True, map_location=torch.device('cpu')))
32
+ model.eval()
33
+
34
+ def generate_sample_from_image_text(
35
+ model,
36
+ image_path,
37
+ prompt_label,
38
+ tokenizer,
39
+ device,
40
+ max_new_tokens=70,
41
+ temperature=0.8,
42
+ top_k=10,
43
+ output_path="generated_output.png"
44
+ ):
45
+ """
46
+ Generates a prediction for an image and prompt text and saves it to a file.
47
+ Generation loop is implemented *within* this function.
48
+
49
+ Args:
50
+ model: The trained VisionLanguageModel.
51
+ image_path: Path to the input image.
52
+ prompt_label: Text prompt/label to use.
53
+ tokenizer: The tokenizer used for training.
54
+ device: The computation device ('cuda' or 'cpu').
55
+ max_new_tokens (int): Max tokens to generate after the prompt.
56
+ temperature (float): Softmax temperature for sampling.
57
+ top_k (int): K for top-k sampling (0 or None to disable).
58
+ output_path (str): Path where to save the output image.
59
+
60
+ Returns:
61
+ None. Saves the image with prompt and generated output to a file.
62
+ """
63
+ model.eval() # Set the model to evaluation mode
64
+
65
+ try:
66
+ with torch.no_grad(): # No need to track gradients during inference
67
+ # --- 1. Prepare Initial Inputs ---
68
+ # Load and process image
69
+ image = Image.open(image_path)
70
+ image_tensor = image_to_tensor(image).unsqueeze(0).to(device) # Add batch dim
71
+
72
+ # Tokenize prompt
73
+ prompt_text = f"<point_start>{prompt_label}<point_end>"
74
+ prompt_tokens = tokenizer(prompt_text, return_tensors="pt", truncation=True, padding=False)
75
+ prompt_ids = prompt_tokens.input_ids.to(device)
76
+ prompt_attention_mask = prompt_tokens.attention_mask.to(device)
77
+ B = 1 # We are processing one sample at a time
78
+
79
+ print(f"--- Generating Sample (Manual Loop) ---")
80
+ print(f"Original Label/Prompt Hint: {prompt_label}")
81
+ print(f"Input Prompt Tokens Decoded: {prompt_text}")
82
+
83
+ # --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) ---
84
+ image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C)
85
+ image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C)
86
+ prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C)
87
+
88
+ result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
89
+ result_start_embed = model.decoder.token_embedding_table(
90
+ torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C)
91
+ )
92
+
93
+ # The initial sequence fed to the decoder blocks consists of image + prompt
94
+ current_embeds = torch.cat([
95
+ image_embeds_decoder,
96
+ prompt_embeds_decoder,
97
+ result_start_embed # Add the embedding for the first expected output token
98
+ ], dim=1)
99
+ generated_ids = [] # Store newly generated IDs
100
+
101
+ # --- 3. Autoregressive Generation Loop ---
102
+ for _ in range(max_new_tokens):
103
+ T_current = current_embeds.shape[1]
104
+
105
+ # Truncate if necessary (keep recent context)
106
+ if T_current > model.decoder.max_context: # Access max_context from decoder
107
+ print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}")
108
+ current_embeds = current_embeds[:, -model.decoder.max_context:, :]
109
+ T_current = model.decoder.max_context
110
+
111
+ # Prepare positional embeddings for current length
112
+ pos = torch.arange(0, T_current, dtype=torch.long, device=device)
113
+ pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices
114
+ pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C)
115
+ x = current_embeds + pos_emb
116
+
117
+ # Create attention mask (all ones, causal handles future)
118
+ # Note: We don't need padding mask here as we handle one sequence without padding
119
+ attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long)
120
+
121
+ # Pass through Decoder Blocks
122
+ for block in model.decoder.blocks:
123
+ # We assume the block forward takes (x, attention_mask)
124
+ x = block(x, attention_mask=attention_mask)
125
+
126
+ # Final Layer Norm and LM Head for the *last* token prediction
127
+ x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C)
128
+ logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V)
129
+ logits = logits.squeeze(1) # (B, V) -> (1, V)
130
+
131
+ # Sampling
132
+ logits = logits / temperature
133
+ if top_k is not None and top_k > 0:
134
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
135
+ logits[logits < v[:, [-1]]] = -float('Inf')
136
+
137
+ probs = F.softmax(logits, dim=-1)
138
+ # idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution
139
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic
140
+
141
+ # Store generated ID
142
+ generated_ids.append(idx_next)
143
+
144
+ # Stop if EOS token is generated
145
+ if idx_next.item() == tokenizer.eos_token_id:
146
+ print("EOS token generated.")
147
+ break
148
+
149
+ # Prepare for next iteration: Append embedding of new token
150
+ next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C)
151
+ current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim
152
+
153
+ # --- 4. Combine and Decode Results ---
154
+ if generated_ids:
155
+ generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated)
156
+ initial_target_ids = torch.tensor([[result_start_token_id]], device=device)
157
+ full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1)
158
+ else:
159
+ full_generated_sequence_ids = prompt_ids # Nothing was generated
160
+
161
+ full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False)
162
+ print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}")
163
+
164
+ # --- 5. Save visualization to file ---
165
+ save_coords_visualization(
166
+ image_tensor=image_tensor[0], # Remove batch dim for visualization
167
+ full_decoded_text=full_decoded_text,
168
+ tokenizer=tokenizer,
169
+ image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined
170
+ num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined
171
+ output_path=output_path
172
+ )
173
+ print(f"Visualization saved to: {output_path}")
174
+
175
+ except Exception as e:
176
+ print(f"An error occurred during sample generation: {e}")
177
+ import traceback
178
+ traceback.print_exc()
179
+
180
+ def generate_sample_from_test_loader(
181
+ model,
182
+ test_loader,
183
+ tokenizer,
184
+ device,
185
+ max_new_tokens=70,
186
+ temperature=0.8,
187
+ top_k=10,
188
+ output_path="generated_output.png",
189
+ TEST_BATCH=8,
190
+ TEST_IDX=1
191
+ ):
192
+ """
193
+ Generates a prediction for one sample from the test loader and saves it to a file.
194
+ Generation loop is implemented *within* this function.
195
+
196
+ Args:
197
+ model: The trained VisionLanguageModel.
198
+ test_loader: DataLoader for the test set.
199
+ tokenizer: The tokenizer used for training.
200
+ device: The computation device ('cuda' or 'cpu').
201
+ max_new_tokens (int): Max tokens to generate after the prompt.
202
+ temperature (float): Softmax temperature for sampling.
203
+ top_k (int): K for top-k sampling (0 or None to disable).
204
+ output_path (str): Path where to save the output image.
205
+
206
+ Returns:
207
+ None. Saves the image with prompt and generated output to a file.
208
+ """
209
+
210
+ if not test_loader or len(test_loader.dataset) == 0:
211
+ print("Test loader is empty or not available.")
212
+ return
213
+
214
+ model.eval() # Set the model to evaluation mode
215
+
216
+ try:
217
+ # Get a single batch from the test loader
218
+ with torch.no_grad(): # No need to track gradients during inference
219
+ my_iter = iter(test_loader)
220
+ for i in range(TEST_BATCH):
221
+ _ = next(my_iter)
222
+ batch = next(my_iter)
223
+
224
+ if batch is None:
225
+ print("Test loader yielded an empty batch.")
226
+ return
227
+ if batch['image'].shape[0] == 0:
228
+ print("Test loader yielded a batch with 0 items.")
229
+ return
230
+
231
+ # --- 1. Prepare Initial Inputs ---
232
+ image_tensor = batch['image'][TEST_IDX:TEST_IDX+1].to(device) # (1, 3, H, W)
233
+ prompt_ids = batch['prompt_ids'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt)
234
+ prompt_attention_mask = batch['prompt_attention_mask'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt)
235
+ label = batch['label'][TEST_IDX]
236
+ B = 1 # We are processing one sample at a time
237
+
238
+ print(f"--- Generating Sample (Manual Loop) ---")
239
+ print(f"Original Label/Prompt Hint: {label}")
240
+ prompt_text = tokenizer.decode(prompt_ids[0], skip_special_tokens=False)
241
+ print(f"Input Prompt Tokens Decoded: {prompt_text}")
242
+
243
+ # --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) ---
244
+ image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C)
245
+ image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C)
246
+ prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C)
247
+
248
+ result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
249
+ result_start_embed = model.decoder.token_embedding_table(
250
+ torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C)
251
+ )
252
+
253
+ # The initial sequence fed to the decoder blocks consists of image + prompt
254
+ current_embeds = torch.cat([
255
+ image_embeds_decoder,
256
+ prompt_embeds_decoder,
257
+ result_start_embed # Add the embedding for the first expected output token
258
+ ], dim=1)
259
+ # current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1) # (1, T_initial, C)
260
+ generated_ids = [] # Store newly generated IDs
261
+
262
+ # --- 3. Autoregressive Generation Loop ---
263
+ for _ in range(max_new_tokens):
264
+ T_current = current_embeds.shape[1]
265
+
266
+ # Truncate if necessary (keep recent context)
267
+ if T_current > model.decoder.max_context: # Access max_context from decoder
268
+ print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}")
269
+ current_embeds = current_embeds[:, -model.decoder.max_context:, :]
270
+ T_current = model.decoder.max_context
271
+
272
+ # Prepare positional embeddings for current length
273
+ pos = torch.arange(0, T_current, dtype=torch.long, device=device)
274
+ pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices
275
+ pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C)
276
+ x = current_embeds + pos_emb
277
+
278
+ # Create attention mask (all ones, causal handles future)
279
+ # Note: We don't need padding mask here as we handle one sequence without padding
280
+ attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long)
281
+
282
+ # Pass through Decoder Blocks
283
+ for block in model.decoder.blocks:
284
+ # We assume the block forward takes (x, attention_mask)
285
+ x = block(x, attention_mask=attention_mask)
286
+
287
+ # Final Layer Norm and LM Head for the *last* token prediction
288
+ x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C)
289
+ logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V)
290
+ logits = logits.squeeze(1) # (B, V) -> (1, V)
291
+
292
+ # Sampling
293
+ logits = logits / temperature
294
+ if top_k is not None and top_k > 0:
295
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
296
+ logits[logits < v[:, [-1]]] = -float('Inf')
297
+
298
+ probs = F.softmax(logits, dim=-1)
299
+ # idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution
300
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic
301
+
302
+ # Store generated ID
303
+ generated_ids.append(idx_next)
304
+
305
+ # Stop if EOS token is generated
306
+ if idx_next.item() == tokenizer.eos_token_id:
307
+ print("EOS token generated.")
308
+ break
309
+
310
+ # Prepare for next iteration: Append embedding of new token
311
+ next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C)
312
+ current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim
313
+
314
+ # --- 4. Combine and Decode Results ---
315
+ if generated_ids:
316
+ generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated)
317
+ initial_target_ids = torch.tensor([[result_start_token_id]], device=device)
318
+ full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1)
319
+ else:
320
+ full_generated_sequence_ids = prompt_ids # Nothing was generated
321
+
322
+ full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False)
323
+ print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}")
324
+
325
+ # --- 5. Save visualization to file ---
326
+ save_coords_visualization(
327
+ image_tensor=image_tensor[0], # Remove batch dim for visualization
328
+ full_decoded_text=full_decoded_text,
329
+ tokenizer=tokenizer,
330
+ image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined
331
+ num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined
332
+ output_path=output_path
333
+ )
334
+ print(f"Visualization saved to: {output_path}")
335
+
336
+ except StopIteration:
337
+ print("Test loader is exhausted.")
338
+ except Exception as e:
339
+ print(f"An error occurred during sample generation: {e}")
340
+ import traceback
341
+ traceback.print_exc()
342
+
343
+ def parse_coordinate_tokens(text, tokenizer, num_bins):
344
+ """
345
+ Parses generated text to extract coordinate bin tokens.
346
+
347
+ Args:
348
+ text (str): The decoded output text from the model.
349
+ tokenizer: The tokenizer.
350
+ num_bins (int): The number of coordinate bins used.
351
+
352
+ Returns:
353
+ list[tuple(int, int)]: A list of (x_bin, y_bin) tuples, or None if parsing fails.
354
+ """
355
+ coords = []
356
+ try:
357
+ # Basic parsing - look for the pattern
358
+ x_start_token = "<pointx_start>"
359
+ x_end_token = "<pointx_end>"
360
+ y_start_token = "<pointy_start>"
361
+ y_end_token = "<pointy_end>"
362
+ result_end_token = "<result_end>"
363
+
364
+ # Find where the actual results start
365
+ try:
366
+ start_index = text.index("<result_start>") + len("<result_start>")
367
+ except ValueError:
368
+ print("Warning: <result_start> not found in generated text.")
369
+ return None
370
+
371
+ # Find where results end
372
+ try:
373
+ end_index = text.index(result_end_token, start_index)
374
+ except ValueError:
375
+ end_index = len(text) # Use end of string if <result_end> is missing
376
+ print(f"Warning: {result_end_token} not found. Parsing until end of string.")
377
+
378
+
379
+ current_pos = start_index
380
+ while current_pos < end_index:
381
+ # Find next X coordinate
382
+ x_start_idx = text.find(x_start_token, current_pos)
383
+ if x_start_idx == -1 or x_start_idx >= end_index: break # No more x points found
384
+ x_start_idx += len(x_start_token)
385
+
386
+ x_end_idx = text.find(x_end_token, x_start_idx)
387
+ if x_end_idx == -1 or x_end_idx >= end_index: break # Malformed
388
+
389
+ x_token_str = text[x_start_idx:x_end_idx].strip()
390
+
391
+ # Find next Y coordinate (must follow X)
392
+ y_start_idx = text.find(y_start_token, x_end_idx)
393
+ if y_start_idx == -1 or y_start_idx >= end_index: break # No corresponding y point
394
+ y_start_idx += len(y_start_token)
395
+
396
+ y_end_idx = text.find(y_end_token, y_start_idx)
397
+ if y_end_idx == -1 or y_end_idx >= end_index: break # Malformed
398
+
399
+ y_token_str = text[y_start_idx:y_end_idx].strip()
400
+
401
+ x_token_str = x_token_str[:-1]
402
+ y_token_str = y_token_str[:-1]
403
+
404
+ # Convert token strings to bin numbers
405
+ try:
406
+ x_bin = int(x_token_str.split("_")[-1])
407
+ y_bin = int(y_token_str.split("_")[-1])
408
+ if 0 <= x_bin < num_bins and 0 <= y_bin < num_bins:
409
+ coords.append((x_bin, y_bin))
410
+ else:
411
+ print(f"Warning: Parsed bin indices out of range ({x_bin}, {y_bin}). Skipping.")
412
+ except (ValueError, IndexError):
413
+ print(f"Warning: Could not parse bins from tokens '{x_token_str}', '{y_token_str}'. Skipping.")
414
+
415
+ # Move search position past the found Y token
416
+ current_pos = y_end_idx + len(y_end_token)
417
+
418
+ return coords if coords else None
419
+
420
+ except Exception as e:
421
+ print(f"Error during coordinate parsing: {e}")
422
+ return None
423
+
424
+
425
+ def save_coords_visualization(image_tensor, full_decoded_text, tokenizer, image_size, num_bins, output_path):
426
+ """Parses coords, draws them on the image, and saves to a file."""
427
+ parsed_bins = parse_coordinate_tokens(full_decoded_text, tokenizer, num_bins)
428
+
429
+ # Convert tensor to PIL image for drawing
430
+ try:
431
+ pil_image = tensor_to_image(image_tensor.cpu()) # Ensure tensor is on CPU
432
+ except Exception as e:
433
+ print(f"Error converting tensor to image: {e}")
434
+ # Create a placeholder image if conversion fails
435
+ pil_image = Image.new('RGB', (image_size, image_size), color='white')
436
+ draw = ImageDraw.Draw(pil_image)
437
+ draw.text((10, 10), "Image conversion failed", fill="black")
438
+ pil_image.save(output_path)
439
+ return
440
+
441
+ draw = ImageDraw.Draw(pil_image)
442
+ radius = 5 # Radius of the drawn point
443
+
444
+ if parsed_bins:
445
+ print(f"\nParsed Coordinate Bins: {parsed_bins}")
446
+ bin_size_pixels = image_size / num_bins
447
+ for x_bin, y_bin in parsed_bins:
448
+ # Calculate center of the bin in pixels
449
+ center_x = (x_bin + 0.5) * bin_size_pixels
450
+ center_y = (y_bin + 0.5) * bin_size_pixels
451
+
452
+ # Draw a circle
453
+ bbox = [center_x - radius, center_y - radius, center_x + radius, center_y + radius]
454
+ draw.ellipse(bbox, outline="red", width=3)
455
+ # Optional: Draw bin boundaries for debugging
456
+ # draw.rectangle([x_bin*bin_size_pixels, y_bin*bin_size_pixels, (x_bin+1)*bin_size_pixels, (y_bin+1)*bin_size_pixels], outline="blue", width=1)
457
+
458
+ # Add a text label with the coordinates at the top of the image
459
+ coord_text = f"Generated Point(s): {parsed_bins}"
460
+ draw.text((10, 10), coord_text, fill="red")
461
+ else:
462
+ print("\nCould not parse valid coordinates from the generated text.")
463
+ # Add a text label indicating no coordinates were found
464
+ draw.text((10, 10), "No Coordinates Parsed", fill="red")
465
+
466
+ # Save the image to file
467
+ pil_image.save(output_path)
468
+
469
+
470
+ import argparse
471
+
472
+ # --- Example Usage ---
473
+ # python infer.py --image ./data/test_images/image_1.png --prompt "a red apple"
474
+ if __name__ == "__main__":
475
+ parser = argparse.ArgumentParser()
476
+ parser.add_argument('--image', type=str, help='Path to input image')
477
+ parser.add_argument('--prompt', type=str, help='Prompt label for generation')
478
+ args = parser.parse_args()
479
+ if args.image and args.prompt:
480
+ # Use image and prompt based generation
481
+ if 'model' in locals() and 'tokenizer' in locals():
482
+ generate_sample_from_image_text(
483
+ model=model,
484
+ image_path=args.image,
485
+ prompt_label=args.prompt,
486
+ tokenizer=tokenizer,
487
+ device=DEVICE,
488
+ output_path="model_prediction.png"
489
+ )
490
+ else:
491
+ print("Please ensure 'model' and 'tokenizer' are loaded before running generation.")
492
+ else:
493
+ # Use test loader based generation
494
+ if 'model' in locals() and 'test_loader' in locals() and 'tokenizer' in locals():
495
+ test_loader = create_test_dataloader(batch_size=2, num_workers=0)
496
+ generate_sample_from_test_loader(
497
+ model=model,
498
+ test_loader=test_loader,
499
+ tokenizer=tokenizer,
500
+ device=DEVICE,
501
+ output_path="model_prediction.png"
502
+ )
503
+ else:
504
+ print("Please ensure 'model', 'test_loader', and 'tokenizer' are loaded before running generation.")
model_components.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class PatchEmbeddings(nn.Module):
7
+ def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM):
8
+ super().__init__()
9
+ self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
10
+
11
+ def forward(self, X):
12
+ X = self.conv(X) # (B, C, H/P, W/P)
13
+ X = X.flatten(2) # (B, C, N) where N = (H/P)*(W/P)
14
+ X = X.transpose(1, 2) # (B, N, C)
15
+ return X
16
+
17
+ class Head(nn.Module):
18
+ def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False):
19
+ super().__init__()
20
+ self.key = nn.Linear(n_embd, head_size, bias=False)
21
+ self.query = nn.Linear(n_embd, head_size, bias=False)
22
+ self.value = nn.Linear(n_embd, head_size, bias=False)
23
+ self.dropout = nn.Dropout(dropout)
24
+ self.is_decoder = is_decoder
25
+ # causal mask is registered persistent=False so it's not saved in state_dict
26
+ if self.is_decoder:
27
+ self.register_buffer("bias", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool))
28
+ .view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False)
29
+
30
+
31
+ def forward(self, x, attention_mask=None):
32
+ B, T, C = x.shape
33
+ # print(f"B = {B} T={T}, C={C}")
34
+ k = self.key(x) # (B, T, hs)
35
+ q = self.query(x) # (B, T, hs)
36
+ v = self.value(x) # (B, T, hs)
37
+
38
+ # Compute attention scores ("affinities")
39
+ wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
40
+
41
+ if self.is_decoder:
42
+ # Apply causal mask
43
+ # Ensure the mask is sliced correctly if T < CONTEXT_LENGTH
44
+ causal_mask = self.bias[:, :T, :T]
45
+ wei = wei.masked_fill(causal_mask == 0, float('-inf'))
46
+
47
+ if attention_mask is not None:
48
+ # Apply padding mask (for text tokens)
49
+ # attention_mask shape: (B, T_combined) -> needs expansion
50
+ # Expand mask: (B, T) -> (B, 1, 1, T) or (B, 1, T, T) depending on what needs masking
51
+ # Mask where attention_mask is 0
52
+ # attention_mask shape: (B, T) == (B, T_key)
53
+ # Expand mask to align with wei's key dimension for broadcasting across queries
54
+ # Target shape for mask: [B, 1, T_key]
55
+ # print(f"attn mask = {attention_mask.shape}")
56
+ # print(f"wei shape = {wei.shape}")
57
+ mask = attention_mask.unsqueeze(1) # Shape [B, 1, T]
58
+ # Apply mask using broadcasting rules. masked_fill condition needs to be broadcastable to wei [B, T_query, T_key]
59
+ # (mask == 0) gives a boolean tensor of shape [B, 1, T]
60
+ # This broadcasts correctly: dim 2 (T vs T) matches, dim 1 (1 vs T) broadcasts 1->T, dim 0 (B vs B) matches.
61
+ wei = wei.masked_fill(mask == 0, float('-inf'))
62
+
63
+
64
+ # Apply softmax
65
+ wei = F.softmax(wei, dim=-1)
66
+ wei = self.dropout(wei)
67
+
68
+ # Perform weighted aggregation of values
69
+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
70
+ # print(f"out shape = {out.shape}")
71
+ return out
72
+
73
+ class MultiHeadAttention(nn.Module):
74
+ def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
75
+ super().__init__()
76
+ assert n_embd % num_heads == 0
77
+ head_size = n_embd // num_heads
78
+ self.heads = nn.ModuleList([
79
+ Head(n_embd, head_size, dropout, is_decoder)
80
+ for _ in range(num_heads)
81
+ ])
82
+ self.proj = nn.Linear(n_embd, n_embd) # n_embd = num_heads * head_size
83
+ self.dropout = nn.Dropout(dropout)
84
+ self.is_decoder = is_decoder # Store is_decoder status
85
+
86
+ def forward(self, x, attention_mask=None):
87
+ # Pass attention_mask only if it's a decoder block dealing with combined sequence
88
+ out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1)
89
+ out = self.dropout(self.proj(out))
90
+ return out
91
+
92
+
93
+ class FeedForward(nn.Module):
94
+ """ a simple linear layer followed by a non-linearity """
95
+ def __init__(self, n_embd, dropout=DROPOUT):
96
+ super().__init__()
97
+ self.net = nn.Sequential(
98
+ nn.Linear(n_embd, 4 * n_embd),
99
+ nn.GELU(), # Changed from ReLU to GELU, common in transformers
100
+ nn.Linear(4 * n_embd, n_embd), # Projection back to residual stream
101
+ nn.Dropout(dropout),
102
+ )
103
+
104
+ def forward(self, x):
105
+ return self.net(x)
106
+
107
+ class Block(nn.Module):
108
+ """ Transformer block: communication followed by computation """
109
+ def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
110
+ super().__init__()
111
+ self.ln1 = nn.LayerNorm(n_embd)
112
+ self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
113
+ self.ln2 = nn.LayerNorm(n_embd)
114
+ self.ffn = FeedForward(n_embd, dropout)
115
+ self.is_decoder = is_decoder # Store is_decoder status
116
+
117
+ def forward(self, x, attention_mask=None):
118
+ # Pass attention_mask only if it's a decoder block
119
+ # print(f"is decoder = {self.is_decoder} input shape = {x.shape}")
120
+ x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None)
121
+ x = x + self.ffn(self.ln2(x))
122
+ # print(f"output shape = {x.shape}")
123
+ return x
124
+
125
+ class ViT(nn.Module):
126
+ def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM,
127
+ num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT):
128
+ super().__init__()
129
+ self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens)
130
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
131
+ num_patches = (img_size // patch_size) ** 2
132
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) # Smaller init
133
+ self.dropout = nn.Dropout(emb_dropout)
134
+ # ViT blocks are NOT decoders (no causal mask)
135
+ self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
136
+ self.layer_norm = nn.LayerNorm(num_hiddens) # Final LN
137
+
138
+ def forward(self, X):
139
+ x = self.patch_embedding(X) # (B, N, C)
140
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, C)
141
+ x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, C)
142
+ # Add positional embedding
143
+ x = x + self.pos_embedding # Uses broadcasting
144
+ x = self.dropout(x)
145
+ for block in self.blocks:
146
+ # ViT blocks don't need attention_mask
147
+ x = block(x)
148
+ x = self.layer_norm(x) # Apply final layer norm
149
+ return x
150
+
151
+ class MultiModalProjector(nn.Module):
152
+ # Projects image embedding dim to text embedding dim
153
+ def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT):
154
+ super().__init__()
155
+ self.net = nn.Sequential(
156
+ nn.Linear(image_embed_dim, text_embed_dim * 4), # Intermediate expansion
157
+ nn.GELU(),
158
+ nn.Linear(text_embed_dim * 4, text_embed_dim),
159
+ nn.Dropout(dropout)
160
+ )
161
+
162
+ def forward(self, x):
163
+ return self.net(x)
train.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from dataset import create_train_dataloader, create_test_dataloader
3
+ from vision_language_model import VisionLanguageModel
4
+ from utils import *
5
+ from datetime import datetime
6
+ import wandb
7
+ import torch
8
+ import torch.optim as optim
9
+ from torch.optim.lr_scheduler import OneCycleLR
10
+ from tqdm.auto import tqdm
11
+
12
+ print(f"Using device: {DEVICE}")
13
+ print(f"Vocab size: {vocab_size}")
14
+
15
+ # --- Initialize Model ---
16
+ # Ensure lambda_regression is passed during initialization
17
+ model = VisionLanguageModel(
18
+ n_embd=HIDDEN_DIM,
19
+ vocab_size=vocab_size,
20
+ img_size=IMAGE_SIZE,
21
+ patch_size=PATCH_SIZE,
22
+ num_heads=NUM_HEADS,
23
+ num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
24
+ num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
25
+ emb_dropout=DROPOUT,
26
+ blk_dropout=DROPOUT,
27
+ max_context=CONTEXT_LENGTH,
28
+ shared_embed_dim=SHARED_EMBED_DIM,
29
+ lambda_contrastive=LAMBDA_CONTRASTIVE,
30
+ lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
31
+ ).to(DEVICE)
32
+
33
+ # --- Optimizer ---
34
+ # Optimizer will automatically include all model parameters, including the new regression head
35
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)
36
+
37
+ # --- Dataloaders ---
38
+ # Ensure these functions now return 'continuous_coords' in the batch dictionary
39
+ train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first
40
+ test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2)
41
+ if train_loader is None: exit("Training loader failed to initialize.")
42
+ test_loader_has_data = test_loader and len(test_loader.dataset) > 0
43
+
44
+ # --- LR Scheduler ---
45
+ if train_loader and len(train_loader) > 0:
46
+ steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0)
47
+ total_steps = steps_per_epoch * NUM_EPOCHS
48
+ # Adjust warmup steps if total steps are very low
49
+ warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup
50
+ print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}")
51
+ lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)
52
+ else:
53
+ print("Warning: Train loader empty. Using constant LR.")
54
+ total_steps = 0; warmup_steps = 0
55
+ lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
56
+
57
+ # --- Wandb Setup ---
58
+ try:
59
+ wandb.init(
60
+ # project="point-language-model-dualhead", # Suggest new project name
61
+ project="point-language-model-regression-vast",
62
+ name=f"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
63
+ config={ # Add new hyperparameters
64
+ "image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM,
65
+ "context_length": CONTEXT_LENGTH, "dropout": DROPOUT,
66
+ "num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE,
67
+ "learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS,
68
+ "shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE,
69
+ "lambda_regression": LAMBDA_REGRESSION, # Log regression weight
70
+ "architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW",
71
+ "num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps
72
+ }
73
+ )
74
+ wandb_enabled = True
75
+ # Watch model gradients and parameters
76
+ # wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)
77
+ except Exception as e:
78
+ print(f"Wandb initialization failed: {e}. Running without wandb.")
79
+ wandb_enabled = False
80
+
81
+ # --- Training Loop ---
82
+ print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...")
83
+ step_counter = 0
84
+ optimizer.zero_grad()
85
+
86
+ for epoch in range(NUM_EPOCHS):
87
+ model.train()
88
+ epoch_total_loss_accum = 0.0
89
+ epoch_class_loss_accum = 0.0
90
+ epoch_con_loss_accum = 0.0
91
+ epoch_reg_loss_accum = 0.0
92
+ batches_since_log = 0
93
+ valid_batches_accum = 0 # Count batches with valid loss for averaging
94
+
95
+ pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
96
+
97
+ for batch_idx, batch in pbar:
98
+ if batch is None: continue
99
+
100
+ # --- Unpack Batch Data ---
101
+ try:
102
+ images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
103
+ prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
104
+ prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
105
+ target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
106
+ target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
107
+ generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)
108
+ continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded
109
+ coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) # Mask
110
+ except KeyError as e:
111
+ print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.")
112
+ continue
113
+
114
+ # Clamp logit_scale
115
+ with torch.no_grad():
116
+ model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))
117
+
118
+ # --- Forward Pass ---
119
+ # Model now returns potentially NaN scalar tensors for individual losses if invalid
120
+ logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model(
121
+ img_array=images,
122
+ prompt_ids=prompt_ids,
123
+ prompt_attention_mask=prompt_attention_mask,
124
+ target_ids=target_ids,
125
+ target_attention_mask=target_attention_mask,
126
+ generative_targets=generative_targets,
127
+ continuous_coords=continuous_coords,
128
+ coords_mask=coords_mask # Pass mask for regression loss calculation
129
+ )
130
+
131
+ # --- Loss Handling & Accumulation ---
132
+ # Check for invalid total loss before backward pass
133
+ if total_loss is None or not torch.isfinite(total_loss):
134
+ print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.")
135
+ optimizer.zero_grad() # Reset gradients for safety if loss is invalid
136
+ continue # Skip this batch for optimization step
137
+
138
+ # Scale loss for gradient accumulation
139
+ scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS
140
+
141
+ # Accumulate valid loss components for logging
142
+ # Check if the scalar tensor is finite before adding its item()
143
+ if torch.isfinite(total_loss):
144
+ epoch_total_loss_accum += total_loss.item()
145
+ valid_batches_accum += 1 # Increment count of batches contributing to average loss
146
+ if torch.isfinite(class_loss_s):
147
+ epoch_class_loss_accum += class_loss_s.item()
148
+ if torch.isfinite(contrastive_loss_s):
149
+ epoch_con_loss_accum += contrastive_loss_s.item()
150
+ if torch.isfinite(regression_loss_s):
151
+ epoch_reg_loss_accum += regression_loss_s.item()
152
+ batches_since_log += 1
153
+
154
+ # --- Backward Pass ---
155
+ try:
156
+ scaled_loss.backward()
157
+ except Exception as e:
158
+ print(f"Error during backward pass: {e}. Skipping step.")
159
+ optimizer.zero_grad() # Reset gradients if backward failed
160
+ continue
161
+
162
+ # --- Gradient Accumulation Step ---
163
+ if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):
164
+ # Clip gradients
165
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
166
+
167
+ # Check for non-finite gradients before stepping
168
+ all_finite = True
169
+ for p in model.parameters():
170
+ if p.grad is not None and not torch.isfinite(p.grad).all():
171
+ all_finite = False
172
+ break
173
+ if not all_finite:
174
+ print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.")
175
+ optimizer.zero_grad()
176
+ continue # Skip optimizer step and scheduler step
177
+
178
+ # Optimizer step
179
+ optimizer.step()
180
+ lr_scheduler.step()
181
+ optimizer.zero_grad()
182
+
183
+ step_counter += 1
184
+
185
+ # --- Logging ---
186
+ if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum
187
+ # Calculate average losses over the logging period using valid batch count
188
+ avg_total_loss = epoch_total_loss_accum / valid_batches_accum
189
+ avg_class_loss = epoch_class_loss_accum / valid_batches_accum
190
+ avg_con_loss = epoch_con_loss_accum / valid_batches_accum
191
+ avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum
192
+ current_lr = optimizer.param_groups[0]['lr']
193
+
194
+ # --- Test Evaluation (Needs modification to handle mask) ---
195
+ test_class_loss_val = float('nan')
196
+ test_con_loss_val = float('nan')
197
+ test_reg_loss_val = float('nan')
198
+ if test_loader_has_data:
199
+ model.eval()
200
+ with torch.no_grad():
201
+ try:
202
+ test_batch = next(iter(test_loader))
203
+ if test_batch:
204
+ t_images = test_batch['image'].to(DEVICE).to(DTYPE)
205
+ t_p_ids = test_batch['prompt_ids'].to(DEVICE)
206
+ t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
207
+ t_t_ids = test_batch['target_ids'].to(DEVICE)
208
+ t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
209
+ t_gen_targets = test_batch['generative_targets'].to(DEVICE)
210
+ t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded
211
+ t_coords_mask = test_batch['coords_mask'].to(DEVICE) # Mask
212
+
213
+ _, _, _, t_class_loss, t_con_loss, t_reg_loss = model(
214
+ t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,
215
+ t_gen_targets, t_cont_coords, t_coords_mask # Pass mask
216
+ )
217
+ # Use .item() only if the tensor is finite
218
+ test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
219
+ test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan')
220
+ test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan')
221
+ # ... (rest of exception handling) ...
222
+ except StopIteration: print("Info: Test loader exhausted during logging.")
223
+ except KeyError as e: print(f"Error: Missing key {e} in test batch.")
224
+ except Exception as e: print(f"Error during test evaluation: {e}")
225
+ model.train()
226
+
227
+ # Prepare data for logging
228
+ log_data = {
229
+ "train/total_loss": avg_total_loss,
230
+ "train/class_loss": avg_class_loss,
231
+ "train/contrastive_loss": avg_con_loss,
232
+ "train/regression_loss": avg_reg_loss,
233
+ "test/class_loss": test_class_loss_val,
234
+ "test/contrastive_loss": test_con_loss_val,
235
+ "test/regression_loss": test_reg_loss_val,
236
+ "epoch": epoch + ((batch_idx + 1) / len(train_loader)),
237
+ "step": step_counter,
238
+ "learning_rate": current_lr,
239
+ "gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
240
+ "logit_scale": model.logit_scale.exp().item()
241
+ }
242
+ # Update progress bar
243
+ pbar.set_postfix({
244
+ "lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}",
245
+ "cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}",
246
+ "reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}"
247
+ })
248
+ if wandb_enabled: wandb.log(log_data)
249
+
250
+ # Reset accumulators
251
+ epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0
252
+ batches_since_log = 0
253
+ valid_batches_accum = 0 # Reset valid batch count
254
+
255
+ # --- End of Epoch ---
256
+ print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.")
257
+ # Optional: Add end-of-epoch evaluation or model saving here
258
+ if epoch % 5 == 0:
259
+ torch.save(model.state_dict(), f"model_regression_multi_{epoch+1}.pth")
260
+
261
+ # --- End of Training ---
262
+ print("\nTraining completed!")
263
+ if wandb_enabled:
264
+ wandb.finish()
train_stage_2.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from dataset import create_train_dataloader, create_test_dataloader
3
+ from vision_language_model import VisionLanguageModel
4
+ from utils import *
5
+ from datetime import datetime
6
+ import wandb
7
+ import torch
8
+ import torch.optim as optim
9
+ from torch.optim.lr_scheduler import OneCycleLR
10
+ from tqdm.auto import tqdm
11
+
12
+ print(f"Using device: {DEVICE}")
13
+ print(f"Vocab size: {vocab_size}")
14
+
15
+ # --- Initialize Model ---
16
+ # Ensure lambda_regression is passed during initialization
17
+ model = VisionLanguageModel(
18
+ n_embd=HIDDEN_DIM,
19
+ vocab_size=vocab_size,
20
+ img_size=IMAGE_SIZE,
21
+ patch_size=PATCH_SIZE,
22
+ num_heads=NUM_HEADS,
23
+ num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
24
+ num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
25
+ emb_dropout=0.0,
26
+ blk_dropout=0.0,
27
+ max_context=CONTEXT_LENGTH,
28
+ shared_embed_dim=SHARED_EMBED_DIM,
29
+ lambda_contrastive=LAMBDA_CONTRASTIVE,
30
+ lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
31
+ ).to(DEVICE)
32
+
33
+ NUM_EPOCHS = 100
34
+ model.load_state_dict(torch.load("model_regression_multi_16.pth", weights_only=True)) # we ran till 15 before it over fitted with higher learning rate
35
+
36
+ # --- Optimizer ---
37
+ # Optimizer will automatically include all model parameters, including the new regression head
38
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1) # lower learning rate for second stage
39
+
40
+ # --- Dataloaders ---
41
+ # Ensure these functions now return 'continuous_coords' in the batch dictionary
42
+ train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first
43
+ test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2)
44
+ if train_loader is None: exit("Training loader failed to initialize.")
45
+ test_loader_has_data = test_loader and len(test_loader.dataset) > 0
46
+
47
+ # --- LR Scheduler ---
48
+ if train_loader and len(train_loader) > 0:
49
+ steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0)
50
+ total_steps = steps_per_epoch * NUM_EPOCHS
51
+ # Adjust warmup steps if total steps are very low
52
+ warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup
53
+ print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}")
54
+ lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)
55
+ else:
56
+ print("Warning: Train loader empty. Using constant LR.")
57
+ total_steps = 0; warmup_steps = 0
58
+ lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
59
+
60
+ # --- Wandb Setup ---
61
+ try:
62
+ wandb.init(
63
+ # project="point-language-model-dualhead", # Suggest new project name
64
+ project="point-language-model-regression-vast",
65
+ name=f"point-vlm-dual-stage-2-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
66
+ config={ # Add new hyperparameters
67
+ "image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM,
68
+ "context_length": CONTEXT_LENGTH, "dropout": DROPOUT,
69
+ "num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE,
70
+ "learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS,
71
+ "shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE,
72
+ "lambda_regression": LAMBDA_REGRESSION, # Log regression weight
73
+ "architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW",
74
+ "num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps
75
+ }
76
+ )
77
+ wandb_enabled = True
78
+ # Watch model gradients and parameters
79
+ # wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)
80
+ except Exception as e:
81
+ print(f"Wandb initialization failed: {e}. Running without wandb.")
82
+ wandb_enabled = False
83
+
84
+ # --- Training Loop ---
85
+ print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...")
86
+ step_counter = 0
87
+ optimizer.zero_grad()
88
+
89
+ for epoch in range(NUM_EPOCHS):
90
+ model.train()
91
+ epoch_total_loss_accum = 0.0
92
+ epoch_class_loss_accum = 0.0
93
+ epoch_con_loss_accum = 0.0
94
+ epoch_reg_loss_accum = 0.0
95
+ batches_since_log = 0
96
+ valid_batches_accum = 0 # Count batches with valid loss for averaging
97
+
98
+ pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
99
+
100
+ for batch_idx, batch in pbar:
101
+ if batch is None: continue
102
+
103
+ # --- Unpack Batch Data ---
104
+ try:
105
+ images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
106
+ prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
107
+ prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
108
+ target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
109
+ target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
110
+ generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)
111
+ continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded
112
+ coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) # Mask
113
+ except KeyError as e:
114
+ print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.")
115
+ continue
116
+
117
+ # Clamp logit_scale
118
+ with torch.no_grad():
119
+ model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))
120
+
121
+ # --- Forward Pass ---
122
+ # Model now returns potentially NaN scalar tensors for individual losses if invalid
123
+ logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model(
124
+ img_array=images,
125
+ prompt_ids=prompt_ids,
126
+ prompt_attention_mask=prompt_attention_mask,
127
+ target_ids=target_ids,
128
+ target_attention_mask=target_attention_mask,
129
+ generative_targets=generative_targets,
130
+ continuous_coords=continuous_coords,
131
+ coords_mask=coords_mask # Pass mask for regression loss calculation
132
+ )
133
+
134
+ # --- Loss Handling & Accumulation ---
135
+ # Check for invalid total loss before backward pass
136
+ if total_loss is None or not torch.isfinite(total_loss):
137
+ print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.")
138
+ optimizer.zero_grad() # Reset gradients for safety if loss is invalid
139
+ continue # Skip this batch for optimization step
140
+
141
+ # Scale loss for gradient accumulation
142
+ scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS
143
+
144
+ # Accumulate valid loss components for logging
145
+ # Check if the scalar tensor is finite before adding its item()
146
+ if torch.isfinite(total_loss):
147
+ epoch_total_loss_accum += total_loss.item()
148
+ valid_batches_accum += 1 # Increment count of batches contributing to average loss
149
+ if torch.isfinite(class_loss_s):
150
+ epoch_class_loss_accum += class_loss_s.item()
151
+ if torch.isfinite(contrastive_loss_s):
152
+ epoch_con_loss_accum += contrastive_loss_s.item()
153
+ if torch.isfinite(regression_loss_s):
154
+ epoch_reg_loss_accum += regression_loss_s.item()
155
+ batches_since_log += 1
156
+
157
+ # --- Backward Pass ---
158
+ try:
159
+ scaled_loss.backward()
160
+ except Exception as e:
161
+ print(f"Error during backward pass: {e}. Skipping step.")
162
+ optimizer.zero_grad() # Reset gradients if backward failed
163
+ continue
164
+
165
+ # --- Gradient Accumulation Step ---
166
+ if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):
167
+ # Clip gradients
168
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
169
+
170
+ # Check for non-finite gradients before stepping
171
+ all_finite = True
172
+ for p in model.parameters():
173
+ if p.grad is not None and not torch.isfinite(p.grad).all():
174
+ all_finite = False
175
+ break
176
+ if not all_finite:
177
+ print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.")
178
+ optimizer.zero_grad()
179
+ continue # Skip optimizer step and scheduler step
180
+
181
+ # Optimizer step
182
+ optimizer.step()
183
+ lr_scheduler.step()
184
+ optimizer.zero_grad()
185
+
186
+ step_counter += 1
187
+
188
+ # --- Logging ---
189
+ if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum
190
+ # Calculate average losses over the logging period using valid batch count
191
+ avg_total_loss = epoch_total_loss_accum / valid_batches_accum
192
+ avg_class_loss = epoch_class_loss_accum / valid_batches_accum
193
+ avg_con_loss = epoch_con_loss_accum / valid_batches_accum
194
+ avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum
195
+ current_lr = optimizer.param_groups[0]['lr']
196
+
197
+ # --- Test Evaluation (Needs modification to handle mask) ---
198
+ test_class_loss_val = float('nan')
199
+ test_con_loss_val = float('nan')
200
+ test_reg_loss_val = float('nan')
201
+ if test_loader_has_data:
202
+ model.eval()
203
+ with torch.no_grad():
204
+ try:
205
+ test_batch = next(iter(test_loader))
206
+ if test_batch:
207
+ t_images = test_batch['image'].to(DEVICE).to(DTYPE)
208
+ t_p_ids = test_batch['prompt_ids'].to(DEVICE)
209
+ t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
210
+ t_t_ids = test_batch['target_ids'].to(DEVICE)
211
+ t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
212
+ t_gen_targets = test_batch['generative_targets'].to(DEVICE)
213
+ t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded
214
+ t_coords_mask = test_batch['coords_mask'].to(DEVICE) # Mask
215
+
216
+ _, _, _, t_class_loss, t_con_loss, t_reg_loss = model(
217
+ t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,
218
+ t_gen_targets, t_cont_coords, t_coords_mask # Pass mask
219
+ )
220
+ # Use .item() only if the tensor is finite
221
+ test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
222
+ test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan')
223
+ test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan')
224
+ # ... (rest of exception handling) ...
225
+ except StopIteration: print("Info: Test loader exhausted during logging.")
226
+ except KeyError as e: print(f"Error: Missing key {e} in test batch.")
227
+ except Exception as e: print(f"Error during test evaluation: {e}")
228
+ model.train()
229
+
230
+ # Prepare data for logging
231
+ log_data = {
232
+ "train/total_loss": avg_total_loss,
233
+ "train/class_loss": avg_class_loss,
234
+ "train/contrastive_loss": avg_con_loss,
235
+ "train/regression_loss": avg_reg_loss,
236
+ "test/class_loss": test_class_loss_val,
237
+ "test/contrastive_loss": test_con_loss_val,
238
+ "test/regression_loss": test_reg_loss_val,
239
+ "epoch": epoch + ((batch_idx + 1) / len(train_loader)),
240
+ "step": step_counter,
241
+ "learning_rate": current_lr,
242
+ "gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
243
+ "logit_scale": model.logit_scale.exp().item()
244
+ }
245
+ # Update progress bar
246
+ pbar.set_postfix({
247
+ "lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}",
248
+ "cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}",
249
+ "reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}"
250
+ })
251
+ if wandb_enabled: wandb.log(log_data)
252
+
253
+ # Reset accumulators
254
+ epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0
255
+ batches_since_log = 0
256
+ valid_batches_accum = 0 # Reset valid batch count
257
+
258
+ # --- End of Epoch ---
259
+ print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.")
260
+ # Optional: Add end-of-epoch evaluation or model saving here
261
+ if epoch % 5 == 0:
262
+ torch.save(model.state_dict(), f"model_regression_multi_stage_2_{epoch+1}.pth")
263
+
264
+ # --- End of Training ---
265
+ print("\nTraining completed!")
266
+ if wandb_enabled:
267
+ wandb.finish()
utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+
9
+ def get_tokenizer():
10
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
11
+ point_tokens = [f"coord_bin_{i}" for i in range(0, NUM_BINS)]
12
+ new_tokens = [
13
+ "<point_start>", "<point_end>", "<result_start>",
14
+ "<result_end>", "<pointx_start>", "<pointx_end>",
15
+ "<pointy_start>", "<pointy_end>",
16
+ *point_tokens
17
+ ]
18
+ tokenizer.add_tokens(new_tokens)
19
+ # Ensure pad token is set (GPT2 usually doesn't have one by default)
20
+ if tokenizer.pad_token is None:
21
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Or use eos_token if preferred
22
+ # tokenizer.pad_token_id = tokenizer.eos_token_id # Alternative if we want padding to be EOS
23
+
24
+ print(f"Tokenizer pad token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
25
+ print(f"Tokenizer EOS token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")
26
+
27
+ # Check if pad token ID is valid
28
+ if tokenizer.pad_token_id is None:
29
+ raise ValueError("Tokenizer pad token ID is not set!")
30
+
31
+ return tokenizer, len(tokenizer)
32
+
33
+ def image_to_tensor(image, image_size=IMAGE_SIZE):
34
+ if image.mode != 'RGB':
35
+ image = image.convert('RGB')
36
+ # We avoid the hassle of calculating
37
+ # changed co-ordinates for rotation etc for now. Can be added later.
38
+ transform = transforms.Compose([
39
+ transforms.Resize((image_size, image_size)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD)
42
+ ])
43
+ return transform(image)
44
+
45
+ def tensor_to_image(tensor):
46
+ tensor = tensor.clone().detach()
47
+ if tensor.is_cuda:
48
+ tensor = tensor.cpu()
49
+ mean = torch.tensor(IMAGE_MEAN).view(3, 1, 1)
50
+ std = torch.tensor(IMAGE_STD).view(3, 1, 1)
51
+ tensor = tensor * std + mean
52
+ tensor = torch.clamp(tensor, 0, 1)
53
+ image_np = tensor.numpy().transpose(1, 2, 0)
54
+ image_np = (image_np * 255).astype(np.uint8)
55
+ return Image.fromarray(image_np)
56
+
57
+ tokenizer, vocab_size = get_tokenizer() # Initialize tokenizer globally
vision_language_model.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model_components import ViT, MultiModalProjector
2
+ from decoder_language_model import DecoderLanguageModel
3
+ from constants import *
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from utils import tokenizer, vocab_size
8
+
9
+
10
+ class VisionLanguageModel(nn.Module):
11
+ """
12
+ Vision Language Model integrating ViT, Projector, Contrastive Loss, Decoder (Class + Reg).
13
+ Handles multiple points via padded regression targets and masked loss.
14
+ """
15
+ def __init__(self,
16
+ n_embd=HIDDEN_DIM,
17
+ vocab_size=vocab_size,
18
+ img_size=IMAGE_SIZE,
19
+ patch_size=PATCH_SIZE,
20
+ num_heads=NUM_HEADS,
21
+ num_blks_vit=NUM_LAYERS,
22
+ num_blks_dec=NUM_LAYERS,
23
+ emb_dropout=DROPOUT,
24
+ blk_dropout=DROPOUT,
25
+ max_context=CONTEXT_LENGTH,
26
+ shared_embed_dim=SHARED_EMBED_DIM,
27
+ lambda_contrastive=LAMBDA_CONTRASTIVE,
28
+ lambda_regression=LAMBDA_REGRESSION, # Use the updated constant
29
+ max_points = MAX_POINTS # Store max points
30
+ ):
31
+ super().__init__()
32
+
33
+ # --- Vision Backbone ---
34
+ self.vision_encoder = ViT(
35
+ img_size=img_size,
36
+ patch_size=patch_size,
37
+ num_hiddens=n_embd, # Assuming ViT output dim matches decoder embed dim
38
+ num_heads=num_heads,
39
+ num_blks=num_blks_vit,
40
+ emb_dropout=emb_dropout,
41
+ blk_dropout=blk_dropout
42
+ )
43
+
44
+ # --- Multimodal Components ---
45
+ self.multimodal_projector = MultiModalProjector(
46
+ image_embed_dim=n_embd, # Input from ViT
47
+ text_embed_dim=n_embd, # Output matches decoder dim
48
+ dropout=emb_dropout
49
+ )
50
+ self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
51
+ self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
52
+ self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))
53
+
54
+ # --- Text Decoder ---
55
+ # DecoderLanguageModel now has regression head outputting MAX_POINTS*2
56
+ self.decoder = DecoderLanguageModel(
57
+ n_embd=n_embd,
58
+ vocab_size=vocab_size,
59
+ num_heads=num_heads,
60
+ n_layer=num_blks_dec,
61
+ max_context=max_context,
62
+ dropout=blk_dropout # Use block dropout for decoder consistency
63
+ )
64
+
65
+ # --- Store Configuration ---
66
+ self.n_embd = n_embd
67
+ self.vocab_size = vocab_size
68
+ self.num_patches = (img_size // patch_size)**2 + 1
69
+ self.lambda_contrastive = lambda_contrastive
70
+ self.lambda_regression = lambda_regression
71
+ self.max_points = max_points # Store max points
72
+
73
+ self._resize_embeddings_if_needed(self.vocab_size)
74
+ print("VisionLanguageModel initialized.")
75
+
76
+
77
+ def _resize_embeddings_if_needed(self, current_vocab_size):
78
+ """ Resizes decoder token embeddings if vocab size changed after init. """
79
+ decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings
80
+ if decoder_embedding_size != current_vocab_size:
81
+ print(f"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}")
82
+ # Freeze original weights before replacing layers
83
+ self.decoder.token_embedding_table.weight.requires_grad = False
84
+ self.decoder.lm_head.weight.requires_grad = False
85
+ # Create new layers
86
+ new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE)
87
+ new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE)
88
+ # Assign new layers
89
+ self.decoder.token_embedding_table = new_embedding
90
+ self.decoder.lm_head = new_lm_head
91
+ # Re-tie weights
92
+ self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight
93
+ print("VLM decoder embeddings resized and weights retied.")
94
+
95
+
96
+ def _calculate_contrastive_loss(self, image_features, text_features):
97
+ """ Calculates the symmetric InfoNCE loss. """
98
+ # Assumes features are already projected to shared_embed_dim
99
+ # image_features: (B, E)
100
+ # text_features: (B, E)
101
+
102
+ # Normalize features
103
+ image_features = F.normalize(image_features, dim=-1)
104
+ text_features = F.normalize(text_features, dim=-1)
105
+
106
+ # Cosine similarity as logits (using learnable temperature)
107
+ logit_scale = self.logit_scale.exp()
108
+ logits_per_image = logit_scale * image_features @ text_features.t()
109
+ logits_per_text = logits_per_image.t()
110
+
111
+ # Calculate symmetric cross-entropy loss
112
+ labels = torch.arange(len(logits_per_image), device=logits_per_image.device)
113
+ loss_i = F.cross_entropy(logits_per_image, labels)
114
+ loss_t = F.cross_entropy(logits_per_text, labels)
115
+ contrastive_loss = (loss_i + loss_t) / 2.0
116
+
117
+ # Handle potential NaNs
118
+ if torch.isnan(contrastive_loss):
119
+ print("Warning: Contrastive loss is NaN.")
120
+ return None # Return None or zero tensor
121
+
122
+ return contrastive_loss
123
+
124
+ def forward(self,
125
+ img_array,
126
+ prompt_ids,
127
+ prompt_attention_mask,
128
+ target_ids,
129
+ target_attention_mask,
130
+ generative_targets=None,
131
+ continuous_coords=None, # Now expects shape (B, MAX_POINTS, 2), padded
132
+ coords_mask=None # Mask for valid points (B, MAX_POINTS)
133
+ ):
134
+ """
135
+ Main forward pass for training. Calculates combined loss with masked regression loss.
136
+ """
137
+
138
+ # --- 1. Encode Image ---
139
+ image_embeds_raw = self.vision_encoder(img_array) # (B, N_img, C)
140
+ B, N_img, C_img = image_embeds_raw.shape
141
+ img_cls_token = image_embeds_raw[:, 0]
142
+
143
+ # --- 2. Contrastive Loss Path ---
144
+ contrastive_loss = None
145
+ # ... (contrastive loss calculation - same as before) ...
146
+ image_features_contrast = self.image_contrastive_head(img_cls_token)
147
+ with torch.no_grad(): # Keep no_grad here for efficiency if prompt embeddings aren't trained via contrastive
148
+ prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids)
149
+ prompt_lengths = prompt_attention_mask.sum(dim=1)
150
+ last_token_indices = (prompt_lengths - 1).clamp(min=0)
151
+ gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img)
152
+ prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1)
153
+ text_features_contrast = self.text_contrastive_head(prompt_last_token_embed)
154
+ contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast)
155
+
156
+
157
+ # --- 3. Generative / Regression Path ---
158
+ image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
159
+ prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids)
160
+ target_embeds_decoder = self.decoder.token_embedding_table(target_ids)
161
+ B, T_prompt, C = prompt_embeds_decoder.shape
162
+ B, T_target, _ = target_embeds_decoder.shape
163
+
164
+ # Prepare combined input sequence and attention mask for the decoder
165
+ combined_embeds = torch.cat([
166
+ image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder
167
+ ], dim=1)
168
+ combined_attention_mask = torch.cat([
169
+ torch.ones(B, N_img, dtype=torch.long, device=DEVICE),
170
+ prompt_attention_mask,
171
+ target_attention_mask
172
+ ], dim=1)
173
+ T_combined = combined_embeds.shape[1]
174
+
175
+ # Prepare combined targets for the classification loss
176
+ combined_class_targets = None
177
+ if generative_targets is not None:
178
+ combined_class_targets = torch.cat([
179
+ torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),
180
+ generative_targets
181
+ ], dim=1)
182
+
183
+ # --- Pass through Decoder ---
184
+ logits, class_loss, x_norm = self.decoder(
185
+ combined_embeds,
186
+ attention_mask=combined_attention_mask,
187
+ targets=combined_class_targets
188
+ )
189
+ # x_norm shape: (B, T_combined, C)
190
+
191
+ # --- Calculate Regression Output & Loss (Modified for multiple points) ---
192
+ regression_loss = None
193
+ regression_output = None
194
+ if continuous_coords is not None and coords_mask is not None and x_norm is not None:
195
+ # Strategy: Use hidden state corresponding to token *before* <result_end> (or <eos>)
196
+ # This single state predicts coordinates for *all* MAX_POINTS.
197
+ target_lengths = target_attention_mask.sum(dim=1) # Length of actual target tokens (B,)
198
+ # Index relative to start of *target sequence* is length - 2 (token before <eos>/<result_end>)
199
+ relative_target_idx = (target_lengths - 2).clamp(min=0)
200
+ # Absolute index in the combined sequence's hidden states (x_norm)
201
+ absolute_idx = N_img + T_prompt + relative_target_idx
202
+ absolute_idx = absolute_idx.clamp(max=T_combined - 1) # Clamp index
203
+
204
+ # Gather the hidden states at these specific indices
205
+ gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C)
206
+ try:
207
+ hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) # Shape: (B, C)
208
+ # Pass through the regression head
209
+ regression_output_flat = self.decoder.regression_head(hidden_state_for_regression) # Shape: (B, MAX_POINTS * 2)
210
+ # Reshape to (B, MAX_POINTS, 2)
211
+ regression_output = regression_output_flat.view(B, self.max_points, 2)
212
+
213
+ # --- Calculate MASKED regression loss (L1 - Mean Absolute Error) ---
214
+ loss_per_coord = F.l1_loss(regression_output, continuous_coords, reduction='none') # (B, MAX_POINTS, 2)
215
+ # Apply mask (mask is (B, MAX_POINTS), need to broadcast to (B, MAX_POINTS, 2))
216
+ masked_loss = loss_per_coord * coords_mask.unsqueeze(-1)
217
+ # Sum loss over valid points and coordinates, divide by number of valid coordinates
218
+ num_valid_coords = coords_mask.sum() * 2 # Total number of valid x,y values in batch
219
+ if num_valid_coords > 0:
220
+ regression_loss = masked_loss.sum() / num_valid_coords
221
+ else:
222
+ regression_loss = torch.tensor(0.0, device=DEVICE) # No valid points in batch
223
+
224
+ if torch.isnan(regression_loss):
225
+ print("Warning: Regression loss is NaN.")
226
+ regression_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # Set to zero tensor if NaN
227
+
228
+
229
+ except Exception as e:
230
+ print(f"Error during regression calculation: {e}")
231
+ print(f"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}")
232
+ regression_loss = None
233
+ regression_output = None # Ensure output is None if error occurs
234
+
235
+
236
+ # --- 4. Combine All Losses ---
237
+ total_loss = torch.tensor(0.0, device=DEVICE) # Ensure requires_grad=True
238
+ # Add valid losses with their respective weights
239
+ loss_log = {}
240
+ if class_loss is not None and torch.isfinite(class_loss):
241
+ total_loss += class_loss # Weight = 1.0 assumed
242
+ loss_log["class_loss"] = class_loss.item()
243
+ else:
244
+ # If class_loss is None or NaN/Inf, don't add it, log NaN
245
+ loss_log["class_loss"] = float('nan')
246
+ print(f"Warning: Invalid class_loss ({class_loss})")
247
+
248
+
249
+ if contrastive_loss is not None and torch.isfinite(contrastive_loss):
250
+ total_loss += self.lambda_contrastive * contrastive_loss
251
+ loss_log["contrastive_loss"] = contrastive_loss.item()
252
+ else:
253
+ loss_log["contrastive_loss"] = float('nan')
254
+ print(f"Warning: Invalid contrastive_loss ({contrastive_loss})")
255
+
256
+
257
+ if regression_loss is not None and torch.isfinite(regression_loss):
258
+ total_loss += self.lambda_regression * regression_loss
259
+ loss_log["regression_loss"] = regression_loss.item()
260
+ else:
261
+ loss_log["regression_loss"] = float('nan')
262
+ # Don't print warning if it was intentionally set to 0 due to no valid points
263
+ if regression_loss is not None and not (regression_loss == 0.0 and num_valid_coords == 0):
264
+ print(f"Warning: Invalid regression_loss ({regression_loss})")
265
+
266
+
267
+ # Handle case where total loss becomes NaN/Inf
268
+ if not torch.isfinite(total_loss):
269
+ print(f"Warning: Total loss became non-finite ({total_loss}). Setting to zero and clearing gradients.")
270
+ total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
271
+ # It might be safer to skip the optimizer step entirely here, handled in training loop
272
+
273
+ # Use the loss_log dictionary for clearer logging later
274
+ class_loss_val = loss_log["class_loss"]
275
+ contrastive_loss_val = loss_log["contrastive_loss"]
276
+ regression_loss_val = loss_log["regression_loss"]
277
+
278
+ # Return all relevant outputs (use scalar values for loss logging)
279
+ return logits, regression_output, total_loss, \
280
+ torch.tensor(class_loss_val), torch.tensor(contrastive_loss_val), torch.tensor(regression_loss_val)
281
+
282
+
283
+ # --- Generation Method ---
284
+ @torch.no_grad() # Ensure no gradients are computed during generation
285
+ def generate(self, img_array, idx_prompt, max_new_tokens,
286
+ temperature=1.0, top_k=None, # Default to greedy if temp=1, top_k=None
287
+ force_result_start=True # Option to manually add <result_start>
288
+ ):
289
+ """
290
+ Generates token sequences autoregressively based on image and prompt.
291
+ Uses the classification head (lm_head).
292
+
293
+ Args:
294
+ img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl.
295
+ idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt).
296
+ max_new_tokens (int): Maximum number of new tokens to generate.
297
+ temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper.
298
+ top_k (int | None): If set, restricts sampling to top K most likely tokens.
299
+ force_result_start (bool): If True, manually appends <result_start> embedding
300
+ after the prompt before starting generation loop.
301
+
302
+ Returns:
303
+ torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated).
304
+ """
305
+ self.eval() # Ensure model is in eval mode
306
+ B = img_array.shape[0]
307
+ if B > 1:
308
+ # This simplified generation loop assumes B=1 for clarity
309
+ # Batch generation requires careful handling of EOS and padding within the loop
310
+ print("Warning: Generation function currently assumes batch size B=1.")
311
+ # Process only the first item for now
312
+ img_array = img_array[:1]
313
+ idx_prompt = idx_prompt[:1]
314
+ B = 1
315
+
316
+ # --- 1. Prepare Initial Embeddings ---
317
+ image_embeds_raw = self.vision_encoder(img_array)
318
+ image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
319
+ prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt)
320
+
321
+ # Initial sequence for the decoder loop
322
+ current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1)
323
+ generated_ids_list = [] # Store newly generated IDs as a list
324
+
325
+ # Manually add <result_start> if forced
326
+ if force_result_start:
327
+ try:
328
+ result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
329
+ result_start_embed = self.decoder.token_embedding_table(
330
+ torch.tensor([[result_start_token_id]], device=DEVICE)
331
+ )
332
+ current_embeds = torch.cat([current_embeds, result_start_embed], dim=1)
333
+ # Also store this token ID if we added it
334
+ generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE))
335
+ except Exception as e:
336
+ print(f"Warning: Could not encode or add <result_start>: {e}")
337
+
338
+
339
+ # --- 2. Autoregressive Loop ---
340
+ for _ in range(max_new_tokens):
341
+ T_current = current_embeds.shape[1]
342
+
343
+ # Context truncation
344
+ if T_current > self.decoder.max_context:
345
+ current_embeds = current_embeds[:, -self.decoder.max_context:, :]
346
+ T_current = self.decoder.max_context
347
+
348
+ # Prepare inputs for decoder blocks
349
+ pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE)
350
+ pos = pos.clamp(max=self.decoder.max_context - 1)
351
+ pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0)
352
+ x = current_embeds + pos_emb
353
+ attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) # No padding needed
354
+
355
+ # Pass through decoder blocks
356
+ for block in self.decoder.blocks:
357
+ x = block(x, attention_mask=attention_mask)
358
+
359
+ # Get logits for the last token
360
+ x = self.decoder.ln_f(x[:, -1:, :]) # (B, 1, C)
361
+ logits = self.decoder.lm_head(x) # (B, 1, V)
362
+ logits = logits.squeeze(1) / temperature # Apply temperature (B, V)
363
+
364
+ # --- Sampling / Decoding ---
365
+ # Optional: Top-K filtering
366
+ if top_k is not None and top_k > 0:
367
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
368
+ logits[logits < v[:, [-1]]] = -float('Inf') # Apply mask
369
+
370
+ # Get probabilities
371
+ probs = F.softmax(logits, dim=-1)
372
+
373
+ # Sample next token ID
374
+ # For deterministic output (greedy), use torch.argmax instead of multinomial
375
+ if temperature == 0.0 or top_k == 1: # Greedy condition
376
+ idx_next = torch.argmax(probs, dim=-1, keepdim=True)
377
+ else:
378
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
379
+
380
+ # Append the generated token ID
381
+ generated_ids_list.append(idx_next)
382
+
383
+ # Stop if EOS is generated
384
+ if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id:
385
+ break
386
+
387
+ # Prepare for next iteration
388
+ next_token_embed = self.decoder.token_embedding_table(idx_next)
389
+ current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)
390
+
391
+
392
+ # --- 3. Combine results ---
393
+ if generated_ids_list:
394
+ generated_ids_tensor = torch.cat(generated_ids_list, dim=1) # (B, T_generated)
395
+ full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1)
396
+ else:
397
+ full_sequence_ids = idx_prompt # Return only prompt if nothing generated
398
+
399
+ self.train() # Set model back to training mode
400
+ return full_sequence_ids