File size: 13,347 Bytes
b781107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
from constants import *
from dataset import create_train_dataloader, create_test_dataloader
from vision_language_model import VisionLanguageModel
from utils import *
from datetime import datetime
import wandb
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from tqdm.auto import tqdm
print(f"Using device: {DEVICE}")
print(f"Vocab size: {vocab_size}")
# --- Initialize Model ---
# Ensure lambda_regression is passed during initialization
model = VisionLanguageModel(
n_embd=HIDDEN_DIM,
vocab_size=vocab_size,
img_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_heads=NUM_HEADS,
num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
emb_dropout=DROPOUT,
blk_dropout=DROPOUT,
max_context=CONTEXT_LENGTH,
shared_embed_dim=SHARED_EMBED_DIM,
lambda_contrastive=LAMBDA_CONTRASTIVE,
lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
).to(DEVICE)
# --- Optimizer ---
# Optimizer will automatically include all model parameters, including the new regression head
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)
# --- Dataloaders ---
# Ensure these functions now return 'continuous_coords' in the batch dictionary
train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first
test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2)
if train_loader is None: exit("Training loader failed to initialize.")
test_loader_has_data = test_loader and len(test_loader.dataset) > 0
# --- LR Scheduler ---
if train_loader and len(train_loader) > 0:
steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0)
total_steps = steps_per_epoch * NUM_EPOCHS
# Adjust warmup steps if total steps are very low
warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup
print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}")
lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)
else:
print("Warning: Train loader empty. Using constant LR.")
total_steps = 0; warmup_steps = 0
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
# --- Wandb Setup ---
try:
wandb.init(
# project="point-language-model-dualhead", # Suggest new project name
project="point-language-model-regression-vast",
name=f"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
config={ # Add new hyperparameters
"image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM,
"context_length": CONTEXT_LENGTH, "dropout": DROPOUT,
"num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE,
"learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS,
"shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE,
"lambda_regression": LAMBDA_REGRESSION, # Log regression weight
"architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW",
"num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps
}
)
wandb_enabled = True
# Watch model gradients and parameters
# wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)
except Exception as e:
print(f"Wandb initialization failed: {e}. Running without wandb.")
wandb_enabled = False
# --- Training Loop ---
print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...")
step_counter = 0
optimizer.zero_grad()
for epoch in range(NUM_EPOCHS):
model.train()
epoch_total_loss_accum = 0.0
epoch_class_loss_accum = 0.0
epoch_con_loss_accum = 0.0
epoch_reg_loss_accum = 0.0
batches_since_log = 0
valid_batches_accum = 0 # Count batches with valid loss for averaging
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
for batch_idx, batch in pbar:
if batch is None: continue
# --- Unpack Batch Data ---
try:
images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)
continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded
coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) # Mask
except KeyError as e:
print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.")
continue
# Clamp logit_scale
with torch.no_grad():
model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))
# --- Forward Pass ---
# Model now returns potentially NaN scalar tensors for individual losses if invalid
logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model(
img_array=images,
prompt_ids=prompt_ids,
prompt_attention_mask=prompt_attention_mask,
target_ids=target_ids,
target_attention_mask=target_attention_mask,
generative_targets=generative_targets,
continuous_coords=continuous_coords,
coords_mask=coords_mask # Pass mask for regression loss calculation
)
# --- Loss Handling & Accumulation ---
# Check for invalid total loss before backward pass
if total_loss is None or not torch.isfinite(total_loss):
print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.")
optimizer.zero_grad() # Reset gradients for safety if loss is invalid
continue # Skip this batch for optimization step
# Scale loss for gradient accumulation
scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS
# Accumulate valid loss components for logging
# Check if the scalar tensor is finite before adding its item()
if torch.isfinite(total_loss):
epoch_total_loss_accum += total_loss.item()
valid_batches_accum += 1 # Increment count of batches contributing to average loss
if torch.isfinite(class_loss_s):
epoch_class_loss_accum += class_loss_s.item()
if torch.isfinite(contrastive_loss_s):
epoch_con_loss_accum += contrastive_loss_s.item()
if torch.isfinite(regression_loss_s):
epoch_reg_loss_accum += regression_loss_s.item()
batches_since_log += 1
# --- Backward Pass ---
try:
scaled_loss.backward()
except Exception as e:
print(f"Error during backward pass: {e}. Skipping step.")
optimizer.zero_grad() # Reset gradients if backward failed
continue
# --- Gradient Accumulation Step ---
if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):
# Clip gradients
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
# Check for non-finite gradients before stepping
all_finite = True
for p in model.parameters():
if p.grad is not None and not torch.isfinite(p.grad).all():
all_finite = False
break
if not all_finite:
print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.")
optimizer.zero_grad()
continue # Skip optimizer step and scheduler step
# Optimizer step
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
step_counter += 1
# --- Logging ---
if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum
# Calculate average losses over the logging period using valid batch count
avg_total_loss = epoch_total_loss_accum / valid_batches_accum
avg_class_loss = epoch_class_loss_accum / valid_batches_accum
avg_con_loss = epoch_con_loss_accum / valid_batches_accum
avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum
current_lr = optimizer.param_groups[0]['lr']
# --- Test Evaluation (Needs modification to handle mask) ---
test_class_loss_val = float('nan')
test_con_loss_val = float('nan')
test_reg_loss_val = float('nan')
if test_loader_has_data:
model.eval()
with torch.no_grad():
try:
test_batch = next(iter(test_loader))
if test_batch:
t_images = test_batch['image'].to(DEVICE).to(DTYPE)
t_p_ids = test_batch['prompt_ids'].to(DEVICE)
t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
t_t_ids = test_batch['target_ids'].to(DEVICE)
t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
t_gen_targets = test_batch['generative_targets'].to(DEVICE)
t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded
t_coords_mask = test_batch['coords_mask'].to(DEVICE) # Mask
_, _, _, t_class_loss, t_con_loss, t_reg_loss = model(
t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,
t_gen_targets, t_cont_coords, t_coords_mask # Pass mask
)
# Use .item() only if the tensor is finite
test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan')
test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan')
# ... (rest of exception handling) ...
except StopIteration: print("Info: Test loader exhausted during logging.")
except KeyError as e: print(f"Error: Missing key {e} in test batch.")
except Exception as e: print(f"Error during test evaluation: {e}")
model.train()
# Prepare data for logging
log_data = {
"train/total_loss": avg_total_loss,
"train/class_loss": avg_class_loss,
"train/contrastive_loss": avg_con_loss,
"train/regression_loss": avg_reg_loss,
"test/class_loss": test_class_loss_val,
"test/contrastive_loss": test_con_loss_val,
"test/regression_loss": test_reg_loss_val,
"epoch": epoch + ((batch_idx + 1) / len(train_loader)),
"step": step_counter,
"learning_rate": current_lr,
"gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
"logit_scale": model.logit_scale.exp().item()
}
# Update progress bar
pbar.set_postfix({
"lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}",
"cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}",
"reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}"
})
if wandb_enabled: wandb.log(log_data)
# Reset accumulators
epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0
batches_since_log = 0
valid_batches_accum = 0 # Reset valid batch count
# --- End of Epoch ---
print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.")
# Optional: Add end-of-epoch evaluation or model saving here
if epoch % 5 == 0:
torch.save(model.state_dict(), f"model_regression_multi_{epoch+1}.pth")
# --- End of Training ---
print("\nTraining completed!")
if wandb_enabled:
wandb.finish() |