Upload 10 files
Browse files- constants.py +30 -0
- dataset.py +349 -0
- decoder_language_model.py +165 -0
- finetune_lm_head_ce_loss.py +418 -0
- infer.py +504 -0
- model_components.py +163 -0
- train.py +264 -0
- train_stage_2.py +267 -0
- utils.py +57 -0
- vision_language_model.py +400 -0
constants.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
IMAGE_SIZE = 512
|
| 4 |
+
PATCH_SIZE = 16
|
| 5 |
+
HIDDEN_DIM = 256
|
| 6 |
+
CONTEXT_LENGTH = 1536
|
| 7 |
+
TEXT_LENGTH = 512 # Max length for *target* sequence (coords)
|
| 8 |
+
PROMPT_LENGTH = 64 # Max length for *prompt* sequence (description) - Adjust as needed
|
| 9 |
+
DROPOUT = 0.1
|
| 10 |
+
NUM_HEADS = 8
|
| 11 |
+
NUM_LAYERS = 12 # Keep moderate layers
|
| 12 |
+
BATCH_SIZE = 16
|
| 13 |
+
LEARNING_RATE = 1e-3 # Lower LR might be needed with contrastive loss
|
| 14 |
+
DTYPE = torch.float32 # torch.bfloat16 created some instability, why?
|
| 15 |
+
GRAD_ACCUMULATION_STEPS = 16
|
| 16 |
+
IMAGE_MEAN = [0.485, 0.456, 0.406]
|
| 17 |
+
IMAGE_STD = [0.229, 0.224, 0.225]
|
| 18 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 19 |
+
IMAGE_LOCATION = "./images/"
|
| 20 |
+
NUM_BINS = 32
|
| 21 |
+
SHARED_EMBED_DIM = 256 # Dimension for contrastive space
|
| 22 |
+
NUM_BINS = 32
|
| 23 |
+
MAX_POINTS = 10 # Maximum number of points per image to handle
|
| 24 |
+
|
| 25 |
+
# Training loop constants
|
| 26 |
+
NUM_EPOCHS = 400 # desired number of epochs
|
| 27 |
+
LOGGING_STEPS = 1 # Log every N optimization steps
|
| 28 |
+
MAX_GRAD_NORM = 1.0
|
| 29 |
+
LAMBDA_CONTRASTIVE = 2 # Weight for contrastive loss - TUNE THIS
|
| 30 |
+
LAMBDA_REGRESSION = 2 # Works but noisy
|
dataset.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm.auto import tqdm
|
| 2 |
+
from constants import *
|
| 3 |
+
from utils import *
|
| 4 |
+
import pickle
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
def format_point_text(points):
|
| 13 |
+
# This function should already handle multiple points correctly
|
| 14 |
+
text = "<result_start>"
|
| 15 |
+
for point in points:
|
| 16 |
+
# Ensure point coordinates are within [0, 100] before processing
|
| 17 |
+
px = min(max(int(point.get('x', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) # Added .get for safety
|
| 18 |
+
py = min(max(int(point.get('y', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1)
|
| 19 |
+
x_bin = min(px // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
|
| 20 |
+
y_bin = min(py // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
|
| 21 |
+
text += f"<pointx_start><coord_bin_{x_bin}><pointx_end><pointy_start><coord_bin_{y_bin}><pointy_end>"
|
| 22 |
+
text += "<result_end>" + tokenizer.eos_token
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
def format_data_for_training(sample):
|
| 26 |
+
"""Format data sample for training, handling 0 to MAX_POINTS continuous coordinates."""
|
| 27 |
+
try:
|
| 28 |
+
# Check if 'points' key exists and is a list, otherwise treat as 0 points
|
| 29 |
+
sample_points = sample.get('points', [])
|
| 30 |
+
if not isinstance(sample_points, list):
|
| 31 |
+
print(f"Warning: Invalid 'points' type for {sample.get('image_url', 'N/A')}. Treating as 0 points.")
|
| 32 |
+
sample_points = []
|
| 33 |
+
|
| 34 |
+
# Limit the number of points processed
|
| 35 |
+
points_to_process = sample_points[:MAX_POINTS]
|
| 36 |
+
num_points = len(points_to_process)
|
| 37 |
+
|
| 38 |
+
# Load image - this is where most memory is used
|
| 39 |
+
image_path = f"{IMAGE_LOCATION}{sample['image_url']}"
|
| 40 |
+
|
| 41 |
+
# Check if file exists before attempting to open
|
| 42 |
+
if not os.path.exists(image_path):
|
| 43 |
+
print(f"Warning: Image not found: {image_path}. Skipping.")
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
# Open image with error handling
|
| 47 |
+
try:
|
| 48 |
+
image = Image.open(image_path)
|
| 49 |
+
# Convert grayscale to RGB if needed
|
| 50 |
+
if image.mode != 'RGB':
|
| 51 |
+
image = image.convert('RGB')
|
| 52 |
+
image_tensor = image_to_tensor(image)
|
| 53 |
+
# Explicitly delete the PIL image to free memory
|
| 54 |
+
del image
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"Error processing image {image_path}: {e}")
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
# Process text with memory efficiency in mind
|
| 60 |
+
prompt_text = f"<point_start>{sample['label']}<point_end>"
|
| 61 |
+
# format_point_text correctly handles an empty points_to_process list
|
| 62 |
+
target_text = format_point_text(points_to_process)
|
| 63 |
+
|
| 64 |
+
# Tokenize with explicit max lengths
|
| 65 |
+
prompt_tokens = tokenizer(prompt_text, return_tensors="pt", max_length=PROMPT_LENGTH,
|
| 66 |
+
truncation=True, padding=False)
|
| 67 |
+
target_tokens = tokenizer(target_text, return_tensors="pt", max_length=TEXT_LENGTH,
|
| 68 |
+
truncation=True, padding=False)
|
| 69 |
+
|
| 70 |
+
# Check for empty tokens after tokenization
|
| 71 |
+
if prompt_tokens.input_ids.numel() == 0 or target_tokens.input_ids.numel() == 0:
|
| 72 |
+
print(f"Warning: Empty tokens after tokenization for {sample.get('image_url', 'N/A')}. Skipping.")
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
# --- Handle Multiple Continuous Coordinates with Padding (Handles num_points=0 correctly) ---
|
| 76 |
+
continuous_coords_list = []
|
| 77 |
+
for point in points_to_process: # This loop won't run if num_points is 0
|
| 78 |
+
coord_x = min(max(point.get('x', 50) / 100.0, 0.0), 1.0)
|
| 79 |
+
coord_y = min(max(point.get('y', 50) / 100.0, 0.0), 1.0)
|
| 80 |
+
continuous_coords_list.append([coord_x, coord_y])
|
| 81 |
+
|
| 82 |
+
# Pad coordinates and create mask
|
| 83 |
+
# If continuous_coords_list is empty, create empty tensor with right shape
|
| 84 |
+
if num_points == 0:
|
| 85 |
+
padded_coords = torch.full((MAX_POINTS, 2), -1.0)
|
| 86 |
+
coords_mask = torch.zeros(MAX_POINTS)
|
| 87 |
+
else:
|
| 88 |
+
coords_tensor = torch.tensor(continuous_coords_list, dtype=torch.float32)
|
| 89 |
+
padding_needed = MAX_POINTS - num_points
|
| 90 |
+
padded_coords = F.pad(coords_tensor, (0, 0, 0, padding_needed), value=-1.0)
|
| 91 |
+
coords_mask = torch.cat([torch.ones(num_points, dtype=torch.float32),
|
| 92 |
+
torch.zeros(padding_needed, dtype=torch.float32)])
|
| 93 |
+
|
| 94 |
+
# Create and return the formatted sample
|
| 95 |
+
return {
|
| 96 |
+
"image": image_tensor,
|
| 97 |
+
"prompt_ids": prompt_tokens.input_ids[0],
|
| 98 |
+
"target_ids": target_tokens.input_ids[0],
|
| 99 |
+
"continuous_coords": padded_coords,
|
| 100 |
+
"coords_mask": coords_mask,
|
| 101 |
+
"num_points": num_points,
|
| 102 |
+
"label": sample['label'],
|
| 103 |
+
"image_url": sample['image_url']
|
| 104 |
+
}
|
| 105 |
+
except FileNotFoundError:
|
| 106 |
+
print(f"Warning: Image not found: {sample.get('image_url', 'N/A')}. Skipping.")
|
| 107 |
+
return None
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"Error formatting sample ({sample.get('image_url', 'N/A')}): {e}. Skipping.")
|
| 110 |
+
import traceback
|
| 111 |
+
traceback.print_exc()
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class PointDataset(Dataset):
|
| 116 |
+
def __init__(self, data_path="active_point_dataset.pkl", split="train", test_size=1000):
|
| 117 |
+
with open(data_path, "rb") as f:
|
| 118 |
+
raw_data = pickle.load(f)
|
| 119 |
+
|
| 120 |
+
# --- Corrected filter and print statement ---
|
| 121 |
+
# Keep samples with 0 to MAX_POINTS points. Handle potential non-list 'points' safely.
|
| 122 |
+
original_count = len(raw_data)
|
| 123 |
+
raw_data = [sample for sample in raw_data
|
| 124 |
+
if 0 <= len(sample.get('points', [])) <= MAX_POINTS and isinstance(sample.get('points', []), list)]
|
| 125 |
+
filtered_count = len(raw_data)
|
| 126 |
+
print(f"Original raw data size: {original_count}")
|
| 127 |
+
print(f"Filtered raw data to {filtered_count} samples with 0 to {MAX_POINTS} points.")
|
| 128 |
+
|
| 129 |
+
total_samples = len(raw_data)
|
| 130 |
+
if total_samples == 0:
|
| 131 |
+
raise ValueError("No samples left after filtering. Check data or MAX_POINTS.") # Added error for empty dataset
|
| 132 |
+
|
| 133 |
+
if total_samples <= test_size:
|
| 134 |
+
print(f"Warning: Dataset size {total_samples} <= test_size {test_size}.")
|
| 135 |
+
test_size = max(1, int(total_samples * 0.2)) if total_samples > 1 else 0
|
| 136 |
+
train_end = total_samples - test_size
|
| 137 |
+
# Update print statement to reflect 0 points are included
|
| 138 |
+
print(f"Dataset: {total_samples} total (0 to {MAX_POINTS} points), {train_end} train, {test_size} test")
|
| 139 |
+
|
| 140 |
+
# --- Corrected split logic to use actual train/test counts ---
|
| 141 |
+
if split == "train":
|
| 142 |
+
# Check if train_end is valid before slicing
|
| 143 |
+
if train_end <= 0: print("Warning: No samples allocated for training split.")
|
| 144 |
+
self.raw_data = raw_data[:train_end]
|
| 145 |
+
elif split == "test":
|
| 146 |
+
# Check if test_size is valid before slicing
|
| 147 |
+
if test_size <= 0: print("Warning: No samples allocated for test split.")
|
| 148 |
+
self.raw_data = raw_data[train_end:]
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError("split must be 'train' or 'test'")
|
| 151 |
+
|
| 152 |
+
# DO NOT preprocess data here - just store the raw data
|
| 153 |
+
# This is the key change - we don't load all images at once
|
| 154 |
+
print(f"Dataset initialized with {len(self.raw_data)} samples for {split}")
|
| 155 |
+
|
| 156 |
+
# Optional: Cache a small number of recent items to speed up repeated access
|
| 157 |
+
self.cache_size = 8000 # Adjust based on memory constraints
|
| 158 |
+
self.cache = {} # Simple LRU cache for processed samples
|
| 159 |
+
|
| 160 |
+
def __len__(self):
|
| 161 |
+
return len(self.raw_data)
|
| 162 |
+
|
| 163 |
+
def __getitem__(self, idx):
|
| 164 |
+
# Check if the item is in the cache
|
| 165 |
+
if idx in self.cache:
|
| 166 |
+
return self.cache[idx]
|
| 167 |
+
|
| 168 |
+
# Process the sample on-demand
|
| 169 |
+
sample = self.raw_data[idx]
|
| 170 |
+
formatted = format_data_for_training(sample)
|
| 171 |
+
|
| 172 |
+
# If processing failed, try the next sample
|
| 173 |
+
if formatted is None:
|
| 174 |
+
# Find next valid index (with wrapping)
|
| 175 |
+
next_idx = (idx + 1) % len(self.raw_data)
|
| 176 |
+
|
| 177 |
+
# Prevent infinite loop if all samples are invalid
|
| 178 |
+
attempts = 0
|
| 179 |
+
while formatted is None and attempts < min(10, len(self.raw_data)):
|
| 180 |
+
sample = self.raw_data[next_idx]
|
| 181 |
+
formatted = format_data_for_training(sample)
|
| 182 |
+
next_idx = (next_idx + 1) % len(self.raw_data)
|
| 183 |
+
attempts += 1
|
| 184 |
+
|
| 185 |
+
# If we still don't have a valid sample after attempts, return a dummy sample
|
| 186 |
+
if formatted is None:
|
| 187 |
+
print(f"Warning: Failed to find valid sample after {attempts} attempts")
|
| 188 |
+
# Create minimal valid sample with zeros
|
| 189 |
+
formatted = self._create_dummy_sample()
|
| 190 |
+
|
| 191 |
+
# Update cache - simple LRU implementation
|
| 192 |
+
if len(self.cache) >= self.cache_size:
|
| 193 |
+
# Remove oldest item (first key)
|
| 194 |
+
if self.cache:
|
| 195 |
+
oldest_key = next(iter(self.cache))
|
| 196 |
+
del self.cache[oldest_key]
|
| 197 |
+
|
| 198 |
+
# Add to cache
|
| 199 |
+
self.cache[idx] = formatted
|
| 200 |
+
|
| 201 |
+
return formatted
|
| 202 |
+
|
| 203 |
+
def _create_dummy_sample(self):
|
| 204 |
+
"""Creates a minimal valid sample when all else fails."""
|
| 205 |
+
# Create empty image tensor
|
| 206 |
+
image_tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE)
|
| 207 |
+
|
| 208 |
+
# Create minimal tokens
|
| 209 |
+
prompt_text = "<point_start>dummy<point_end>"
|
| 210 |
+
target_text = "<result_start><result_end>" + tokenizer.eos_token
|
| 211 |
+
|
| 212 |
+
prompt_tokens = tokenizer(prompt_text, return_tensors="pt").input_ids[0]
|
| 213 |
+
target_tokens = tokenizer(target_text, return_tensors="pt").input_ids[0]
|
| 214 |
+
|
| 215 |
+
# Create empty coordinates
|
| 216 |
+
padded_coords = torch.full((MAX_POINTS, 2), -1.0)
|
| 217 |
+
coords_mask = torch.zeros(MAX_POINTS)
|
| 218 |
+
|
| 219 |
+
return {
|
| 220 |
+
"image": image_tensor,
|
| 221 |
+
"prompt_ids": prompt_tokens,
|
| 222 |
+
"target_ids": target_tokens,
|
| 223 |
+
"continuous_coords": padded_coords,
|
| 224 |
+
"coords_mask": coords_mask,
|
| 225 |
+
"num_points": 0,
|
| 226 |
+
"label": "dummy",
|
| 227 |
+
"image_url": "none"
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
# --- collate_fn remains the same as the previous version ---
|
| 231 |
+
@staticmethod
|
| 232 |
+
def collate_fn(batch):
|
| 233 |
+
# ... (Same as before, correctly handles stacking the padded coords and masks) ...
|
| 234 |
+
batch = [item for item in batch if item is not None]
|
| 235 |
+
if not batch: return None
|
| 236 |
+
|
| 237 |
+
images = torch.stack([item['image'] for item in batch]).to(DTYPE)
|
| 238 |
+
|
| 239 |
+
# --- Pad Prompt IDs ---
|
| 240 |
+
max_prompt_len = max(item['prompt_ids'].size(0) for item in batch)
|
| 241 |
+
prompt_ids_padded, prompt_attention_mask = [], []
|
| 242 |
+
for item in batch:
|
| 243 |
+
ids, pad_len = item['prompt_ids'], max_prompt_len - item['prompt_ids'].size(0)
|
| 244 |
+
prompt_ids_padded.append(torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
|
| 245 |
+
prompt_attention_mask.append(torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)]))
|
| 246 |
+
prompt_ids = torch.stack(prompt_ids_padded)
|
| 247 |
+
prompt_attention_mask = torch.stack(prompt_attention_mask)
|
| 248 |
+
|
| 249 |
+
# --- Pad Target IDs & Create Generative Targets ---
|
| 250 |
+
max_target_len = max(item['target_ids'].size(0) for item in batch)
|
| 251 |
+
target_ids_padded, target_attention_mask, generative_targets = [], [], []
|
| 252 |
+
for item in batch:
|
| 253 |
+
ids, pad_len = item['target_ids'], max_target_len - item['target_ids'].size(0)
|
| 254 |
+
padded_ids = torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)])
|
| 255 |
+
target_ids_padded.append(padded_ids)
|
| 256 |
+
mask = torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)])
|
| 257 |
+
target_attention_mask.append(mask)
|
| 258 |
+
targets = torch.full_like(padded_ids, -100)
|
| 259 |
+
if ids.size(0) > 1:
|
| 260 |
+
targets[:ids.size(0)-1] = ids[1:]
|
| 261 |
+
if ids.numel() > 0 and ids[-1] == tokenizer.eos_token_id:
|
| 262 |
+
if ids.size(0) > 1:
|
| 263 |
+
targets[ids.size(0)-1] = tokenizer.eos_token_id
|
| 264 |
+
else:
|
| 265 |
+
targets[0] = -100
|
| 266 |
+
generative_targets.append(targets)
|
| 267 |
+
target_ids = torch.stack(target_ids_padded)
|
| 268 |
+
target_attention_mask = torch.stack(target_attention_mask)
|
| 269 |
+
generative_targets = torch.stack(generative_targets)
|
| 270 |
+
|
| 271 |
+
# --- Stack Continuous Coords and Masks ---
|
| 272 |
+
continuous_coords = torch.stack([item['continuous_coords'] for item in batch])
|
| 273 |
+
coords_mask = torch.stack([item['coords_mask'] for item in batch])
|
| 274 |
+
num_points = [item['num_points'] for item in batch]
|
| 275 |
+
|
| 276 |
+
labels = [item['label'] for item in batch]
|
| 277 |
+
image_urls = [item.get('image_url', '') for item in batch]
|
| 278 |
+
|
| 279 |
+
return {
|
| 280 |
+
'image': images,
|
| 281 |
+
'prompt_ids': prompt_ids,
|
| 282 |
+
'prompt_attention_mask': prompt_attention_mask,
|
| 283 |
+
'target_ids': target_ids,
|
| 284 |
+
'target_attention_mask': target_attention_mask,
|
| 285 |
+
'generative_targets': generative_targets,
|
| 286 |
+
'continuous_coords': continuous_coords,
|
| 287 |
+
'coords_mask': coords_mask,
|
| 288 |
+
'num_points': num_points,
|
| 289 |
+
'label': labels,
|
| 290 |
+
'image_url': image_urls
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
def create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
|
| 294 |
+
"""Create training dataloader with memory-efficient settings.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
batch_size: Number of samples per batch
|
| 298 |
+
num_workers: Number of worker processes for data loading
|
| 299 |
+
prefetch_factor: Number of batches to prefetch per worker
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
DataLoader instance or None if dataset is empty
|
| 303 |
+
"""
|
| 304 |
+
dataset = PointDataset(split="train")
|
| 305 |
+
if len(dataset) == 0:
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
# Configure DataLoader for memory efficiency
|
| 309 |
+
return DataLoader(
|
| 310 |
+
dataset,
|
| 311 |
+
batch_size=batch_size,
|
| 312 |
+
shuffle=True,
|
| 313 |
+
collate_fn=PointDataset.collate_fn,
|
| 314 |
+
pin_memory=True, # Speeds up CPU to GPU transfer
|
| 315 |
+
num_workers=num_workers,
|
| 316 |
+
prefetch_factor=prefetch_factor if num_workers > 0 else None, # Only valid with workers
|
| 317 |
+
persistent_workers=num_workers > 0, # Keep workers alive between epochs
|
| 318 |
+
drop_last=False # Don't drop the last incomplete batch
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
|
| 322 |
+
"""Create test dataloader with memory-efficient settings.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
batch_size: Number of samples per batch
|
| 326 |
+
num_workers: Number of worker processes for data loading
|
| 327 |
+
prefetch_factor: Number of batches to prefetch per worker
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
DataLoader instance or None if dataset is empty
|
| 331 |
+
"""
|
| 332 |
+
dataset = PointDataset(split="test")
|
| 333 |
+
if len(dataset) == 0:
|
| 334 |
+
print("Warning: Test dataset is empty. Returning None.")
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
# Test loader with similar memory settings but no shuffling
|
| 338 |
+
return DataLoader(
|
| 339 |
+
dataset,
|
| 340 |
+
batch_size=batch_size,
|
| 341 |
+
shuffle=False,
|
| 342 |
+
collate_fn=PointDataset.collate_fn,
|
| 343 |
+
pin_memory=True,
|
| 344 |
+
num_workers=num_workers,
|
| 345 |
+
prefetch_factor=prefetch_factor if num_workers > 0 else None,
|
| 346 |
+
persistent_workers=num_workers > 0,
|
| 347 |
+
drop_last=False
|
| 348 |
+
)
|
| 349 |
+
|
decoder_language_model.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model_components import Block
|
| 2 |
+
from constants import *
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from utils import tokenizer, vocab_size
|
| 7 |
+
|
| 8 |
+
class DecoderLanguageModel(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Transformer Decoder Language Model with optional coordinate regression head.
|
| 11 |
+
Processes a combined sequence of embeddings.
|
| 12 |
+
Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS).
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS,
|
| 15 |
+
n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT):
|
| 16 |
+
super().__init__()
|
| 17 |
+
# --- Input Embeddings ---
|
| 18 |
+
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
|
| 19 |
+
self.position_embedding_table = nn.Embedding(max_context, n_embd)
|
| 20 |
+
self.dropout = nn.Dropout(dropout)
|
| 21 |
+
|
| 22 |
+
# --- Transformer Blocks ---
|
| 23 |
+
self.blocks = nn.ModuleList([
|
| 24 |
+
Block(n_embd, num_heads, dropout, is_decoder=True)
|
| 25 |
+
for _ in range(n_layer)
|
| 26 |
+
])
|
| 27 |
+
|
| 28 |
+
# --- Final Layer Norm ---
|
| 29 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 30 |
+
|
| 31 |
+
# --- Output Heads ---
|
| 32 |
+
# 1. Head for token classification
|
| 33 |
+
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| 34 |
+
|
| 35 |
+
# 2. Head for direct coordinate regression (predicting MAX_POINTS * 2 values)
|
| 36 |
+
self.regression_head = nn.Sequential(
|
| 37 |
+
nn.Linear(n_embd, n_embd // 2),
|
| 38 |
+
nn.GELU(),
|
| 39 |
+
nn.Linear(n_embd // 2, MAX_POINTS * 2), # Output MAX_POINTS * (x, y)
|
| 40 |
+
nn.Sigmoid() # Output activation [0, 1]
|
| 41 |
+
)
|
| 42 |
+
# --- End Output Heads ---
|
| 43 |
+
|
| 44 |
+
self.n_embd = n_embd
|
| 45 |
+
self.max_context = max_context
|
| 46 |
+
self.token_embedding_table.weight = self.lm_head.weight
|
| 47 |
+
self.apply(self._init_weights)
|
| 48 |
+
print(f"DecoderLanguageModel initialized with {n_layer} layers.")
|
| 49 |
+
|
| 50 |
+
def _init_weights(self, module):
|
| 51 |
+
# ... (same as before) ...
|
| 52 |
+
if isinstance(module, nn.Linear):
|
| 53 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 54 |
+
if module.bias is not None:
|
| 55 |
+
torch.nn.init.zeros_(module.bias)
|
| 56 |
+
elif isinstance(module, nn.Embedding):
|
| 57 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 58 |
+
elif isinstance(module, nn.LayerNorm):
|
| 59 |
+
torch.nn.init.zeros_(module.bias)
|
| 60 |
+
torch.nn.init.ones_(module.weight)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def forward(self, combined_embeds, attention_mask=None, targets=None):
|
| 64 |
+
"""
|
| 65 |
+
Forward pass for training or inference where loss is calculated.
|
| 66 |
+
Regression output is now handled *outside* this module by VLM.
|
| 67 |
+
"""
|
| 68 |
+
# --- Input Validation & Processing ---
|
| 69 |
+
if combined_embeds.ndim != 3:
|
| 70 |
+
raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}")
|
| 71 |
+
B, T, C = combined_embeds.shape
|
| 72 |
+
if T > self.max_context:
|
| 73 |
+
# ... (context truncation logic - same as before) ...
|
| 74 |
+
print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.")
|
| 75 |
+
combined_embeds = combined_embeds[:, -self.max_context:, :]
|
| 76 |
+
if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:]
|
| 77 |
+
if targets is not None: targets = targets[:, -self.max_context:]
|
| 78 |
+
T = self.max_context
|
| 79 |
+
|
| 80 |
+
# --- Positional Encoding ---
|
| 81 |
+
pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device)
|
| 82 |
+
pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1)
|
| 83 |
+
pos_emb = self.position_embedding_table(pos) # Shape: (T, C)
|
| 84 |
+
x = combined_embeds + pos_emb.unsqueeze(0)
|
| 85 |
+
x = self.dropout(x)
|
| 86 |
+
|
| 87 |
+
# --- Transformer Blocks ---
|
| 88 |
+
for block in self.blocks:
|
| 89 |
+
x = block(x, attention_mask=attention_mask)
|
| 90 |
+
|
| 91 |
+
# --- Final Layer Norm ---
|
| 92 |
+
x_norm = self.ln_f(x) # Shape: (B, T, C) - Pass this out for VLM regression head
|
| 93 |
+
|
| 94 |
+
# --- Classification Head Output ---
|
| 95 |
+
logits = self.lm_head(x_norm) # Shape: (B, T, VocabSize)
|
| 96 |
+
|
| 97 |
+
# --- Classification Loss Calculation ---
|
| 98 |
+
class_loss = None
|
| 99 |
+
if targets is not None:
|
| 100 |
+
# ... (cross_entropy calculation - same as before) ...
|
| 101 |
+
try:
|
| 102 |
+
class_loss = F.cross_entropy(
|
| 103 |
+
logits.view(-1, logits.size(-1)),
|
| 104 |
+
targets.view(-1),
|
| 105 |
+
ignore_index=-100
|
| 106 |
+
)
|
| 107 |
+
if torch.isnan(class_loss):
|
| 108 |
+
print("Warning: class_loss is NaN.")
|
| 109 |
+
class_loss = None
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error calculating cross_entropy: {e}")
|
| 112 |
+
print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}")
|
| 113 |
+
class_loss = None
|
| 114 |
+
|
| 115 |
+
# Return logits, class_loss, and the final normalized hidden states
|
| 116 |
+
return logits, class_loss, x_norm
|
| 117 |
+
|
| 118 |
+
# --- Generation Method (Example - if needed internally, otherwise VLM handles it) ---
|
| 119 |
+
# If VLM needs this class to perform generation based on token IDs:
|
| 120 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
| 121 |
+
"""
|
| 122 |
+
Autoregressive generation based on starting token IDs.
|
| 123 |
+
NOTE: This version doesn't handle combined embeddings directly.
|
| 124 |
+
The VisionLanguageModel should ideally use a method like
|
| 125 |
+
generate_from_embeddings or implement the loop externally.
|
| 126 |
+
"""
|
| 127 |
+
self.eval()
|
| 128 |
+
for _ in range(max_new_tokens):
|
| 129 |
+
# --- Context Management ---
|
| 130 |
+
# Crop idx if longer than context length
|
| 131 |
+
idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:]
|
| 132 |
+
|
| 133 |
+
# --- Forward Pass ---
|
| 134 |
+
# Get embeddings
|
| 135 |
+
tok_embeds = self.token_embedding_table(idx_cond) # (B, T, C)
|
| 136 |
+
# Get positional embeddings
|
| 137 |
+
pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device)
|
| 138 |
+
pos = pos.clamp(max=self.max_context - 1)
|
| 139 |
+
pos_emb = self.position_embedding_table(pos).unsqueeze(0) # (1, T, C)
|
| 140 |
+
x = self.dropout(tok_embeds + pos_emb)
|
| 141 |
+
# Pass through blocks (no padding mask needed here as we handle single sequence)
|
| 142 |
+
for block in self.blocks:
|
| 143 |
+
x = block(x, attention_mask=None) # Causal mask is internal to block/head
|
| 144 |
+
# Final layer norm and head for the last token only
|
| 145 |
+
x = self.ln_f(x[:, -1:, :]) # (B, 1, C)
|
| 146 |
+
logits = self.lm_head(x) # (B, 1, V)
|
| 147 |
+
logits = logits.squeeze(1) # (B, V)
|
| 148 |
+
|
| 149 |
+
# --- Sampling ---
|
| 150 |
+
logits = logits / temperature
|
| 151 |
+
if top_k is not None and top_k > 0:
|
| 152 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 153 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 154 |
+
probs = F.softmax(logits, dim=-1)
|
| 155 |
+
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 156 |
+
|
| 157 |
+
# Append sampled token
|
| 158 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 159 |
+
|
| 160 |
+
# Stop if EOS
|
| 161 |
+
if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all():
|
| 162 |
+
break
|
| 163 |
+
self.train()
|
| 164 |
+
return idx
|
| 165 |
+
|
finetune_lm_head_ce_loss.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# finetune_lm_head_ce_loss.py
|
| 2 |
+
# python finetune_lm_head_ce_loss.py --pretrained_model_path model_regression_multi_stage_2_11.pth
|
| 3 |
+
|
| 4 |
+
# finetune_lm_head_ce_loss.py
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR # Using Cosine decay for fine-tuning
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
import wandb
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import numpy as np
|
| 14 |
+
import argparse
|
| 15 |
+
import os
|
| 16 |
+
import math
|
| 17 |
+
import traceback # For detailed error printing
|
| 18 |
+
from constants import *
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
# Ensure get_tokenizer defines global tokenizer and vocab_size
|
| 22 |
+
from utils import get_tokenizer, tokenizer, vocab_size, tensor_to_image, image_to_tensor
|
| 23 |
+
if 'tokenizer' not in globals() or 'vocab_size' not in globals():
|
| 24 |
+
print("Initializing tokenizer...")
|
| 25 |
+
tokenizer, vocab_size = get_tokenizer()
|
| 26 |
+
except ImportError:
|
| 27 |
+
print("Error: Could not import required functions/variables from utils.py.")
|
| 28 |
+
exit()
|
| 29 |
+
except NameError:
|
| 30 |
+
print("Error: tokenizer or vocab_size not defined after importing utils.")
|
| 31 |
+
exit()
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Error during utils import or tokenizer init: {e}")
|
| 34 |
+
exit()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Dataset needs to handle 0 points and MAX_POINTS filter
|
| 39 |
+
# Collate fn should return necessary keys including 'generative_targets'
|
| 40 |
+
from dataset import create_train_dataloader, create_test_dataloader, PointDataset
|
| 41 |
+
except ImportError:
|
| 42 |
+
print("Error: Could not import from dataset.py.")
|
| 43 |
+
exit()
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# VisionLanguageModel __init__ should match the one used in the training script
|
| 47 |
+
# Make sure DecoderLanguageModel etc. are also available
|
| 48 |
+
from vision_language_model import VisionLanguageModel
|
| 49 |
+
except ImportError:
|
| 50 |
+
print("Error: Could not import VisionLanguageModel from vision_language_model.py.")
|
| 51 |
+
exit()
|
| 52 |
+
|
| 53 |
+
# --- Fine-tuning Specific Arguments ---
|
| 54 |
+
parser = argparse.ArgumentParser(description="Re-initialize and fine-tune the LM head using ONLY classification loss.")
|
| 55 |
+
parser.add_argument("--pretrained_model_path", type=str, required=True, help="Path to the pre-trained model state_dict (.pth file).")
|
| 56 |
+
parser.add_argument("--output_model_path", type=str, default="model_lm_reinit_ce_finetuned.pth", help="Path to save the fine-tuned model.")
|
| 57 |
+
parser.add_argument("--ft_epochs", type=int, default=10, help="Number of epochs for fine-tuning.")
|
| 58 |
+
parser.add_argument("--ft_lr", type=float, default=5e-4, help="Learning rate for fine-tuning.")
|
| 59 |
+
parser.add_argument("--ft_batch_size", type=int, default=BATCH_SIZE, help="Batch size for fine-tuning.")
|
| 60 |
+
parser.add_argument("--ft_grad_accum", type=int, default=GRAD_ACCUMULATION_STEPS, help="Gradient accumulation steps.")
|
| 61 |
+
parser.add_argument("--ft_log_steps", type=int, default=1, help="Logging frequency.")
|
| 62 |
+
parser.add_argument("--train_final_ln", action='store_true', help="Also train the final LayerNorm (ln_f) before the lm_head.")
|
| 63 |
+
parser.add_argument("--wandb_project", type=str, default="point-lm-head-reinit-ce-finetune", help="WandB project name.")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
args = parser.parse_args()
|
| 68 |
+
|
| 69 |
+
# Use constants/args consistently
|
| 70 |
+
FT_BATCH_SIZE = args.ft_batch_size
|
| 71 |
+
FT_GRAD_ACCUM = args.ft_grad_accum
|
| 72 |
+
FT_LOG_STEPS = args.ft_log_steps
|
| 73 |
+
|
| 74 |
+
print(f"Using device: {DEVICE}")
|
| 75 |
+
print(f"Re-initializing and fine-tuning LM head (and final LN: {args.train_final_ln})")
|
| 76 |
+
print(f"Using ONLY Classification (Cross-Entropy) Loss")
|
| 77 |
+
print(f"Pretrained model: {args.pretrained_model_path}")
|
| 78 |
+
print(f"Output model: {args.output_model_path}")
|
| 79 |
+
print(f"Epochs: {args.ft_epochs}, LR: {args.ft_lr}, Batch Size: {FT_BATCH_SIZE}, Grad Accum: {FT_GRAD_ACCUM}")
|
| 80 |
+
|
| 81 |
+
# --- Load Model Definition ---
|
| 82 |
+
print("Loading model definition...")
|
| 83 |
+
try:
|
| 84 |
+
# Use parameters consistent with the pre-trained model's architecture
|
| 85 |
+
model_args = {
|
| 86 |
+
'n_embd': HIDDEN_DIM, 'vocab_size': vocab_size, 'img_size': IMAGE_SIZE, 'patch_size': PATCH_SIZE,
|
| 87 |
+
'num_heads': NUM_HEADS, 'num_blks_vit': NUM_LAYERS, 'num_blks_dec': NUM_LAYERS,
|
| 88 |
+
'emb_dropout': 0.0, 'blk_dropout': 0.0, 'max_context': CONTEXT_LENGTH,
|
| 89 |
+
'shared_embed_dim': SHARED_EMBED_DIM,
|
| 90 |
+
# Use the single contrastive lambda expected by the VLM class from training
|
| 91 |
+
'lambda_contrastive': 0.0,
|
| 92 |
+
'lambda_regression': 0.0,
|
| 93 |
+
'max_points': MAX_POINTS
|
| 94 |
+
}
|
| 95 |
+
model = VisionLanguageModel(**model_args).to(DEVICE)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Error initializing model structure: {e}")
|
| 98 |
+
exit()
|
| 99 |
+
|
| 100 |
+
# --- Load Pre-trained Weights ---
|
| 101 |
+
print(f"Loading pre-trained weights from: {args.pretrained_model_path}")
|
| 102 |
+
try:
|
| 103 |
+
# Use strict=False just in case parameter names differ slightly
|
| 104 |
+
model.load_state_dict(torch.load(args.pretrained_model_path, map_location=DEVICE, weights_only=True), strict=False)
|
| 105 |
+
print("Pre-trained weights loaded successfully.")
|
| 106 |
+
except FileNotFoundError: print(f"Error: Pre-trained model file not found at {args.pretrained_model_path}"); exit()
|
| 107 |
+
except Exception as e: print(f"Error loading model state_dict: {e}"); exit()
|
| 108 |
+
|
| 109 |
+
# --- Reinitialize LM Head ---
|
| 110 |
+
print("Reinitializing LM Head...")
|
| 111 |
+
model.decoder.lm_head.reset_parameters()
|
| 112 |
+
# --- Explicitly Re-Tie Weights AFTER reinitialization ---
|
| 113 |
+
model.decoder.token_embedding_table.weight = model.decoder.lm_head.weight
|
| 114 |
+
print("LM Head reinitialized and weights explicitly retied.")
|
| 115 |
+
|
| 116 |
+
# --- Freeze/Unfreeze Parameters (Do ONCE before loop) ---
|
| 117 |
+
print("Setting requires_grad flags...")
|
| 118 |
+
params_to_optimize = []
|
| 119 |
+
trainable_param_names = []
|
| 120 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 121 |
+
|
| 122 |
+
for param in model.parameters():
|
| 123 |
+
param.requires_grad = False # Freeze all
|
| 124 |
+
|
| 125 |
+
print("\nParameters explicitly marked as trainable:")
|
| 126 |
+
for param in model.decoder.lm_head.parameters():
|
| 127 |
+
param.requires_grad = True
|
| 128 |
+
params_to_optimize.append(param)
|
| 129 |
+
for name, p in model.decoder.lm_head.named_parameters():
|
| 130 |
+
if p is param: trainable_param_names.append(f"decoder.lm_head.{name}"); break
|
| 131 |
+
|
| 132 |
+
if args.train_final_ln:
|
| 133 |
+
for param in model.decoder.ln_f.parameters():
|
| 134 |
+
param.requires_grad = True
|
| 135 |
+
params_to_optimize.append(param)
|
| 136 |
+
for name, p in model.decoder.ln_f.named_parameters():
|
| 137 |
+
if p is param: trainable_param_names.append(f"decoder.ln_f.{name}"); break
|
| 138 |
+
|
| 139 |
+
# --- Create Optimizer using the specific list ---
|
| 140 |
+
print("\nParameters passed to optimizer:")
|
| 141 |
+
for name in trainable_param_names: print(f"- {name}")
|
| 142 |
+
trainable_params_count = sum(p.numel() for p in params_to_optimize)
|
| 143 |
+
print(f"\nTotal parameters: {total_params}")
|
| 144 |
+
print(f"Trainable parameters (optimizer target): {trainable_params_count} ({100 * trainable_params_count / total_params:.2f}%)")
|
| 145 |
+
|
| 146 |
+
# Verification print
|
| 147 |
+
print("\nVerification: All parameters with requires_grad=True:")
|
| 148 |
+
actual_trainable_count = 0
|
| 149 |
+
for name, param in model.named_parameters():
|
| 150 |
+
if param.requires_grad:
|
| 151 |
+
is_in_optimize_list = any(p is param for p in params_to_optimize)
|
| 152 |
+
print(f"- {name} (Requires Grad: {param.requires_grad}, In Optimizer List: {is_in_optimize_list})")
|
| 153 |
+
actual_trainable_count += param.numel()
|
| 154 |
+
print(f"Actual trainable count (incl. tied): {actual_trainable_count}")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if not params_to_optimize: print("Error: No parameters collected for the optimizer."); exit()
|
| 158 |
+
optimizer = torch.optim.AdamW(params_to_optimize, lr=args.ft_lr, betas=(0.9, 0.95), weight_decay=0.1)
|
| 159 |
+
print("Optimizer created.")
|
| 160 |
+
|
| 161 |
+
# --- Dataloaders & Scheduler ---
|
| 162 |
+
print("Creating dataloaders...")
|
| 163 |
+
train_loader = create_train_dataloader(batch_size=FT_BATCH_SIZE, num_workers=4)
|
| 164 |
+
test_loader = create_test_dataloader(batch_size=FT_BATCH_SIZE, num_workers=2)
|
| 165 |
+
if train_loader is None: exit("Training loader failed to initialize.")
|
| 166 |
+
test_loader_has_data = test_loader and len(test_loader.dataset) > 0
|
| 167 |
+
scheduler = None
|
| 168 |
+
if train_loader and len(train_loader) > 0:
|
| 169 |
+
steps_per_epoch = (len(train_loader) // FT_GRAD_ACCUM) + (1 if len(train_loader) % FT_GRAD_ACCUM != 0 else 0)
|
| 170 |
+
total_steps = steps_per_epoch * args.ft_epochs
|
| 171 |
+
print(f"Fine-tuning: Total estimated optimization steps: {total_steps}")
|
| 172 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=args.ft_lr / 10)
|
| 173 |
+
else: print("Warning: Train loader empty. Cannot setup scheduler.")
|
| 174 |
+
|
| 175 |
+
# --- Wandb Setup ---
|
| 176 |
+
wandb_enabled = False
|
| 177 |
+
try:
|
| 178 |
+
wandb.init(
|
| 179 |
+
project=args.wandb_project,
|
| 180 |
+
name=f"lm-head-reinit-ce-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
|
| 181 |
+
config={ "fine_tuning_lr": args.ft_lr, "fine_tuning_epochs": args.ft_epochs, "batch_size": FT_BATCH_SIZE,
|
| 182 |
+
"grad_accum": FT_GRAD_ACCUM, "pretrained_model": args.pretrained_model_path, "train_final_ln": args.train_final_ln,
|
| 183 |
+
"loss": "Classification Only" } )
|
| 184 |
+
wandb_enabled = True
|
| 185 |
+
except Exception as e: print(f"Wandb initialization failed: {e}.")
|
| 186 |
+
|
| 187 |
+
# --- Fine-tuning Loop ---
|
| 188 |
+
print("Starting LM head re-init fine-tuning with Classification Loss...")
|
| 189 |
+
torch.autograd.set_detect_anomaly(True)
|
| 190 |
+
step_counter = 0
|
| 191 |
+
optimizer.zero_grad()
|
| 192 |
+
|
| 193 |
+
for epoch in range(args.ft_epochs):
|
| 194 |
+
model.train() # Set dropout/layernorm layers to train mode
|
| 195 |
+
|
| 196 |
+
epoch_class_loss_accum = 0.0
|
| 197 |
+
valid_batches_accum = 0
|
| 198 |
+
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"FT Epoch {epoch+1}/{args.ft_epochs}", leave=False)
|
| 199 |
+
|
| 200 |
+
for batch_idx, batch in pbar:
|
| 201 |
+
if batch is None: continue
|
| 202 |
+
# --- Unpack Data ---
|
| 203 |
+
try:
|
| 204 |
+
images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
|
| 205 |
+
prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
|
| 206 |
+
prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
|
| 207 |
+
target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
|
| 208 |
+
target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
|
| 209 |
+
generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True) # Needed for loss
|
| 210 |
+
# Pass None for unused args (model forward should handle this)
|
| 211 |
+
continuous_coords = batch.get('continuous_coords'); coords_mask = batch.get('coords_mask'); num_points_list = batch.get('num_points')
|
| 212 |
+
if continuous_coords is not None: continuous_coords = continuous_coords.to(DEVICE, non_blocking=True)
|
| 213 |
+
if coords_mask is not None: coords_mask = coords_mask.to(DEVICE, non_blocking=True)
|
| 214 |
+
except KeyError as e: print(f"KeyError unpacking batch: {e}"); continue
|
| 215 |
+
except Exception as e: print(f"Error unpacking batch: {e}"); continue
|
| 216 |
+
|
| 217 |
+
# --- Forward Pass ---
|
| 218 |
+
# Run full model normally. Autograd handles requires_grad flags.
|
| 219 |
+
try:
|
| 220 |
+
# We only need the logits output from the main model call
|
| 221 |
+
logits, _, _, _, _, _, *_ = model(
|
| 222 |
+
img_array=images, prompt_ids=prompt_ids, prompt_attention_mask=prompt_attention_mask,
|
| 223 |
+
target_ids=target_ids, target_attention_mask=target_attention_mask,
|
| 224 |
+
generative_targets=generative_targets, # Pass targets, model might use internally
|
| 225 |
+
continuous_coords=continuous_coords, coords_mask=coords_mask,
|
| 226 |
+
)
|
| 227 |
+
if logits is None or not torch.isfinite(logits).all():
|
| 228 |
+
print(f"!!! ERROR: NaN/Inf/None detected in logits. Skipping batch {batch_idx}. !!!")
|
| 229 |
+
optimizer.zero_grad(); continue
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(f"!!! ERROR during forward pass: {e} !!!"); traceback.print_exc()
|
| 233 |
+
optimizer.zero_grad(); continue
|
| 234 |
+
|
| 235 |
+
# --- Calculate Classification Loss EXTERNALLY ---
|
| 236 |
+
loss_to_backward = None
|
| 237 |
+
try:
|
| 238 |
+
# Get batch size and vocab size from logits
|
| 239 |
+
B, T_logits, V = logits.shape
|
| 240 |
+
|
| 241 |
+
# --- Prepare PADDED Targets for External CE Loss ---
|
| 242 |
+
# Logic to pad generative_targets to match T_logits
|
| 243 |
+
B_targ, T_target_orig = generative_targets.shape
|
| 244 |
+
N_img = model.num_patches
|
| 245 |
+
T_prompt = prompt_ids.shape[1]
|
| 246 |
+
T_combined_expected = N_img + T_prompt + T_target_orig # Expected full length
|
| 247 |
+
|
| 248 |
+
if T_logits != T_combined_expected:
|
| 249 |
+
# Handle potential truncation due to context length
|
| 250 |
+
print(f"Warning: Logits length {T_logits} != Expected combined length {T_combined_expected}. Adjusting targets.")
|
| 251 |
+
T_target_in_logits = max(0, T_logits - (N_img + T_prompt))
|
| 252 |
+
generative_targets_sliced = generative_targets[:, :T_target_in_logits]
|
| 253 |
+
combined_class_targets = torch.cat([
|
| 254 |
+
torch.full((B, T_logits - T_target_in_logits), -100, dtype=torch.long, device=DEVICE),
|
| 255 |
+
generative_targets_sliced
|
| 256 |
+
], dim=1)
|
| 257 |
+
else:
|
| 258 |
+
# Pad generative_targets normally
|
| 259 |
+
combined_class_targets = torch.cat([
|
| 260 |
+
torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),
|
| 261 |
+
generative_targets
|
| 262 |
+
], dim=1)
|
| 263 |
+
|
| 264 |
+
# Verify shapes before loss calculation
|
| 265 |
+
if logits.shape[1] != combined_class_targets.shape[1]:
|
| 266 |
+
raise ValueError(f"Shape mismatch before CE Loss! Logits T={logits.shape[1]}, Targets T={combined_class_targets.shape[1]}")
|
| 267 |
+
|
| 268 |
+
# Calculate loss using the logits that require grad and the padded targets
|
| 269 |
+
loss_to_backward = F.cross_entropy(
|
| 270 |
+
logits.view(-1, V), # Shape (B * T_logits, V)
|
| 271 |
+
combined_class_targets.view(-1), # Shape (B * T_logits)
|
| 272 |
+
ignore_index=-100
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if not torch.isfinite(loss_to_backward):
|
| 276 |
+
print(f"Warning: NaN/Inf detected in calculated class_loss ({loss_to_backward}).")
|
| 277 |
+
loss_to_backward = None
|
| 278 |
+
|
| 279 |
+
except Exception as e:
|
| 280 |
+
print(f"Error calculating external CE loss: {e}")
|
| 281 |
+
loss_to_backward = None
|
| 282 |
+
|
| 283 |
+
# Check loss before backward
|
| 284 |
+
if loss_to_backward is None:
|
| 285 |
+
print(f"Warning: Skipping batch {batch_idx} due to invalid loss calculation.")
|
| 286 |
+
optimizer.zero_grad(); continue
|
| 287 |
+
|
| 288 |
+
# --- Verification ---
|
| 289 |
+
if loss_to_backward.grad_fn is None:
|
| 290 |
+
print(f"!!! ERROR: loss_to_backward (value: {loss_to_backward.item()}) has no grad_fn! Batch {batch_idx} !!!")
|
| 291 |
+
optimizer.zero_grad(); continue
|
| 292 |
+
|
| 293 |
+
# Accumulate for logging
|
| 294 |
+
epoch_class_loss_accum += loss_to_backward.item(); valid_batches_accum += 1
|
| 295 |
+
scaled_loss = loss_to_backward / FT_GRAD_ACCUM
|
| 296 |
+
|
| 297 |
+
# --- Backward Pass ---
|
| 298 |
+
try:
|
| 299 |
+
scaled_loss.backward()
|
| 300 |
+
except RuntimeError as e: print(f"!!! RUNTIME ERROR backward: {e} !!!"); optimizer.zero_grad(); continue
|
| 301 |
+
|
| 302 |
+
# --- Gradient Accumulation Step ---
|
| 303 |
+
if (batch_idx + 1) % FT_GRAD_ACCUM == 0 or (batch_idx + 1) == len(train_loader):
|
| 304 |
+
# Check/Clip gradients of OPTIMIZED parameters
|
| 305 |
+
found_non_finite_grad = False
|
| 306 |
+
for p in params_to_optimize:
|
| 307 |
+
if p.grad is not None and not torch.isfinite(p.grad).all():
|
| 308 |
+
print(f"!!! WARNING: NaN/Inf gradient BEFORE step. Skipping step. !!!")
|
| 309 |
+
found_non_finite_grad = True; break
|
| 310 |
+
if found_non_finite_grad: optimizer.zero_grad(); continue
|
| 311 |
+
|
| 312 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(params_to_optimize, MAX_GRAD_NORM)
|
| 313 |
+
if not torch.isfinite(grad_norm): print(f"!!! WARNING: Grad norm NaN/Inf ({grad_norm.item()}) AFTER clipping. Skipping step. !!!"); optimizer.zero_grad(); continue
|
| 314 |
+
|
| 315 |
+
optimizer.step()
|
| 316 |
+
if scheduler: scheduler.step()
|
| 317 |
+
optimizer.zero_grad()
|
| 318 |
+
step_counter += 1
|
| 319 |
+
|
| 320 |
+
# --- Logging ---
|
| 321 |
+
if step_counter % FT_LOG_STEPS == 0 and valid_batches_accum > 0:
|
| 322 |
+
avg_class_loss = epoch_class_loss_accum / valid_batches_accum
|
| 323 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 324 |
+
# --- Test Evaluation (Class loss only) ---
|
| 325 |
+
test_class_loss_val = float('nan')
|
| 326 |
+
if test_loader_has_data:
|
| 327 |
+
model.eval()
|
| 328 |
+
with torch.no_grad():
|
| 329 |
+
try:
|
| 330 |
+
test_batch = next(iter(test_loader))
|
| 331 |
+
if test_batch:
|
| 332 |
+
# Unpack test data needed for forward pass -> logits
|
| 333 |
+
t_images = test_batch['image'].to(DEVICE).to(DTYPE)
|
| 334 |
+
t_p_ids = test_batch['prompt_ids'].to(DEVICE)
|
| 335 |
+
t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
|
| 336 |
+
t_t_ids = test_batch['target_ids'].to(DEVICE)
|
| 337 |
+
t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
|
| 338 |
+
t_gen_targets = test_batch['generative_targets'].to(DEVICE) # Need this for external CE calc
|
| 339 |
+
# Pass None for other args if model handles it
|
| 340 |
+
t_cont_coords = test_batch.get('continuous_coords'); t_coords_mask = test_batch.get('coords_mask'); t_num_pts = test_batch.get('num_points')
|
| 341 |
+
if t_cont_coords is not None: t_cont_coords = t_cont_coords.to(DEVICE)
|
| 342 |
+
if t_coords_mask is not None: t_coords_mask = t_coords_mask.to(DEVICE)
|
| 343 |
+
|
| 344 |
+
# Run forward just to get logits
|
| 345 |
+
logits_t, _, _, _, _, _, *_ = model(
|
| 346 |
+
t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask, t_gen_targets,
|
| 347 |
+
t_cont_coords, t_coords_mask
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Calculate CE loss externally for logging
|
| 351 |
+
if logits_t is not None and t_gen_targets is not None:
|
| 352 |
+
try:
|
| 353 |
+
# Prepare padded targets matching logits_t shape
|
| 354 |
+
B_test, T_logits_t, V_test = logits_t.shape
|
| 355 |
+
_, T_target_orig_t = t_gen_targets.shape
|
| 356 |
+
N_img_test = model.num_patches
|
| 357 |
+
T_prompt_test = t_p_ids.shape[1]
|
| 358 |
+
T_combined_expected_t = N_img_test + T_prompt_test + T_target_orig_t
|
| 359 |
+
|
| 360 |
+
if T_logits_t != T_combined_expected_t:
|
| 361 |
+
T_target_in_logits_t = max(0, T_logits_t - (N_img_test + T_prompt_test))
|
| 362 |
+
generative_targets_sliced_t = t_gen_targets[:, :T_target_in_logits_t]
|
| 363 |
+
combined_class_targets_t = torch.cat([
|
| 364 |
+
torch.full((B_test, T_logits_t - T_target_in_logits_t), -100, dtype=torch.long, device=DEVICE),
|
| 365 |
+
generative_targets_sliced_t
|
| 366 |
+
], dim=1)
|
| 367 |
+
else:
|
| 368 |
+
combined_class_targets_t = torch.cat([
|
| 369 |
+
torch.full((B_test, N_img_test + T_prompt_test), -100, dtype=torch.long, device=DEVICE),
|
| 370 |
+
t_gen_targets
|
| 371 |
+
], dim=1)
|
| 372 |
+
|
| 373 |
+
if logits_t.shape[1] != combined_class_targets_t.shape[1]:
|
| 374 |
+
raise ValueError("Shape mismatch test CE!")
|
| 375 |
+
|
| 376 |
+
t_class_loss = F.cross_entropy(logits_t.view(-1, V_test), combined_class_targets_t.view(-1), ignore_index=-100)
|
| 377 |
+
test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
|
| 378 |
+
except Exception as e_ce_test: print(f"Error CE Test: {e_ce_test}")
|
| 379 |
+
except StopIteration: print("Info: Test loader exhausted during logging.")
|
| 380 |
+
except Exception as e: print(f"Error during test eval: {e}")
|
| 381 |
+
model.train() # Set back to train mode
|
| 382 |
+
|
| 383 |
+
# Log data
|
| 384 |
+
log_data = { # Simplified logging
|
| 385 |
+
"train/class_loss": avg_class_loss,
|
| 386 |
+
"test/class_loss": test_class_loss_val,
|
| 387 |
+
"epoch": epoch + ((batch_idx + 1) / len(train_loader)),
|
| 388 |
+
"step": step_counter,
|
| 389 |
+
"learning_rate": current_lr,
|
| 390 |
+
"gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else float('nan'),
|
| 391 |
+
}
|
| 392 |
+
pbar.set_postfix({"lr": f"{current_lr:.2e}", "cls_loss": f"{avg_class_loss:.4f}", "gnorm": f"{log_data['gradient_norm']:.3f}"})
|
| 393 |
+
if wandb_enabled: wandb.log(log_data, step=step_counter)
|
| 394 |
+
|
| 395 |
+
# Reset accumulators
|
| 396 |
+
epoch_class_loss_accum = 0.0; valid_batches_accum = 0
|
| 397 |
+
|
| 398 |
+
# --- End of Epoch ---
|
| 399 |
+
print(f"\nFT Epoch {epoch+1}/{args.ft_epochs} completed.")
|
| 400 |
+
# Optional: Save checkpoint periodically
|
| 401 |
+
if (epoch + 1) % 5 == 0 or (epoch + 1) == args.ft_epochs:
|
| 402 |
+
chkpt_path = args.output_model_path.replace(".pth", f"_epoch{epoch+1}.pth")
|
| 403 |
+
try:
|
| 404 |
+
torch.save(model.state_dict(), chkpt_path)
|
| 405 |
+
print(f"Checkpoint saved to: {chkpt_path}")
|
| 406 |
+
except Exception as e: print(f"Error saving checkpoint: {e}")
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# --- End of Fine-tuning ---
|
| 410 |
+
print("\nLM head fine-tuning with CE loss completed!")
|
| 411 |
+
try:
|
| 412 |
+
torch.save(model.state_dict(), args.output_model_path)
|
| 413 |
+
print(f"Fine-tuned model saved to: {args.output_model_path}")
|
| 414 |
+
except Exception as e: print(f"Error saving fine-tuned model: {e}")
|
| 415 |
+
|
| 416 |
+
if wandb_enabled:
|
| 417 |
+
wandb.finish()
|
| 418 |
+
torch.autograd.set_detect_anomaly(False) # Disable anomaly detection
|
infer.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from constants import *
|
| 2 |
+
from utils import image_to_tensor, tokenizer, tensor_to_image, vocab_size, tokenizer
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from PIL import ImageDraw, Image
|
| 6 |
+
from dataset import create_test_dataloader
|
| 7 |
+
from vision_language_model import VisionLanguageModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
model = VisionLanguageModel(
|
| 11 |
+
n_embd=HIDDEN_DIM,
|
| 12 |
+
vocab_size=vocab_size,
|
| 13 |
+
img_size=IMAGE_SIZE,
|
| 14 |
+
patch_size=PATCH_SIZE,
|
| 15 |
+
num_heads=NUM_HEADS,
|
| 16 |
+
num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
|
| 17 |
+
num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
|
| 18 |
+
emb_dropout=DROPOUT,
|
| 19 |
+
blk_dropout=DROPOUT,
|
| 20 |
+
max_context=CONTEXT_LENGTH,
|
| 21 |
+
shared_embed_dim=SHARED_EMBED_DIM,
|
| 22 |
+
lambda_contrastive=LAMBDA_CONTRASTIVE,
|
| 23 |
+
lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
|
| 24 |
+
).to(DEVICE)
|
| 25 |
+
|
| 26 |
+
MODEL_PATH = "model_regression_multi_first_100.pth" # "model_regression_multi_16.pth"
|
| 27 |
+
|
| 28 |
+
if DEVICE == "cuda":
|
| 29 |
+
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
|
| 30 |
+
else:
|
| 31 |
+
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True, map_location=torch.device('cpu')))
|
| 32 |
+
model.eval()
|
| 33 |
+
|
| 34 |
+
def generate_sample_from_image_text(
|
| 35 |
+
model,
|
| 36 |
+
image_path,
|
| 37 |
+
prompt_label,
|
| 38 |
+
tokenizer,
|
| 39 |
+
device,
|
| 40 |
+
max_new_tokens=70,
|
| 41 |
+
temperature=0.8,
|
| 42 |
+
top_k=10,
|
| 43 |
+
output_path="generated_output.png"
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Generates a prediction for an image and prompt text and saves it to a file.
|
| 47 |
+
Generation loop is implemented *within* this function.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
model: The trained VisionLanguageModel.
|
| 51 |
+
image_path: Path to the input image.
|
| 52 |
+
prompt_label: Text prompt/label to use.
|
| 53 |
+
tokenizer: The tokenizer used for training.
|
| 54 |
+
device: The computation device ('cuda' or 'cpu').
|
| 55 |
+
max_new_tokens (int): Max tokens to generate after the prompt.
|
| 56 |
+
temperature (float): Softmax temperature for sampling.
|
| 57 |
+
top_k (int): K for top-k sampling (0 or None to disable).
|
| 58 |
+
output_path (str): Path where to save the output image.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
None. Saves the image with prompt and generated output to a file.
|
| 62 |
+
"""
|
| 63 |
+
model.eval() # Set the model to evaluation mode
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
with torch.no_grad(): # No need to track gradients during inference
|
| 67 |
+
# --- 1. Prepare Initial Inputs ---
|
| 68 |
+
# Load and process image
|
| 69 |
+
image = Image.open(image_path)
|
| 70 |
+
image_tensor = image_to_tensor(image).unsqueeze(0).to(device) # Add batch dim
|
| 71 |
+
|
| 72 |
+
# Tokenize prompt
|
| 73 |
+
prompt_text = f"<point_start>{prompt_label}<point_end>"
|
| 74 |
+
prompt_tokens = tokenizer(prompt_text, return_tensors="pt", truncation=True, padding=False)
|
| 75 |
+
prompt_ids = prompt_tokens.input_ids.to(device)
|
| 76 |
+
prompt_attention_mask = prompt_tokens.attention_mask.to(device)
|
| 77 |
+
B = 1 # We are processing one sample at a time
|
| 78 |
+
|
| 79 |
+
print(f"--- Generating Sample (Manual Loop) ---")
|
| 80 |
+
print(f"Original Label/Prompt Hint: {prompt_label}")
|
| 81 |
+
print(f"Input Prompt Tokens Decoded: {prompt_text}")
|
| 82 |
+
|
| 83 |
+
# --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) ---
|
| 84 |
+
image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C)
|
| 85 |
+
image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C)
|
| 86 |
+
prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C)
|
| 87 |
+
|
| 88 |
+
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
|
| 89 |
+
result_start_embed = model.decoder.token_embedding_table(
|
| 90 |
+
torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# The initial sequence fed to the decoder blocks consists of image + prompt
|
| 94 |
+
current_embeds = torch.cat([
|
| 95 |
+
image_embeds_decoder,
|
| 96 |
+
prompt_embeds_decoder,
|
| 97 |
+
result_start_embed # Add the embedding for the first expected output token
|
| 98 |
+
], dim=1)
|
| 99 |
+
generated_ids = [] # Store newly generated IDs
|
| 100 |
+
|
| 101 |
+
# --- 3. Autoregressive Generation Loop ---
|
| 102 |
+
for _ in range(max_new_tokens):
|
| 103 |
+
T_current = current_embeds.shape[1]
|
| 104 |
+
|
| 105 |
+
# Truncate if necessary (keep recent context)
|
| 106 |
+
if T_current > model.decoder.max_context: # Access max_context from decoder
|
| 107 |
+
print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}")
|
| 108 |
+
current_embeds = current_embeds[:, -model.decoder.max_context:, :]
|
| 109 |
+
T_current = model.decoder.max_context
|
| 110 |
+
|
| 111 |
+
# Prepare positional embeddings for current length
|
| 112 |
+
pos = torch.arange(0, T_current, dtype=torch.long, device=device)
|
| 113 |
+
pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices
|
| 114 |
+
pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C)
|
| 115 |
+
x = current_embeds + pos_emb
|
| 116 |
+
|
| 117 |
+
# Create attention mask (all ones, causal handles future)
|
| 118 |
+
# Note: We don't need padding mask here as we handle one sequence without padding
|
| 119 |
+
attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long)
|
| 120 |
+
|
| 121 |
+
# Pass through Decoder Blocks
|
| 122 |
+
for block in model.decoder.blocks:
|
| 123 |
+
# We assume the block forward takes (x, attention_mask)
|
| 124 |
+
x = block(x, attention_mask=attention_mask)
|
| 125 |
+
|
| 126 |
+
# Final Layer Norm and LM Head for the *last* token prediction
|
| 127 |
+
x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C)
|
| 128 |
+
logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V)
|
| 129 |
+
logits = logits.squeeze(1) # (B, V) -> (1, V)
|
| 130 |
+
|
| 131 |
+
# Sampling
|
| 132 |
+
logits = logits / temperature
|
| 133 |
+
if top_k is not None and top_k > 0:
|
| 134 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 135 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 136 |
+
|
| 137 |
+
probs = F.softmax(logits, dim=-1)
|
| 138 |
+
# idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution
|
| 139 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic
|
| 140 |
+
|
| 141 |
+
# Store generated ID
|
| 142 |
+
generated_ids.append(idx_next)
|
| 143 |
+
|
| 144 |
+
# Stop if EOS token is generated
|
| 145 |
+
if idx_next.item() == tokenizer.eos_token_id:
|
| 146 |
+
print("EOS token generated.")
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
# Prepare for next iteration: Append embedding of new token
|
| 150 |
+
next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C)
|
| 151 |
+
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim
|
| 152 |
+
|
| 153 |
+
# --- 4. Combine and Decode Results ---
|
| 154 |
+
if generated_ids:
|
| 155 |
+
generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated)
|
| 156 |
+
initial_target_ids = torch.tensor([[result_start_token_id]], device=device)
|
| 157 |
+
full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1)
|
| 158 |
+
else:
|
| 159 |
+
full_generated_sequence_ids = prompt_ids # Nothing was generated
|
| 160 |
+
|
| 161 |
+
full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False)
|
| 162 |
+
print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}")
|
| 163 |
+
|
| 164 |
+
# --- 5. Save visualization to file ---
|
| 165 |
+
save_coords_visualization(
|
| 166 |
+
image_tensor=image_tensor[0], # Remove batch dim for visualization
|
| 167 |
+
full_decoded_text=full_decoded_text,
|
| 168 |
+
tokenizer=tokenizer,
|
| 169 |
+
image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined
|
| 170 |
+
num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined
|
| 171 |
+
output_path=output_path
|
| 172 |
+
)
|
| 173 |
+
print(f"Visualization saved to: {output_path}")
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"An error occurred during sample generation: {e}")
|
| 177 |
+
import traceback
|
| 178 |
+
traceback.print_exc()
|
| 179 |
+
|
| 180 |
+
def generate_sample_from_test_loader(
|
| 181 |
+
model,
|
| 182 |
+
test_loader,
|
| 183 |
+
tokenizer,
|
| 184 |
+
device,
|
| 185 |
+
max_new_tokens=70,
|
| 186 |
+
temperature=0.8,
|
| 187 |
+
top_k=10,
|
| 188 |
+
output_path="generated_output.png",
|
| 189 |
+
TEST_BATCH=8,
|
| 190 |
+
TEST_IDX=1
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Generates a prediction for one sample from the test loader and saves it to a file.
|
| 194 |
+
Generation loop is implemented *within* this function.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
model: The trained VisionLanguageModel.
|
| 198 |
+
test_loader: DataLoader for the test set.
|
| 199 |
+
tokenizer: The tokenizer used for training.
|
| 200 |
+
device: The computation device ('cuda' or 'cpu').
|
| 201 |
+
max_new_tokens (int): Max tokens to generate after the prompt.
|
| 202 |
+
temperature (float): Softmax temperature for sampling.
|
| 203 |
+
top_k (int): K for top-k sampling (0 or None to disable).
|
| 204 |
+
output_path (str): Path where to save the output image.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
None. Saves the image with prompt and generated output to a file.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
if not test_loader or len(test_loader.dataset) == 0:
|
| 211 |
+
print("Test loader is empty or not available.")
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
model.eval() # Set the model to evaluation mode
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
# Get a single batch from the test loader
|
| 218 |
+
with torch.no_grad(): # No need to track gradients during inference
|
| 219 |
+
my_iter = iter(test_loader)
|
| 220 |
+
for i in range(TEST_BATCH):
|
| 221 |
+
_ = next(my_iter)
|
| 222 |
+
batch = next(my_iter)
|
| 223 |
+
|
| 224 |
+
if batch is None:
|
| 225 |
+
print("Test loader yielded an empty batch.")
|
| 226 |
+
return
|
| 227 |
+
if batch['image'].shape[0] == 0:
|
| 228 |
+
print("Test loader yielded a batch with 0 items.")
|
| 229 |
+
return
|
| 230 |
+
|
| 231 |
+
# --- 1. Prepare Initial Inputs ---
|
| 232 |
+
image_tensor = batch['image'][TEST_IDX:TEST_IDX+1].to(device) # (1, 3, H, W)
|
| 233 |
+
prompt_ids = batch['prompt_ids'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt)
|
| 234 |
+
prompt_attention_mask = batch['prompt_attention_mask'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt)
|
| 235 |
+
label = batch['label'][TEST_IDX]
|
| 236 |
+
B = 1 # We are processing one sample at a time
|
| 237 |
+
|
| 238 |
+
print(f"--- Generating Sample (Manual Loop) ---")
|
| 239 |
+
print(f"Original Label/Prompt Hint: {label}")
|
| 240 |
+
prompt_text = tokenizer.decode(prompt_ids[0], skip_special_tokens=False)
|
| 241 |
+
print(f"Input Prompt Tokens Decoded: {prompt_text}")
|
| 242 |
+
|
| 243 |
+
# --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) ---
|
| 244 |
+
image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C)
|
| 245 |
+
image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C)
|
| 246 |
+
prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C)
|
| 247 |
+
|
| 248 |
+
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
|
| 249 |
+
result_start_embed = model.decoder.token_embedding_table(
|
| 250 |
+
torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# The initial sequence fed to the decoder blocks consists of image + prompt
|
| 254 |
+
current_embeds = torch.cat([
|
| 255 |
+
image_embeds_decoder,
|
| 256 |
+
prompt_embeds_decoder,
|
| 257 |
+
result_start_embed # Add the embedding for the first expected output token
|
| 258 |
+
], dim=1)
|
| 259 |
+
# current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1) # (1, T_initial, C)
|
| 260 |
+
generated_ids = [] # Store newly generated IDs
|
| 261 |
+
|
| 262 |
+
# --- 3. Autoregressive Generation Loop ---
|
| 263 |
+
for _ in range(max_new_tokens):
|
| 264 |
+
T_current = current_embeds.shape[1]
|
| 265 |
+
|
| 266 |
+
# Truncate if necessary (keep recent context)
|
| 267 |
+
if T_current > model.decoder.max_context: # Access max_context from decoder
|
| 268 |
+
print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}")
|
| 269 |
+
current_embeds = current_embeds[:, -model.decoder.max_context:, :]
|
| 270 |
+
T_current = model.decoder.max_context
|
| 271 |
+
|
| 272 |
+
# Prepare positional embeddings for current length
|
| 273 |
+
pos = torch.arange(0, T_current, dtype=torch.long, device=device)
|
| 274 |
+
pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices
|
| 275 |
+
pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C)
|
| 276 |
+
x = current_embeds + pos_emb
|
| 277 |
+
|
| 278 |
+
# Create attention mask (all ones, causal handles future)
|
| 279 |
+
# Note: We don't need padding mask here as we handle one sequence without padding
|
| 280 |
+
attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long)
|
| 281 |
+
|
| 282 |
+
# Pass through Decoder Blocks
|
| 283 |
+
for block in model.decoder.blocks:
|
| 284 |
+
# We assume the block forward takes (x, attention_mask)
|
| 285 |
+
x = block(x, attention_mask=attention_mask)
|
| 286 |
+
|
| 287 |
+
# Final Layer Norm and LM Head for the *last* token prediction
|
| 288 |
+
x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C)
|
| 289 |
+
logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V)
|
| 290 |
+
logits = logits.squeeze(1) # (B, V) -> (1, V)
|
| 291 |
+
|
| 292 |
+
# Sampling
|
| 293 |
+
logits = logits / temperature
|
| 294 |
+
if top_k is not None and top_k > 0:
|
| 295 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 296 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 297 |
+
|
| 298 |
+
probs = F.softmax(logits, dim=-1)
|
| 299 |
+
# idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution
|
| 300 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic
|
| 301 |
+
|
| 302 |
+
# Store generated ID
|
| 303 |
+
generated_ids.append(idx_next)
|
| 304 |
+
|
| 305 |
+
# Stop if EOS token is generated
|
| 306 |
+
if idx_next.item() == tokenizer.eos_token_id:
|
| 307 |
+
print("EOS token generated.")
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
# Prepare for next iteration: Append embedding of new token
|
| 311 |
+
next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C)
|
| 312 |
+
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim
|
| 313 |
+
|
| 314 |
+
# --- 4. Combine and Decode Results ---
|
| 315 |
+
if generated_ids:
|
| 316 |
+
generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated)
|
| 317 |
+
initial_target_ids = torch.tensor([[result_start_token_id]], device=device)
|
| 318 |
+
full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1)
|
| 319 |
+
else:
|
| 320 |
+
full_generated_sequence_ids = prompt_ids # Nothing was generated
|
| 321 |
+
|
| 322 |
+
full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False)
|
| 323 |
+
print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}")
|
| 324 |
+
|
| 325 |
+
# --- 5. Save visualization to file ---
|
| 326 |
+
save_coords_visualization(
|
| 327 |
+
image_tensor=image_tensor[0], # Remove batch dim for visualization
|
| 328 |
+
full_decoded_text=full_decoded_text,
|
| 329 |
+
tokenizer=tokenizer,
|
| 330 |
+
image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined
|
| 331 |
+
num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined
|
| 332 |
+
output_path=output_path
|
| 333 |
+
)
|
| 334 |
+
print(f"Visualization saved to: {output_path}")
|
| 335 |
+
|
| 336 |
+
except StopIteration:
|
| 337 |
+
print("Test loader is exhausted.")
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(f"An error occurred during sample generation: {e}")
|
| 340 |
+
import traceback
|
| 341 |
+
traceback.print_exc()
|
| 342 |
+
|
| 343 |
+
def parse_coordinate_tokens(text, tokenizer, num_bins):
|
| 344 |
+
"""
|
| 345 |
+
Parses generated text to extract coordinate bin tokens.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
text (str): The decoded output text from the model.
|
| 349 |
+
tokenizer: The tokenizer.
|
| 350 |
+
num_bins (int): The number of coordinate bins used.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
list[tuple(int, int)]: A list of (x_bin, y_bin) tuples, or None if parsing fails.
|
| 354 |
+
"""
|
| 355 |
+
coords = []
|
| 356 |
+
try:
|
| 357 |
+
# Basic parsing - look for the pattern
|
| 358 |
+
x_start_token = "<pointx_start>"
|
| 359 |
+
x_end_token = "<pointx_end>"
|
| 360 |
+
y_start_token = "<pointy_start>"
|
| 361 |
+
y_end_token = "<pointy_end>"
|
| 362 |
+
result_end_token = "<result_end>"
|
| 363 |
+
|
| 364 |
+
# Find where the actual results start
|
| 365 |
+
try:
|
| 366 |
+
start_index = text.index("<result_start>") + len("<result_start>")
|
| 367 |
+
except ValueError:
|
| 368 |
+
print("Warning: <result_start> not found in generated text.")
|
| 369 |
+
return None
|
| 370 |
+
|
| 371 |
+
# Find where results end
|
| 372 |
+
try:
|
| 373 |
+
end_index = text.index(result_end_token, start_index)
|
| 374 |
+
except ValueError:
|
| 375 |
+
end_index = len(text) # Use end of string if <result_end> is missing
|
| 376 |
+
print(f"Warning: {result_end_token} not found. Parsing until end of string.")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
current_pos = start_index
|
| 380 |
+
while current_pos < end_index:
|
| 381 |
+
# Find next X coordinate
|
| 382 |
+
x_start_idx = text.find(x_start_token, current_pos)
|
| 383 |
+
if x_start_idx == -1 or x_start_idx >= end_index: break # No more x points found
|
| 384 |
+
x_start_idx += len(x_start_token)
|
| 385 |
+
|
| 386 |
+
x_end_idx = text.find(x_end_token, x_start_idx)
|
| 387 |
+
if x_end_idx == -1 or x_end_idx >= end_index: break # Malformed
|
| 388 |
+
|
| 389 |
+
x_token_str = text[x_start_idx:x_end_idx].strip()
|
| 390 |
+
|
| 391 |
+
# Find next Y coordinate (must follow X)
|
| 392 |
+
y_start_idx = text.find(y_start_token, x_end_idx)
|
| 393 |
+
if y_start_idx == -1 or y_start_idx >= end_index: break # No corresponding y point
|
| 394 |
+
y_start_idx += len(y_start_token)
|
| 395 |
+
|
| 396 |
+
y_end_idx = text.find(y_end_token, y_start_idx)
|
| 397 |
+
if y_end_idx == -1 or y_end_idx >= end_index: break # Malformed
|
| 398 |
+
|
| 399 |
+
y_token_str = text[y_start_idx:y_end_idx].strip()
|
| 400 |
+
|
| 401 |
+
x_token_str = x_token_str[:-1]
|
| 402 |
+
y_token_str = y_token_str[:-1]
|
| 403 |
+
|
| 404 |
+
# Convert token strings to bin numbers
|
| 405 |
+
try:
|
| 406 |
+
x_bin = int(x_token_str.split("_")[-1])
|
| 407 |
+
y_bin = int(y_token_str.split("_")[-1])
|
| 408 |
+
if 0 <= x_bin < num_bins and 0 <= y_bin < num_bins:
|
| 409 |
+
coords.append((x_bin, y_bin))
|
| 410 |
+
else:
|
| 411 |
+
print(f"Warning: Parsed bin indices out of range ({x_bin}, {y_bin}). Skipping.")
|
| 412 |
+
except (ValueError, IndexError):
|
| 413 |
+
print(f"Warning: Could not parse bins from tokens '{x_token_str}', '{y_token_str}'. Skipping.")
|
| 414 |
+
|
| 415 |
+
# Move search position past the found Y token
|
| 416 |
+
current_pos = y_end_idx + len(y_end_token)
|
| 417 |
+
|
| 418 |
+
return coords if coords else None
|
| 419 |
+
|
| 420 |
+
except Exception as e:
|
| 421 |
+
print(f"Error during coordinate parsing: {e}")
|
| 422 |
+
return None
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def save_coords_visualization(image_tensor, full_decoded_text, tokenizer, image_size, num_bins, output_path):
|
| 426 |
+
"""Parses coords, draws them on the image, and saves to a file."""
|
| 427 |
+
parsed_bins = parse_coordinate_tokens(full_decoded_text, tokenizer, num_bins)
|
| 428 |
+
|
| 429 |
+
# Convert tensor to PIL image for drawing
|
| 430 |
+
try:
|
| 431 |
+
pil_image = tensor_to_image(image_tensor.cpu()) # Ensure tensor is on CPU
|
| 432 |
+
except Exception as e:
|
| 433 |
+
print(f"Error converting tensor to image: {e}")
|
| 434 |
+
# Create a placeholder image if conversion fails
|
| 435 |
+
pil_image = Image.new('RGB', (image_size, image_size), color='white')
|
| 436 |
+
draw = ImageDraw.Draw(pil_image)
|
| 437 |
+
draw.text((10, 10), "Image conversion failed", fill="black")
|
| 438 |
+
pil_image.save(output_path)
|
| 439 |
+
return
|
| 440 |
+
|
| 441 |
+
draw = ImageDraw.Draw(pil_image)
|
| 442 |
+
radius = 5 # Radius of the drawn point
|
| 443 |
+
|
| 444 |
+
if parsed_bins:
|
| 445 |
+
print(f"\nParsed Coordinate Bins: {parsed_bins}")
|
| 446 |
+
bin_size_pixels = image_size / num_bins
|
| 447 |
+
for x_bin, y_bin in parsed_bins:
|
| 448 |
+
# Calculate center of the bin in pixels
|
| 449 |
+
center_x = (x_bin + 0.5) * bin_size_pixels
|
| 450 |
+
center_y = (y_bin + 0.5) * bin_size_pixels
|
| 451 |
+
|
| 452 |
+
# Draw a circle
|
| 453 |
+
bbox = [center_x - radius, center_y - radius, center_x + radius, center_y + radius]
|
| 454 |
+
draw.ellipse(bbox, outline="red", width=3)
|
| 455 |
+
# Optional: Draw bin boundaries for debugging
|
| 456 |
+
# draw.rectangle([x_bin*bin_size_pixels, y_bin*bin_size_pixels, (x_bin+1)*bin_size_pixels, (y_bin+1)*bin_size_pixels], outline="blue", width=1)
|
| 457 |
+
|
| 458 |
+
# Add a text label with the coordinates at the top of the image
|
| 459 |
+
coord_text = f"Generated Point(s): {parsed_bins}"
|
| 460 |
+
draw.text((10, 10), coord_text, fill="red")
|
| 461 |
+
else:
|
| 462 |
+
print("\nCould not parse valid coordinates from the generated text.")
|
| 463 |
+
# Add a text label indicating no coordinates were found
|
| 464 |
+
draw.text((10, 10), "No Coordinates Parsed", fill="red")
|
| 465 |
+
|
| 466 |
+
# Save the image to file
|
| 467 |
+
pil_image.save(output_path)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
import argparse
|
| 471 |
+
|
| 472 |
+
# --- Example Usage ---
|
| 473 |
+
# python infer.py --image ./data/test_images/image_1.png --prompt "a red apple"
|
| 474 |
+
if __name__ == "__main__":
|
| 475 |
+
parser = argparse.ArgumentParser()
|
| 476 |
+
parser.add_argument('--image', type=str, help='Path to input image')
|
| 477 |
+
parser.add_argument('--prompt', type=str, help='Prompt label for generation')
|
| 478 |
+
args = parser.parse_args()
|
| 479 |
+
if args.image and args.prompt:
|
| 480 |
+
# Use image and prompt based generation
|
| 481 |
+
if 'model' in locals() and 'tokenizer' in locals():
|
| 482 |
+
generate_sample_from_image_text(
|
| 483 |
+
model=model,
|
| 484 |
+
image_path=args.image,
|
| 485 |
+
prompt_label=args.prompt,
|
| 486 |
+
tokenizer=tokenizer,
|
| 487 |
+
device=DEVICE,
|
| 488 |
+
output_path="model_prediction.png"
|
| 489 |
+
)
|
| 490 |
+
else:
|
| 491 |
+
print("Please ensure 'model' and 'tokenizer' are loaded before running generation.")
|
| 492 |
+
else:
|
| 493 |
+
# Use test loader based generation
|
| 494 |
+
if 'model' in locals() and 'test_loader' in locals() and 'tokenizer' in locals():
|
| 495 |
+
test_loader = create_test_dataloader(batch_size=2, num_workers=0)
|
| 496 |
+
generate_sample_from_test_loader(
|
| 497 |
+
model=model,
|
| 498 |
+
test_loader=test_loader,
|
| 499 |
+
tokenizer=tokenizer,
|
| 500 |
+
device=DEVICE,
|
| 501 |
+
output_path="model_prediction.png"
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
print("Please ensure 'model', 'test_loader', and 'tokenizer' are loaded before running generation.")
|
model_components.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from constants import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class PatchEmbeddings(nn.Module):
|
| 7 |
+
def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
|
| 10 |
+
|
| 11 |
+
def forward(self, X):
|
| 12 |
+
X = self.conv(X) # (B, C, H/P, W/P)
|
| 13 |
+
X = X.flatten(2) # (B, C, N) where N = (H/P)*(W/P)
|
| 14 |
+
X = X.transpose(1, 2) # (B, N, C)
|
| 15 |
+
return X
|
| 16 |
+
|
| 17 |
+
class Head(nn.Module):
|
| 18 |
+
def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.key = nn.Linear(n_embd, head_size, bias=False)
|
| 21 |
+
self.query = nn.Linear(n_embd, head_size, bias=False)
|
| 22 |
+
self.value = nn.Linear(n_embd, head_size, bias=False)
|
| 23 |
+
self.dropout = nn.Dropout(dropout)
|
| 24 |
+
self.is_decoder = is_decoder
|
| 25 |
+
# causal mask is registered persistent=False so it's not saved in state_dict
|
| 26 |
+
if self.is_decoder:
|
| 27 |
+
self.register_buffer("bias", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool))
|
| 28 |
+
.view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def forward(self, x, attention_mask=None):
|
| 32 |
+
B, T, C = x.shape
|
| 33 |
+
# print(f"B = {B} T={T}, C={C}")
|
| 34 |
+
k = self.key(x) # (B, T, hs)
|
| 35 |
+
q = self.query(x) # (B, T, hs)
|
| 36 |
+
v = self.value(x) # (B, T, hs)
|
| 37 |
+
|
| 38 |
+
# Compute attention scores ("affinities")
|
| 39 |
+
wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
|
| 40 |
+
|
| 41 |
+
if self.is_decoder:
|
| 42 |
+
# Apply causal mask
|
| 43 |
+
# Ensure the mask is sliced correctly if T < CONTEXT_LENGTH
|
| 44 |
+
causal_mask = self.bias[:, :T, :T]
|
| 45 |
+
wei = wei.masked_fill(causal_mask == 0, float('-inf'))
|
| 46 |
+
|
| 47 |
+
if attention_mask is not None:
|
| 48 |
+
# Apply padding mask (for text tokens)
|
| 49 |
+
# attention_mask shape: (B, T_combined) -> needs expansion
|
| 50 |
+
# Expand mask: (B, T) -> (B, 1, 1, T) or (B, 1, T, T) depending on what needs masking
|
| 51 |
+
# Mask where attention_mask is 0
|
| 52 |
+
# attention_mask shape: (B, T) == (B, T_key)
|
| 53 |
+
# Expand mask to align with wei's key dimension for broadcasting across queries
|
| 54 |
+
# Target shape for mask: [B, 1, T_key]
|
| 55 |
+
# print(f"attn mask = {attention_mask.shape}")
|
| 56 |
+
# print(f"wei shape = {wei.shape}")
|
| 57 |
+
mask = attention_mask.unsqueeze(1) # Shape [B, 1, T]
|
| 58 |
+
# Apply mask using broadcasting rules. masked_fill condition needs to be broadcastable to wei [B, T_query, T_key]
|
| 59 |
+
# (mask == 0) gives a boolean tensor of shape [B, 1, T]
|
| 60 |
+
# This broadcasts correctly: dim 2 (T vs T) matches, dim 1 (1 vs T) broadcasts 1->T, dim 0 (B vs B) matches.
|
| 61 |
+
wei = wei.masked_fill(mask == 0, float('-inf'))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Apply softmax
|
| 65 |
+
wei = F.softmax(wei, dim=-1)
|
| 66 |
+
wei = self.dropout(wei)
|
| 67 |
+
|
| 68 |
+
# Perform weighted aggregation of values
|
| 69 |
+
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
|
| 70 |
+
# print(f"out shape = {out.shape}")
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
class MultiHeadAttention(nn.Module):
|
| 74 |
+
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
|
| 75 |
+
super().__init__()
|
| 76 |
+
assert n_embd % num_heads == 0
|
| 77 |
+
head_size = n_embd // num_heads
|
| 78 |
+
self.heads = nn.ModuleList([
|
| 79 |
+
Head(n_embd, head_size, dropout, is_decoder)
|
| 80 |
+
for _ in range(num_heads)
|
| 81 |
+
])
|
| 82 |
+
self.proj = nn.Linear(n_embd, n_embd) # n_embd = num_heads * head_size
|
| 83 |
+
self.dropout = nn.Dropout(dropout)
|
| 84 |
+
self.is_decoder = is_decoder # Store is_decoder status
|
| 85 |
+
|
| 86 |
+
def forward(self, x, attention_mask=None):
|
| 87 |
+
# Pass attention_mask only if it's a decoder block dealing with combined sequence
|
| 88 |
+
out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1)
|
| 89 |
+
out = self.dropout(self.proj(out))
|
| 90 |
+
return out
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class FeedForward(nn.Module):
|
| 94 |
+
""" a simple linear layer followed by a non-linearity """
|
| 95 |
+
def __init__(self, n_embd, dropout=DROPOUT):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.net = nn.Sequential(
|
| 98 |
+
nn.Linear(n_embd, 4 * n_embd),
|
| 99 |
+
nn.GELU(), # Changed from ReLU to GELU, common in transformers
|
| 100 |
+
nn.Linear(4 * n_embd, n_embd), # Projection back to residual stream
|
| 101 |
+
nn.Dropout(dropout),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
return self.net(x)
|
| 106 |
+
|
| 107 |
+
class Block(nn.Module):
|
| 108 |
+
""" Transformer block: communication followed by computation """
|
| 109 |
+
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 112 |
+
self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
|
| 113 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 114 |
+
self.ffn = FeedForward(n_embd, dropout)
|
| 115 |
+
self.is_decoder = is_decoder # Store is_decoder status
|
| 116 |
+
|
| 117 |
+
def forward(self, x, attention_mask=None):
|
| 118 |
+
# Pass attention_mask only if it's a decoder block
|
| 119 |
+
# print(f"is decoder = {self.is_decoder} input shape = {x.shape}")
|
| 120 |
+
x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None)
|
| 121 |
+
x = x + self.ffn(self.ln2(x))
|
| 122 |
+
# print(f"output shape = {x.shape}")
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
class ViT(nn.Module):
|
| 126 |
+
def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM,
|
| 127 |
+
num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens)
|
| 130 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
|
| 131 |
+
num_patches = (img_size // patch_size) ** 2
|
| 132 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) # Smaller init
|
| 133 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 134 |
+
# ViT blocks are NOT decoders (no causal mask)
|
| 135 |
+
self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
|
| 136 |
+
self.layer_norm = nn.LayerNorm(num_hiddens) # Final LN
|
| 137 |
+
|
| 138 |
+
def forward(self, X):
|
| 139 |
+
x = self.patch_embedding(X) # (B, N, C)
|
| 140 |
+
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, C)
|
| 141 |
+
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, C)
|
| 142 |
+
# Add positional embedding
|
| 143 |
+
x = x + self.pos_embedding # Uses broadcasting
|
| 144 |
+
x = self.dropout(x)
|
| 145 |
+
for block in self.blocks:
|
| 146 |
+
# ViT blocks don't need attention_mask
|
| 147 |
+
x = block(x)
|
| 148 |
+
x = self.layer_norm(x) # Apply final layer norm
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
class MultiModalProjector(nn.Module):
|
| 152 |
+
# Projects image embedding dim to text embedding dim
|
| 153 |
+
def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.net = nn.Sequential(
|
| 156 |
+
nn.Linear(image_embed_dim, text_embed_dim * 4), # Intermediate expansion
|
| 157 |
+
nn.GELU(),
|
| 158 |
+
nn.Linear(text_embed_dim * 4, text_embed_dim),
|
| 159 |
+
nn.Dropout(dropout)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
return self.net(x)
|
train.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from constants import *
|
| 2 |
+
from dataset import create_train_dataloader, create_test_dataloader
|
| 3 |
+
from vision_language_model import VisionLanguageModel
|
| 4 |
+
from utils import *
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import wandb
|
| 7 |
+
import torch
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
|
| 12 |
+
print(f"Using device: {DEVICE}")
|
| 13 |
+
print(f"Vocab size: {vocab_size}")
|
| 14 |
+
|
| 15 |
+
# --- Initialize Model ---
|
| 16 |
+
# Ensure lambda_regression is passed during initialization
|
| 17 |
+
model = VisionLanguageModel(
|
| 18 |
+
n_embd=HIDDEN_DIM,
|
| 19 |
+
vocab_size=vocab_size,
|
| 20 |
+
img_size=IMAGE_SIZE,
|
| 21 |
+
patch_size=PATCH_SIZE,
|
| 22 |
+
num_heads=NUM_HEADS,
|
| 23 |
+
num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
|
| 24 |
+
num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
|
| 25 |
+
emb_dropout=DROPOUT,
|
| 26 |
+
blk_dropout=DROPOUT,
|
| 27 |
+
max_context=CONTEXT_LENGTH,
|
| 28 |
+
shared_embed_dim=SHARED_EMBED_DIM,
|
| 29 |
+
lambda_contrastive=LAMBDA_CONTRASTIVE,
|
| 30 |
+
lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
|
| 31 |
+
).to(DEVICE)
|
| 32 |
+
|
| 33 |
+
# --- Optimizer ---
|
| 34 |
+
# Optimizer will automatically include all model parameters, including the new regression head
|
| 35 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)
|
| 36 |
+
|
| 37 |
+
# --- Dataloaders ---
|
| 38 |
+
# Ensure these functions now return 'continuous_coords' in the batch dictionary
|
| 39 |
+
train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first
|
| 40 |
+
test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2)
|
| 41 |
+
if train_loader is None: exit("Training loader failed to initialize.")
|
| 42 |
+
test_loader_has_data = test_loader and len(test_loader.dataset) > 0
|
| 43 |
+
|
| 44 |
+
# --- LR Scheduler ---
|
| 45 |
+
if train_loader and len(train_loader) > 0:
|
| 46 |
+
steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0)
|
| 47 |
+
total_steps = steps_per_epoch * NUM_EPOCHS
|
| 48 |
+
# Adjust warmup steps if total steps are very low
|
| 49 |
+
warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup
|
| 50 |
+
print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}")
|
| 51 |
+
lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)
|
| 52 |
+
else:
|
| 53 |
+
print("Warning: Train loader empty. Using constant LR.")
|
| 54 |
+
total_steps = 0; warmup_steps = 0
|
| 55 |
+
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
|
| 56 |
+
|
| 57 |
+
# --- Wandb Setup ---
|
| 58 |
+
try:
|
| 59 |
+
wandb.init(
|
| 60 |
+
# project="point-language-model-dualhead", # Suggest new project name
|
| 61 |
+
project="point-language-model-regression-vast",
|
| 62 |
+
name=f"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
|
| 63 |
+
config={ # Add new hyperparameters
|
| 64 |
+
"image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM,
|
| 65 |
+
"context_length": CONTEXT_LENGTH, "dropout": DROPOUT,
|
| 66 |
+
"num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE,
|
| 67 |
+
"learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS,
|
| 68 |
+
"shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE,
|
| 69 |
+
"lambda_regression": LAMBDA_REGRESSION, # Log regression weight
|
| 70 |
+
"architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW",
|
| 71 |
+
"num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
wandb_enabled = True
|
| 75 |
+
# Watch model gradients and parameters
|
| 76 |
+
# wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Wandb initialization failed: {e}. Running without wandb.")
|
| 79 |
+
wandb_enabled = False
|
| 80 |
+
|
| 81 |
+
# --- Training Loop ---
|
| 82 |
+
print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...")
|
| 83 |
+
step_counter = 0
|
| 84 |
+
optimizer.zero_grad()
|
| 85 |
+
|
| 86 |
+
for epoch in range(NUM_EPOCHS):
|
| 87 |
+
model.train()
|
| 88 |
+
epoch_total_loss_accum = 0.0
|
| 89 |
+
epoch_class_loss_accum = 0.0
|
| 90 |
+
epoch_con_loss_accum = 0.0
|
| 91 |
+
epoch_reg_loss_accum = 0.0
|
| 92 |
+
batches_since_log = 0
|
| 93 |
+
valid_batches_accum = 0 # Count batches with valid loss for averaging
|
| 94 |
+
|
| 95 |
+
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
|
| 96 |
+
|
| 97 |
+
for batch_idx, batch in pbar:
|
| 98 |
+
if batch is None: continue
|
| 99 |
+
|
| 100 |
+
# --- Unpack Batch Data ---
|
| 101 |
+
try:
|
| 102 |
+
images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
|
| 103 |
+
prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
|
| 104 |
+
prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
|
| 105 |
+
target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
|
| 106 |
+
target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
|
| 107 |
+
generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)
|
| 108 |
+
continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded
|
| 109 |
+
coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) # Mask
|
| 110 |
+
except KeyError as e:
|
| 111 |
+
print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.")
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
# Clamp logit_scale
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))
|
| 117 |
+
|
| 118 |
+
# --- Forward Pass ---
|
| 119 |
+
# Model now returns potentially NaN scalar tensors for individual losses if invalid
|
| 120 |
+
logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model(
|
| 121 |
+
img_array=images,
|
| 122 |
+
prompt_ids=prompt_ids,
|
| 123 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 124 |
+
target_ids=target_ids,
|
| 125 |
+
target_attention_mask=target_attention_mask,
|
| 126 |
+
generative_targets=generative_targets,
|
| 127 |
+
continuous_coords=continuous_coords,
|
| 128 |
+
coords_mask=coords_mask # Pass mask for regression loss calculation
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# --- Loss Handling & Accumulation ---
|
| 132 |
+
# Check for invalid total loss before backward pass
|
| 133 |
+
if total_loss is None or not torch.isfinite(total_loss):
|
| 134 |
+
print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.")
|
| 135 |
+
optimizer.zero_grad() # Reset gradients for safety if loss is invalid
|
| 136 |
+
continue # Skip this batch for optimization step
|
| 137 |
+
|
| 138 |
+
# Scale loss for gradient accumulation
|
| 139 |
+
scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS
|
| 140 |
+
|
| 141 |
+
# Accumulate valid loss components for logging
|
| 142 |
+
# Check if the scalar tensor is finite before adding its item()
|
| 143 |
+
if torch.isfinite(total_loss):
|
| 144 |
+
epoch_total_loss_accum += total_loss.item()
|
| 145 |
+
valid_batches_accum += 1 # Increment count of batches contributing to average loss
|
| 146 |
+
if torch.isfinite(class_loss_s):
|
| 147 |
+
epoch_class_loss_accum += class_loss_s.item()
|
| 148 |
+
if torch.isfinite(contrastive_loss_s):
|
| 149 |
+
epoch_con_loss_accum += contrastive_loss_s.item()
|
| 150 |
+
if torch.isfinite(regression_loss_s):
|
| 151 |
+
epoch_reg_loss_accum += regression_loss_s.item()
|
| 152 |
+
batches_since_log += 1
|
| 153 |
+
|
| 154 |
+
# --- Backward Pass ---
|
| 155 |
+
try:
|
| 156 |
+
scaled_loss.backward()
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error during backward pass: {e}. Skipping step.")
|
| 159 |
+
optimizer.zero_grad() # Reset gradients if backward failed
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# --- Gradient Accumulation Step ---
|
| 163 |
+
if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):
|
| 164 |
+
# Clip gradients
|
| 165 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
|
| 166 |
+
|
| 167 |
+
# Check for non-finite gradients before stepping
|
| 168 |
+
all_finite = True
|
| 169 |
+
for p in model.parameters():
|
| 170 |
+
if p.grad is not None and not torch.isfinite(p.grad).all():
|
| 171 |
+
all_finite = False
|
| 172 |
+
break
|
| 173 |
+
if not all_finite:
|
| 174 |
+
print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.")
|
| 175 |
+
optimizer.zero_grad()
|
| 176 |
+
continue # Skip optimizer step and scheduler step
|
| 177 |
+
|
| 178 |
+
# Optimizer step
|
| 179 |
+
optimizer.step()
|
| 180 |
+
lr_scheduler.step()
|
| 181 |
+
optimizer.zero_grad()
|
| 182 |
+
|
| 183 |
+
step_counter += 1
|
| 184 |
+
|
| 185 |
+
# --- Logging ---
|
| 186 |
+
if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum
|
| 187 |
+
# Calculate average losses over the logging period using valid batch count
|
| 188 |
+
avg_total_loss = epoch_total_loss_accum / valid_batches_accum
|
| 189 |
+
avg_class_loss = epoch_class_loss_accum / valid_batches_accum
|
| 190 |
+
avg_con_loss = epoch_con_loss_accum / valid_batches_accum
|
| 191 |
+
avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum
|
| 192 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 193 |
+
|
| 194 |
+
# --- Test Evaluation (Needs modification to handle mask) ---
|
| 195 |
+
test_class_loss_val = float('nan')
|
| 196 |
+
test_con_loss_val = float('nan')
|
| 197 |
+
test_reg_loss_val = float('nan')
|
| 198 |
+
if test_loader_has_data:
|
| 199 |
+
model.eval()
|
| 200 |
+
with torch.no_grad():
|
| 201 |
+
try:
|
| 202 |
+
test_batch = next(iter(test_loader))
|
| 203 |
+
if test_batch:
|
| 204 |
+
t_images = test_batch['image'].to(DEVICE).to(DTYPE)
|
| 205 |
+
t_p_ids = test_batch['prompt_ids'].to(DEVICE)
|
| 206 |
+
t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
|
| 207 |
+
t_t_ids = test_batch['target_ids'].to(DEVICE)
|
| 208 |
+
t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
|
| 209 |
+
t_gen_targets = test_batch['generative_targets'].to(DEVICE)
|
| 210 |
+
t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded
|
| 211 |
+
t_coords_mask = test_batch['coords_mask'].to(DEVICE) # Mask
|
| 212 |
+
|
| 213 |
+
_, _, _, t_class_loss, t_con_loss, t_reg_loss = model(
|
| 214 |
+
t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,
|
| 215 |
+
t_gen_targets, t_cont_coords, t_coords_mask # Pass mask
|
| 216 |
+
)
|
| 217 |
+
# Use .item() only if the tensor is finite
|
| 218 |
+
test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
|
| 219 |
+
test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan')
|
| 220 |
+
test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan')
|
| 221 |
+
# ... (rest of exception handling) ...
|
| 222 |
+
except StopIteration: print("Info: Test loader exhausted during logging.")
|
| 223 |
+
except KeyError as e: print(f"Error: Missing key {e} in test batch.")
|
| 224 |
+
except Exception as e: print(f"Error during test evaluation: {e}")
|
| 225 |
+
model.train()
|
| 226 |
+
|
| 227 |
+
# Prepare data for logging
|
| 228 |
+
log_data = {
|
| 229 |
+
"train/total_loss": avg_total_loss,
|
| 230 |
+
"train/class_loss": avg_class_loss,
|
| 231 |
+
"train/contrastive_loss": avg_con_loss,
|
| 232 |
+
"train/regression_loss": avg_reg_loss,
|
| 233 |
+
"test/class_loss": test_class_loss_val,
|
| 234 |
+
"test/contrastive_loss": test_con_loss_val,
|
| 235 |
+
"test/regression_loss": test_reg_loss_val,
|
| 236 |
+
"epoch": epoch + ((batch_idx + 1) / len(train_loader)),
|
| 237 |
+
"step": step_counter,
|
| 238 |
+
"learning_rate": current_lr,
|
| 239 |
+
"gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
|
| 240 |
+
"logit_scale": model.logit_scale.exp().item()
|
| 241 |
+
}
|
| 242 |
+
# Update progress bar
|
| 243 |
+
pbar.set_postfix({
|
| 244 |
+
"lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}",
|
| 245 |
+
"cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}",
|
| 246 |
+
"reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}"
|
| 247 |
+
})
|
| 248 |
+
if wandb_enabled: wandb.log(log_data)
|
| 249 |
+
|
| 250 |
+
# Reset accumulators
|
| 251 |
+
epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0
|
| 252 |
+
batches_since_log = 0
|
| 253 |
+
valid_batches_accum = 0 # Reset valid batch count
|
| 254 |
+
|
| 255 |
+
# --- End of Epoch ---
|
| 256 |
+
print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.")
|
| 257 |
+
# Optional: Add end-of-epoch evaluation or model saving here
|
| 258 |
+
if epoch % 5 == 0:
|
| 259 |
+
torch.save(model.state_dict(), f"model_regression_multi_{epoch+1}.pth")
|
| 260 |
+
|
| 261 |
+
# --- End of Training ---
|
| 262 |
+
print("\nTraining completed!")
|
| 263 |
+
if wandb_enabled:
|
| 264 |
+
wandb.finish()
|
train_stage_2.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from constants import *
|
| 2 |
+
from dataset import create_train_dataloader, create_test_dataloader
|
| 3 |
+
from vision_language_model import VisionLanguageModel
|
| 4 |
+
from utils import *
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import wandb
|
| 7 |
+
import torch
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
|
| 12 |
+
print(f"Using device: {DEVICE}")
|
| 13 |
+
print(f"Vocab size: {vocab_size}")
|
| 14 |
+
|
| 15 |
+
# --- Initialize Model ---
|
| 16 |
+
# Ensure lambda_regression is passed during initialization
|
| 17 |
+
model = VisionLanguageModel(
|
| 18 |
+
n_embd=HIDDEN_DIM,
|
| 19 |
+
vocab_size=vocab_size,
|
| 20 |
+
img_size=IMAGE_SIZE,
|
| 21 |
+
patch_size=PATCH_SIZE,
|
| 22 |
+
num_heads=NUM_HEADS,
|
| 23 |
+
num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
|
| 24 |
+
num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
|
| 25 |
+
emb_dropout=0.0,
|
| 26 |
+
blk_dropout=0.0,
|
| 27 |
+
max_context=CONTEXT_LENGTH,
|
| 28 |
+
shared_embed_dim=SHARED_EMBED_DIM,
|
| 29 |
+
lambda_contrastive=LAMBDA_CONTRASTIVE,
|
| 30 |
+
lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
|
| 31 |
+
).to(DEVICE)
|
| 32 |
+
|
| 33 |
+
NUM_EPOCHS = 100
|
| 34 |
+
model.load_state_dict(torch.load("model_regression_multi_16.pth", weights_only=True)) # we ran till 15 before it over fitted with higher learning rate
|
| 35 |
+
|
| 36 |
+
# --- Optimizer ---
|
| 37 |
+
# Optimizer will automatically include all model parameters, including the new regression head
|
| 38 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1) # lower learning rate for second stage
|
| 39 |
+
|
| 40 |
+
# --- Dataloaders ---
|
| 41 |
+
# Ensure these functions now return 'continuous_coords' in the batch dictionary
|
| 42 |
+
train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first
|
| 43 |
+
test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2)
|
| 44 |
+
if train_loader is None: exit("Training loader failed to initialize.")
|
| 45 |
+
test_loader_has_data = test_loader and len(test_loader.dataset) > 0
|
| 46 |
+
|
| 47 |
+
# --- LR Scheduler ---
|
| 48 |
+
if train_loader and len(train_loader) > 0:
|
| 49 |
+
steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0)
|
| 50 |
+
total_steps = steps_per_epoch * NUM_EPOCHS
|
| 51 |
+
# Adjust warmup steps if total steps are very low
|
| 52 |
+
warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup
|
| 53 |
+
print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}")
|
| 54 |
+
lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)
|
| 55 |
+
else:
|
| 56 |
+
print("Warning: Train loader empty. Using constant LR.")
|
| 57 |
+
total_steps = 0; warmup_steps = 0
|
| 58 |
+
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
|
| 59 |
+
|
| 60 |
+
# --- Wandb Setup ---
|
| 61 |
+
try:
|
| 62 |
+
wandb.init(
|
| 63 |
+
# project="point-language-model-dualhead", # Suggest new project name
|
| 64 |
+
project="point-language-model-regression-vast",
|
| 65 |
+
name=f"point-vlm-dual-stage-2-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
|
| 66 |
+
config={ # Add new hyperparameters
|
| 67 |
+
"image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM,
|
| 68 |
+
"context_length": CONTEXT_LENGTH, "dropout": DROPOUT,
|
| 69 |
+
"num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE,
|
| 70 |
+
"learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS,
|
| 71 |
+
"shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE,
|
| 72 |
+
"lambda_regression": LAMBDA_REGRESSION, # Log regression weight
|
| 73 |
+
"architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW",
|
| 74 |
+
"num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
wandb_enabled = True
|
| 78 |
+
# Watch model gradients and parameters
|
| 79 |
+
# wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Wandb initialization failed: {e}. Running without wandb.")
|
| 82 |
+
wandb_enabled = False
|
| 83 |
+
|
| 84 |
+
# --- Training Loop ---
|
| 85 |
+
print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...")
|
| 86 |
+
step_counter = 0
|
| 87 |
+
optimizer.zero_grad()
|
| 88 |
+
|
| 89 |
+
for epoch in range(NUM_EPOCHS):
|
| 90 |
+
model.train()
|
| 91 |
+
epoch_total_loss_accum = 0.0
|
| 92 |
+
epoch_class_loss_accum = 0.0
|
| 93 |
+
epoch_con_loss_accum = 0.0
|
| 94 |
+
epoch_reg_loss_accum = 0.0
|
| 95 |
+
batches_since_log = 0
|
| 96 |
+
valid_batches_accum = 0 # Count batches with valid loss for averaging
|
| 97 |
+
|
| 98 |
+
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
|
| 99 |
+
|
| 100 |
+
for batch_idx, batch in pbar:
|
| 101 |
+
if batch is None: continue
|
| 102 |
+
|
| 103 |
+
# --- Unpack Batch Data ---
|
| 104 |
+
try:
|
| 105 |
+
images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
|
| 106 |
+
prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
|
| 107 |
+
prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
|
| 108 |
+
target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
|
| 109 |
+
target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
|
| 110 |
+
generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)
|
| 111 |
+
continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded
|
| 112 |
+
coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) # Mask
|
| 113 |
+
except KeyError as e:
|
| 114 |
+
print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Clamp logit_scale
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))
|
| 120 |
+
|
| 121 |
+
# --- Forward Pass ---
|
| 122 |
+
# Model now returns potentially NaN scalar tensors for individual losses if invalid
|
| 123 |
+
logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model(
|
| 124 |
+
img_array=images,
|
| 125 |
+
prompt_ids=prompt_ids,
|
| 126 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 127 |
+
target_ids=target_ids,
|
| 128 |
+
target_attention_mask=target_attention_mask,
|
| 129 |
+
generative_targets=generative_targets,
|
| 130 |
+
continuous_coords=continuous_coords,
|
| 131 |
+
coords_mask=coords_mask # Pass mask for regression loss calculation
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# --- Loss Handling & Accumulation ---
|
| 135 |
+
# Check for invalid total loss before backward pass
|
| 136 |
+
if total_loss is None or not torch.isfinite(total_loss):
|
| 137 |
+
print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.")
|
| 138 |
+
optimizer.zero_grad() # Reset gradients for safety if loss is invalid
|
| 139 |
+
continue # Skip this batch for optimization step
|
| 140 |
+
|
| 141 |
+
# Scale loss for gradient accumulation
|
| 142 |
+
scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS
|
| 143 |
+
|
| 144 |
+
# Accumulate valid loss components for logging
|
| 145 |
+
# Check if the scalar tensor is finite before adding its item()
|
| 146 |
+
if torch.isfinite(total_loss):
|
| 147 |
+
epoch_total_loss_accum += total_loss.item()
|
| 148 |
+
valid_batches_accum += 1 # Increment count of batches contributing to average loss
|
| 149 |
+
if torch.isfinite(class_loss_s):
|
| 150 |
+
epoch_class_loss_accum += class_loss_s.item()
|
| 151 |
+
if torch.isfinite(contrastive_loss_s):
|
| 152 |
+
epoch_con_loss_accum += contrastive_loss_s.item()
|
| 153 |
+
if torch.isfinite(regression_loss_s):
|
| 154 |
+
epoch_reg_loss_accum += regression_loss_s.item()
|
| 155 |
+
batches_since_log += 1
|
| 156 |
+
|
| 157 |
+
# --- Backward Pass ---
|
| 158 |
+
try:
|
| 159 |
+
scaled_loss.backward()
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"Error during backward pass: {e}. Skipping step.")
|
| 162 |
+
optimizer.zero_grad() # Reset gradients if backward failed
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
# --- Gradient Accumulation Step ---
|
| 166 |
+
if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):
|
| 167 |
+
# Clip gradients
|
| 168 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
|
| 169 |
+
|
| 170 |
+
# Check for non-finite gradients before stepping
|
| 171 |
+
all_finite = True
|
| 172 |
+
for p in model.parameters():
|
| 173 |
+
if p.grad is not None and not torch.isfinite(p.grad).all():
|
| 174 |
+
all_finite = False
|
| 175 |
+
break
|
| 176 |
+
if not all_finite:
|
| 177 |
+
print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.")
|
| 178 |
+
optimizer.zero_grad()
|
| 179 |
+
continue # Skip optimizer step and scheduler step
|
| 180 |
+
|
| 181 |
+
# Optimizer step
|
| 182 |
+
optimizer.step()
|
| 183 |
+
lr_scheduler.step()
|
| 184 |
+
optimizer.zero_grad()
|
| 185 |
+
|
| 186 |
+
step_counter += 1
|
| 187 |
+
|
| 188 |
+
# --- Logging ---
|
| 189 |
+
if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum
|
| 190 |
+
# Calculate average losses over the logging period using valid batch count
|
| 191 |
+
avg_total_loss = epoch_total_loss_accum / valid_batches_accum
|
| 192 |
+
avg_class_loss = epoch_class_loss_accum / valid_batches_accum
|
| 193 |
+
avg_con_loss = epoch_con_loss_accum / valid_batches_accum
|
| 194 |
+
avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum
|
| 195 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 196 |
+
|
| 197 |
+
# --- Test Evaluation (Needs modification to handle mask) ---
|
| 198 |
+
test_class_loss_val = float('nan')
|
| 199 |
+
test_con_loss_val = float('nan')
|
| 200 |
+
test_reg_loss_val = float('nan')
|
| 201 |
+
if test_loader_has_data:
|
| 202 |
+
model.eval()
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
try:
|
| 205 |
+
test_batch = next(iter(test_loader))
|
| 206 |
+
if test_batch:
|
| 207 |
+
t_images = test_batch['image'].to(DEVICE).to(DTYPE)
|
| 208 |
+
t_p_ids = test_batch['prompt_ids'].to(DEVICE)
|
| 209 |
+
t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
|
| 210 |
+
t_t_ids = test_batch['target_ids'].to(DEVICE)
|
| 211 |
+
t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
|
| 212 |
+
t_gen_targets = test_batch['generative_targets'].to(DEVICE)
|
| 213 |
+
t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded
|
| 214 |
+
t_coords_mask = test_batch['coords_mask'].to(DEVICE) # Mask
|
| 215 |
+
|
| 216 |
+
_, _, _, t_class_loss, t_con_loss, t_reg_loss = model(
|
| 217 |
+
t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,
|
| 218 |
+
t_gen_targets, t_cont_coords, t_coords_mask # Pass mask
|
| 219 |
+
)
|
| 220 |
+
# Use .item() only if the tensor is finite
|
| 221 |
+
test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
|
| 222 |
+
test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan')
|
| 223 |
+
test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan')
|
| 224 |
+
# ... (rest of exception handling) ...
|
| 225 |
+
except StopIteration: print("Info: Test loader exhausted during logging.")
|
| 226 |
+
except KeyError as e: print(f"Error: Missing key {e} in test batch.")
|
| 227 |
+
except Exception as e: print(f"Error during test evaluation: {e}")
|
| 228 |
+
model.train()
|
| 229 |
+
|
| 230 |
+
# Prepare data for logging
|
| 231 |
+
log_data = {
|
| 232 |
+
"train/total_loss": avg_total_loss,
|
| 233 |
+
"train/class_loss": avg_class_loss,
|
| 234 |
+
"train/contrastive_loss": avg_con_loss,
|
| 235 |
+
"train/regression_loss": avg_reg_loss,
|
| 236 |
+
"test/class_loss": test_class_loss_val,
|
| 237 |
+
"test/contrastive_loss": test_con_loss_val,
|
| 238 |
+
"test/regression_loss": test_reg_loss_val,
|
| 239 |
+
"epoch": epoch + ((batch_idx + 1) / len(train_loader)),
|
| 240 |
+
"step": step_counter,
|
| 241 |
+
"learning_rate": current_lr,
|
| 242 |
+
"gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
|
| 243 |
+
"logit_scale": model.logit_scale.exp().item()
|
| 244 |
+
}
|
| 245 |
+
# Update progress bar
|
| 246 |
+
pbar.set_postfix({
|
| 247 |
+
"lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}",
|
| 248 |
+
"cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}",
|
| 249 |
+
"reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}"
|
| 250 |
+
})
|
| 251 |
+
if wandb_enabled: wandb.log(log_data)
|
| 252 |
+
|
| 253 |
+
# Reset accumulators
|
| 254 |
+
epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0
|
| 255 |
+
batches_since_log = 0
|
| 256 |
+
valid_batches_accum = 0 # Reset valid batch count
|
| 257 |
+
|
| 258 |
+
# --- End of Epoch ---
|
| 259 |
+
print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.")
|
| 260 |
+
# Optional: Add end-of-epoch evaluation or model saving here
|
| 261 |
+
if epoch % 5 == 0:
|
| 262 |
+
torch.save(model.state_dict(), f"model_regression_multi_stage_2_{epoch+1}.pth")
|
| 263 |
+
|
| 264 |
+
# --- End of Training ---
|
| 265 |
+
print("\nTraining completed!")
|
| 266 |
+
if wandb_enabled:
|
| 267 |
+
wandb.finish()
|
utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from constants import *
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_tokenizer():
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 11 |
+
point_tokens = [f"coord_bin_{i}" for i in range(0, NUM_BINS)]
|
| 12 |
+
new_tokens = [
|
| 13 |
+
"<point_start>", "<point_end>", "<result_start>",
|
| 14 |
+
"<result_end>", "<pointx_start>", "<pointx_end>",
|
| 15 |
+
"<pointy_start>", "<pointy_end>",
|
| 16 |
+
*point_tokens
|
| 17 |
+
]
|
| 18 |
+
tokenizer.add_tokens(new_tokens)
|
| 19 |
+
# Ensure pad token is set (GPT2 usually doesn't have one by default)
|
| 20 |
+
if tokenizer.pad_token is None:
|
| 21 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Or use eos_token if preferred
|
| 22 |
+
# tokenizer.pad_token_id = tokenizer.eos_token_id # Alternative if we want padding to be EOS
|
| 23 |
+
|
| 24 |
+
print(f"Tokenizer pad token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
|
| 25 |
+
print(f"Tokenizer EOS token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")
|
| 26 |
+
|
| 27 |
+
# Check if pad token ID is valid
|
| 28 |
+
if tokenizer.pad_token_id is None:
|
| 29 |
+
raise ValueError("Tokenizer pad token ID is not set!")
|
| 30 |
+
|
| 31 |
+
return tokenizer, len(tokenizer)
|
| 32 |
+
|
| 33 |
+
def image_to_tensor(image, image_size=IMAGE_SIZE):
|
| 34 |
+
if image.mode != 'RGB':
|
| 35 |
+
image = image.convert('RGB')
|
| 36 |
+
# We avoid the hassle of calculating
|
| 37 |
+
# changed co-ordinates for rotation etc for now. Can be added later.
|
| 38 |
+
transform = transforms.Compose([
|
| 39 |
+
transforms.Resize((image_size, image_size)),
|
| 40 |
+
transforms.ToTensor(),
|
| 41 |
+
transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD)
|
| 42 |
+
])
|
| 43 |
+
return transform(image)
|
| 44 |
+
|
| 45 |
+
def tensor_to_image(tensor):
|
| 46 |
+
tensor = tensor.clone().detach()
|
| 47 |
+
if tensor.is_cuda:
|
| 48 |
+
tensor = tensor.cpu()
|
| 49 |
+
mean = torch.tensor(IMAGE_MEAN).view(3, 1, 1)
|
| 50 |
+
std = torch.tensor(IMAGE_STD).view(3, 1, 1)
|
| 51 |
+
tensor = tensor * std + mean
|
| 52 |
+
tensor = torch.clamp(tensor, 0, 1)
|
| 53 |
+
image_np = tensor.numpy().transpose(1, 2, 0)
|
| 54 |
+
image_np = (image_np * 255).astype(np.uint8)
|
| 55 |
+
return Image.fromarray(image_np)
|
| 56 |
+
|
| 57 |
+
tokenizer, vocab_size = get_tokenizer() # Initialize tokenizer globally
|
vision_language_model.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model_components import ViT, MultiModalProjector
|
| 2 |
+
from decoder_language_model import DecoderLanguageModel
|
| 3 |
+
from constants import *
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from utils import tokenizer, vocab_size
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VisionLanguageModel(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Vision Language Model integrating ViT, Projector, Contrastive Loss, Decoder (Class + Reg).
|
| 13 |
+
Handles multiple points via padded regression targets and masked loss.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self,
|
| 16 |
+
n_embd=HIDDEN_DIM,
|
| 17 |
+
vocab_size=vocab_size,
|
| 18 |
+
img_size=IMAGE_SIZE,
|
| 19 |
+
patch_size=PATCH_SIZE,
|
| 20 |
+
num_heads=NUM_HEADS,
|
| 21 |
+
num_blks_vit=NUM_LAYERS,
|
| 22 |
+
num_blks_dec=NUM_LAYERS,
|
| 23 |
+
emb_dropout=DROPOUT,
|
| 24 |
+
blk_dropout=DROPOUT,
|
| 25 |
+
max_context=CONTEXT_LENGTH,
|
| 26 |
+
shared_embed_dim=SHARED_EMBED_DIM,
|
| 27 |
+
lambda_contrastive=LAMBDA_CONTRASTIVE,
|
| 28 |
+
lambda_regression=LAMBDA_REGRESSION, # Use the updated constant
|
| 29 |
+
max_points = MAX_POINTS # Store max points
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
# --- Vision Backbone ---
|
| 34 |
+
self.vision_encoder = ViT(
|
| 35 |
+
img_size=img_size,
|
| 36 |
+
patch_size=patch_size,
|
| 37 |
+
num_hiddens=n_embd, # Assuming ViT output dim matches decoder embed dim
|
| 38 |
+
num_heads=num_heads,
|
| 39 |
+
num_blks=num_blks_vit,
|
| 40 |
+
emb_dropout=emb_dropout,
|
| 41 |
+
blk_dropout=blk_dropout
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# --- Multimodal Components ---
|
| 45 |
+
self.multimodal_projector = MultiModalProjector(
|
| 46 |
+
image_embed_dim=n_embd, # Input from ViT
|
| 47 |
+
text_embed_dim=n_embd, # Output matches decoder dim
|
| 48 |
+
dropout=emb_dropout
|
| 49 |
+
)
|
| 50 |
+
self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
|
| 51 |
+
self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
|
| 52 |
+
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))
|
| 53 |
+
|
| 54 |
+
# --- Text Decoder ---
|
| 55 |
+
# DecoderLanguageModel now has regression head outputting MAX_POINTS*2
|
| 56 |
+
self.decoder = DecoderLanguageModel(
|
| 57 |
+
n_embd=n_embd,
|
| 58 |
+
vocab_size=vocab_size,
|
| 59 |
+
num_heads=num_heads,
|
| 60 |
+
n_layer=num_blks_dec,
|
| 61 |
+
max_context=max_context,
|
| 62 |
+
dropout=blk_dropout # Use block dropout for decoder consistency
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# --- Store Configuration ---
|
| 66 |
+
self.n_embd = n_embd
|
| 67 |
+
self.vocab_size = vocab_size
|
| 68 |
+
self.num_patches = (img_size // patch_size)**2 + 1
|
| 69 |
+
self.lambda_contrastive = lambda_contrastive
|
| 70 |
+
self.lambda_regression = lambda_regression
|
| 71 |
+
self.max_points = max_points # Store max points
|
| 72 |
+
|
| 73 |
+
self._resize_embeddings_if_needed(self.vocab_size)
|
| 74 |
+
print("VisionLanguageModel initialized.")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _resize_embeddings_if_needed(self, current_vocab_size):
|
| 78 |
+
""" Resizes decoder token embeddings if vocab size changed after init. """
|
| 79 |
+
decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings
|
| 80 |
+
if decoder_embedding_size != current_vocab_size:
|
| 81 |
+
print(f"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}")
|
| 82 |
+
# Freeze original weights before replacing layers
|
| 83 |
+
self.decoder.token_embedding_table.weight.requires_grad = False
|
| 84 |
+
self.decoder.lm_head.weight.requires_grad = False
|
| 85 |
+
# Create new layers
|
| 86 |
+
new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE)
|
| 87 |
+
new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE)
|
| 88 |
+
# Assign new layers
|
| 89 |
+
self.decoder.token_embedding_table = new_embedding
|
| 90 |
+
self.decoder.lm_head = new_lm_head
|
| 91 |
+
# Re-tie weights
|
| 92 |
+
self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight
|
| 93 |
+
print("VLM decoder embeddings resized and weights retied.")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _calculate_contrastive_loss(self, image_features, text_features):
|
| 97 |
+
""" Calculates the symmetric InfoNCE loss. """
|
| 98 |
+
# Assumes features are already projected to shared_embed_dim
|
| 99 |
+
# image_features: (B, E)
|
| 100 |
+
# text_features: (B, E)
|
| 101 |
+
|
| 102 |
+
# Normalize features
|
| 103 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 104 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 105 |
+
|
| 106 |
+
# Cosine similarity as logits (using learnable temperature)
|
| 107 |
+
logit_scale = self.logit_scale.exp()
|
| 108 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 109 |
+
logits_per_text = logits_per_image.t()
|
| 110 |
+
|
| 111 |
+
# Calculate symmetric cross-entropy loss
|
| 112 |
+
labels = torch.arange(len(logits_per_image), device=logits_per_image.device)
|
| 113 |
+
loss_i = F.cross_entropy(logits_per_image, labels)
|
| 114 |
+
loss_t = F.cross_entropy(logits_per_text, labels)
|
| 115 |
+
contrastive_loss = (loss_i + loss_t) / 2.0
|
| 116 |
+
|
| 117 |
+
# Handle potential NaNs
|
| 118 |
+
if torch.isnan(contrastive_loss):
|
| 119 |
+
print("Warning: Contrastive loss is NaN.")
|
| 120 |
+
return None # Return None or zero tensor
|
| 121 |
+
|
| 122 |
+
return contrastive_loss
|
| 123 |
+
|
| 124 |
+
def forward(self,
|
| 125 |
+
img_array,
|
| 126 |
+
prompt_ids,
|
| 127 |
+
prompt_attention_mask,
|
| 128 |
+
target_ids,
|
| 129 |
+
target_attention_mask,
|
| 130 |
+
generative_targets=None,
|
| 131 |
+
continuous_coords=None, # Now expects shape (B, MAX_POINTS, 2), padded
|
| 132 |
+
coords_mask=None # Mask for valid points (B, MAX_POINTS)
|
| 133 |
+
):
|
| 134 |
+
"""
|
| 135 |
+
Main forward pass for training. Calculates combined loss with masked regression loss.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# --- 1. Encode Image ---
|
| 139 |
+
image_embeds_raw = self.vision_encoder(img_array) # (B, N_img, C)
|
| 140 |
+
B, N_img, C_img = image_embeds_raw.shape
|
| 141 |
+
img_cls_token = image_embeds_raw[:, 0]
|
| 142 |
+
|
| 143 |
+
# --- 2. Contrastive Loss Path ---
|
| 144 |
+
contrastive_loss = None
|
| 145 |
+
# ... (contrastive loss calculation - same as before) ...
|
| 146 |
+
image_features_contrast = self.image_contrastive_head(img_cls_token)
|
| 147 |
+
with torch.no_grad(): # Keep no_grad here for efficiency if prompt embeddings aren't trained via contrastive
|
| 148 |
+
prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids)
|
| 149 |
+
prompt_lengths = prompt_attention_mask.sum(dim=1)
|
| 150 |
+
last_token_indices = (prompt_lengths - 1).clamp(min=0)
|
| 151 |
+
gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img)
|
| 152 |
+
prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1)
|
| 153 |
+
text_features_contrast = self.text_contrastive_head(prompt_last_token_embed)
|
| 154 |
+
contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# --- 3. Generative / Regression Path ---
|
| 158 |
+
image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
|
| 159 |
+
prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids)
|
| 160 |
+
target_embeds_decoder = self.decoder.token_embedding_table(target_ids)
|
| 161 |
+
B, T_prompt, C = prompt_embeds_decoder.shape
|
| 162 |
+
B, T_target, _ = target_embeds_decoder.shape
|
| 163 |
+
|
| 164 |
+
# Prepare combined input sequence and attention mask for the decoder
|
| 165 |
+
combined_embeds = torch.cat([
|
| 166 |
+
image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder
|
| 167 |
+
], dim=1)
|
| 168 |
+
combined_attention_mask = torch.cat([
|
| 169 |
+
torch.ones(B, N_img, dtype=torch.long, device=DEVICE),
|
| 170 |
+
prompt_attention_mask,
|
| 171 |
+
target_attention_mask
|
| 172 |
+
], dim=1)
|
| 173 |
+
T_combined = combined_embeds.shape[1]
|
| 174 |
+
|
| 175 |
+
# Prepare combined targets for the classification loss
|
| 176 |
+
combined_class_targets = None
|
| 177 |
+
if generative_targets is not None:
|
| 178 |
+
combined_class_targets = torch.cat([
|
| 179 |
+
torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),
|
| 180 |
+
generative_targets
|
| 181 |
+
], dim=1)
|
| 182 |
+
|
| 183 |
+
# --- Pass through Decoder ---
|
| 184 |
+
logits, class_loss, x_norm = self.decoder(
|
| 185 |
+
combined_embeds,
|
| 186 |
+
attention_mask=combined_attention_mask,
|
| 187 |
+
targets=combined_class_targets
|
| 188 |
+
)
|
| 189 |
+
# x_norm shape: (B, T_combined, C)
|
| 190 |
+
|
| 191 |
+
# --- Calculate Regression Output & Loss (Modified for multiple points) ---
|
| 192 |
+
regression_loss = None
|
| 193 |
+
regression_output = None
|
| 194 |
+
if continuous_coords is not None and coords_mask is not None and x_norm is not None:
|
| 195 |
+
# Strategy: Use hidden state corresponding to token *before* <result_end> (or <eos>)
|
| 196 |
+
# This single state predicts coordinates for *all* MAX_POINTS.
|
| 197 |
+
target_lengths = target_attention_mask.sum(dim=1) # Length of actual target tokens (B,)
|
| 198 |
+
# Index relative to start of *target sequence* is length - 2 (token before <eos>/<result_end>)
|
| 199 |
+
relative_target_idx = (target_lengths - 2).clamp(min=0)
|
| 200 |
+
# Absolute index in the combined sequence's hidden states (x_norm)
|
| 201 |
+
absolute_idx = N_img + T_prompt + relative_target_idx
|
| 202 |
+
absolute_idx = absolute_idx.clamp(max=T_combined - 1) # Clamp index
|
| 203 |
+
|
| 204 |
+
# Gather the hidden states at these specific indices
|
| 205 |
+
gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C)
|
| 206 |
+
try:
|
| 207 |
+
hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) # Shape: (B, C)
|
| 208 |
+
# Pass through the regression head
|
| 209 |
+
regression_output_flat = self.decoder.regression_head(hidden_state_for_regression) # Shape: (B, MAX_POINTS * 2)
|
| 210 |
+
# Reshape to (B, MAX_POINTS, 2)
|
| 211 |
+
regression_output = regression_output_flat.view(B, self.max_points, 2)
|
| 212 |
+
|
| 213 |
+
# --- Calculate MASKED regression loss (L1 - Mean Absolute Error) ---
|
| 214 |
+
loss_per_coord = F.l1_loss(regression_output, continuous_coords, reduction='none') # (B, MAX_POINTS, 2)
|
| 215 |
+
# Apply mask (mask is (B, MAX_POINTS), need to broadcast to (B, MAX_POINTS, 2))
|
| 216 |
+
masked_loss = loss_per_coord * coords_mask.unsqueeze(-1)
|
| 217 |
+
# Sum loss over valid points and coordinates, divide by number of valid coordinates
|
| 218 |
+
num_valid_coords = coords_mask.sum() * 2 # Total number of valid x,y values in batch
|
| 219 |
+
if num_valid_coords > 0:
|
| 220 |
+
regression_loss = masked_loss.sum() / num_valid_coords
|
| 221 |
+
else:
|
| 222 |
+
regression_loss = torch.tensor(0.0, device=DEVICE) # No valid points in batch
|
| 223 |
+
|
| 224 |
+
if torch.isnan(regression_loss):
|
| 225 |
+
print("Warning: Regression loss is NaN.")
|
| 226 |
+
regression_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # Set to zero tensor if NaN
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print(f"Error during regression calculation: {e}")
|
| 231 |
+
print(f"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}")
|
| 232 |
+
regression_loss = None
|
| 233 |
+
regression_output = None # Ensure output is None if error occurs
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# --- 4. Combine All Losses ---
|
| 237 |
+
total_loss = torch.tensor(0.0, device=DEVICE) # Ensure requires_grad=True
|
| 238 |
+
# Add valid losses with their respective weights
|
| 239 |
+
loss_log = {}
|
| 240 |
+
if class_loss is not None and torch.isfinite(class_loss):
|
| 241 |
+
total_loss += class_loss # Weight = 1.0 assumed
|
| 242 |
+
loss_log["class_loss"] = class_loss.item()
|
| 243 |
+
else:
|
| 244 |
+
# If class_loss is None or NaN/Inf, don't add it, log NaN
|
| 245 |
+
loss_log["class_loss"] = float('nan')
|
| 246 |
+
print(f"Warning: Invalid class_loss ({class_loss})")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if contrastive_loss is not None and torch.isfinite(contrastive_loss):
|
| 250 |
+
total_loss += self.lambda_contrastive * contrastive_loss
|
| 251 |
+
loss_log["contrastive_loss"] = contrastive_loss.item()
|
| 252 |
+
else:
|
| 253 |
+
loss_log["contrastive_loss"] = float('nan')
|
| 254 |
+
print(f"Warning: Invalid contrastive_loss ({contrastive_loss})")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if regression_loss is not None and torch.isfinite(regression_loss):
|
| 258 |
+
total_loss += self.lambda_regression * regression_loss
|
| 259 |
+
loss_log["regression_loss"] = regression_loss.item()
|
| 260 |
+
else:
|
| 261 |
+
loss_log["regression_loss"] = float('nan')
|
| 262 |
+
# Don't print warning if it was intentionally set to 0 due to no valid points
|
| 263 |
+
if regression_loss is not None and not (regression_loss == 0.0 and num_valid_coords == 0):
|
| 264 |
+
print(f"Warning: Invalid regression_loss ({regression_loss})")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Handle case where total loss becomes NaN/Inf
|
| 268 |
+
if not torch.isfinite(total_loss):
|
| 269 |
+
print(f"Warning: Total loss became non-finite ({total_loss}). Setting to zero and clearing gradients.")
|
| 270 |
+
total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
|
| 271 |
+
# It might be safer to skip the optimizer step entirely here, handled in training loop
|
| 272 |
+
|
| 273 |
+
# Use the loss_log dictionary for clearer logging later
|
| 274 |
+
class_loss_val = loss_log["class_loss"]
|
| 275 |
+
contrastive_loss_val = loss_log["contrastive_loss"]
|
| 276 |
+
regression_loss_val = loss_log["regression_loss"]
|
| 277 |
+
|
| 278 |
+
# Return all relevant outputs (use scalar values for loss logging)
|
| 279 |
+
return logits, regression_output, total_loss, \
|
| 280 |
+
torch.tensor(class_loss_val), torch.tensor(contrastive_loss_val), torch.tensor(regression_loss_val)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# --- Generation Method ---
|
| 284 |
+
@torch.no_grad() # Ensure no gradients are computed during generation
|
| 285 |
+
def generate(self, img_array, idx_prompt, max_new_tokens,
|
| 286 |
+
temperature=1.0, top_k=None, # Default to greedy if temp=1, top_k=None
|
| 287 |
+
force_result_start=True # Option to manually add <result_start>
|
| 288 |
+
):
|
| 289 |
+
"""
|
| 290 |
+
Generates token sequences autoregressively based on image and prompt.
|
| 291 |
+
Uses the classification head (lm_head).
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl.
|
| 295 |
+
idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt).
|
| 296 |
+
max_new_tokens (int): Maximum number of new tokens to generate.
|
| 297 |
+
temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper.
|
| 298 |
+
top_k (int | None): If set, restricts sampling to top K most likely tokens.
|
| 299 |
+
force_result_start (bool): If True, manually appends <result_start> embedding
|
| 300 |
+
after the prompt before starting generation loop.
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated).
|
| 304 |
+
"""
|
| 305 |
+
self.eval() # Ensure model is in eval mode
|
| 306 |
+
B = img_array.shape[0]
|
| 307 |
+
if B > 1:
|
| 308 |
+
# This simplified generation loop assumes B=1 for clarity
|
| 309 |
+
# Batch generation requires careful handling of EOS and padding within the loop
|
| 310 |
+
print("Warning: Generation function currently assumes batch size B=1.")
|
| 311 |
+
# Process only the first item for now
|
| 312 |
+
img_array = img_array[:1]
|
| 313 |
+
idx_prompt = idx_prompt[:1]
|
| 314 |
+
B = 1
|
| 315 |
+
|
| 316 |
+
# --- 1. Prepare Initial Embeddings ---
|
| 317 |
+
image_embeds_raw = self.vision_encoder(img_array)
|
| 318 |
+
image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
|
| 319 |
+
prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt)
|
| 320 |
+
|
| 321 |
+
# Initial sequence for the decoder loop
|
| 322 |
+
current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1)
|
| 323 |
+
generated_ids_list = [] # Store newly generated IDs as a list
|
| 324 |
+
|
| 325 |
+
# Manually add <result_start> if forced
|
| 326 |
+
if force_result_start:
|
| 327 |
+
try:
|
| 328 |
+
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
|
| 329 |
+
result_start_embed = self.decoder.token_embedding_table(
|
| 330 |
+
torch.tensor([[result_start_token_id]], device=DEVICE)
|
| 331 |
+
)
|
| 332 |
+
current_embeds = torch.cat([current_embeds, result_start_embed], dim=1)
|
| 333 |
+
# Also store this token ID if we added it
|
| 334 |
+
generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE))
|
| 335 |
+
except Exception as e:
|
| 336 |
+
print(f"Warning: Could not encode or add <result_start>: {e}")
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# --- 2. Autoregressive Loop ---
|
| 340 |
+
for _ in range(max_new_tokens):
|
| 341 |
+
T_current = current_embeds.shape[1]
|
| 342 |
+
|
| 343 |
+
# Context truncation
|
| 344 |
+
if T_current > self.decoder.max_context:
|
| 345 |
+
current_embeds = current_embeds[:, -self.decoder.max_context:, :]
|
| 346 |
+
T_current = self.decoder.max_context
|
| 347 |
+
|
| 348 |
+
# Prepare inputs for decoder blocks
|
| 349 |
+
pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE)
|
| 350 |
+
pos = pos.clamp(max=self.decoder.max_context - 1)
|
| 351 |
+
pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0)
|
| 352 |
+
x = current_embeds + pos_emb
|
| 353 |
+
attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) # No padding needed
|
| 354 |
+
|
| 355 |
+
# Pass through decoder blocks
|
| 356 |
+
for block in self.decoder.blocks:
|
| 357 |
+
x = block(x, attention_mask=attention_mask)
|
| 358 |
+
|
| 359 |
+
# Get logits for the last token
|
| 360 |
+
x = self.decoder.ln_f(x[:, -1:, :]) # (B, 1, C)
|
| 361 |
+
logits = self.decoder.lm_head(x) # (B, 1, V)
|
| 362 |
+
logits = logits.squeeze(1) / temperature # Apply temperature (B, V)
|
| 363 |
+
|
| 364 |
+
# --- Sampling / Decoding ---
|
| 365 |
+
# Optional: Top-K filtering
|
| 366 |
+
if top_k is not None and top_k > 0:
|
| 367 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 368 |
+
logits[logits < v[:, [-1]]] = -float('Inf') # Apply mask
|
| 369 |
+
|
| 370 |
+
# Get probabilities
|
| 371 |
+
probs = F.softmax(logits, dim=-1)
|
| 372 |
+
|
| 373 |
+
# Sample next token ID
|
| 374 |
+
# For deterministic output (greedy), use torch.argmax instead of multinomial
|
| 375 |
+
if temperature == 0.0 or top_k == 1: # Greedy condition
|
| 376 |
+
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
|
| 377 |
+
else:
|
| 378 |
+
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 379 |
+
|
| 380 |
+
# Append the generated token ID
|
| 381 |
+
generated_ids_list.append(idx_next)
|
| 382 |
+
|
| 383 |
+
# Stop if EOS is generated
|
| 384 |
+
if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id:
|
| 385 |
+
break
|
| 386 |
+
|
| 387 |
+
# Prepare for next iteration
|
| 388 |
+
next_token_embed = self.decoder.token_embedding_table(idx_next)
|
| 389 |
+
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# --- 3. Combine results ---
|
| 393 |
+
if generated_ids_list:
|
| 394 |
+
generated_ids_tensor = torch.cat(generated_ids_list, dim=1) # (B, T_generated)
|
| 395 |
+
full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1)
|
| 396 |
+
else:
|
| 397 |
+
full_sequence_ids = idx_prompt # Return only prompt if nothing generated
|
| 398 |
+
|
| 399 |
+
self.train() # Set model back to training mode
|
| 400 |
+
return full_sequence_ids
|