from constants import * from dataset import create_train_dataloader, create_test_dataloader from vision_language_model import VisionLanguageModel from utils import * from datetime import datetime import wandb import torch import torch.optim as optim from torch.optim.lr_scheduler import OneCycleLR from tqdm.auto import tqdm print(f"Using device: {DEVICE}") print(f"Vocab size: {vocab_size}") # --- Initialize Model --- # Ensure lambda_regression is passed during initialization model = VisionLanguageModel( n_embd=HIDDEN_DIM, vocab_size=vocab_size, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_heads=NUM_HEADS, num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers emb_dropout=DROPOUT, blk_dropout=DROPOUT, max_context=CONTEXT_LENGTH, shared_embed_dim=SHARED_EMBED_DIM, lambda_contrastive=LAMBDA_CONTRASTIVE, lambda_regression=LAMBDA_REGRESSION # Pass the regression weight ).to(DEVICE) # --- Optimizer --- # Optimizer will automatically include all model parameters, including the new regression head optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1) # --- Dataloaders --- # Ensure these functions now return 'continuous_coords' in the batch dictionary train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2) if train_loader is None: exit("Training loader failed to initialize.") test_loader_has_data = test_loader and len(test_loader.dataset) > 0 # --- LR Scheduler --- if train_loader and len(train_loader) > 0: steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0) total_steps = steps_per_epoch * NUM_EPOCHS # Adjust warmup steps if total steps are very low warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}") lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1) else: print("Warning: Train loader empty. Using constant LR.") total_steps = 0; warmup_steps = 0 lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) # --- Wandb Setup --- try: wandb.init( # project="point-language-model-dualhead", # Suggest new project name project="point-language-model-regression-vast", name=f"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}", config={ # Add new hyperparameters "image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM, "context_length": CONTEXT_LENGTH, "dropout": DROPOUT, "num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE, "learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS, "shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE, "lambda_regression": LAMBDA_REGRESSION, # Log regression weight "architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW", "num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps } ) wandb_enabled = True # Watch model gradients and parameters # wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS) except Exception as e: print(f"Wandb initialization failed: {e}. Running without wandb.") wandb_enabled = False # --- Training Loop --- print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...") step_counter = 0 optimizer.zero_grad() for epoch in range(NUM_EPOCHS): model.train() epoch_total_loss_accum = 0.0 epoch_class_loss_accum = 0.0 epoch_con_loss_accum = 0.0 epoch_reg_loss_accum = 0.0 batches_since_log = 0 valid_batches_accum = 0 # Count batches with valid loss for averaging pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False) for batch_idx, batch in pbar: if batch is None: continue # --- Unpack Batch Data --- try: images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE) prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True) prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True) target_ids = batch['target_ids'].to(DEVICE, non_blocking=True) target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True) generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True) continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) # Mask except KeyError as e: print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.") continue # Clamp logit_scale with torch.no_grad(): model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0))) # --- Forward Pass --- # Model now returns potentially NaN scalar tensors for individual losses if invalid logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model( img_array=images, prompt_ids=prompt_ids, prompt_attention_mask=prompt_attention_mask, target_ids=target_ids, target_attention_mask=target_attention_mask, generative_targets=generative_targets, continuous_coords=continuous_coords, coords_mask=coords_mask # Pass mask for regression loss calculation ) # --- Loss Handling & Accumulation --- # Check for invalid total loss before backward pass if total_loss is None or not torch.isfinite(total_loss): print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.") optimizer.zero_grad() # Reset gradients for safety if loss is invalid continue # Skip this batch for optimization step # Scale loss for gradient accumulation scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS # Accumulate valid loss components for logging # Check if the scalar tensor is finite before adding its item() if torch.isfinite(total_loss): epoch_total_loss_accum += total_loss.item() valid_batches_accum += 1 # Increment count of batches contributing to average loss if torch.isfinite(class_loss_s): epoch_class_loss_accum += class_loss_s.item() if torch.isfinite(contrastive_loss_s): epoch_con_loss_accum += contrastive_loss_s.item() if torch.isfinite(regression_loss_s): epoch_reg_loss_accum += regression_loss_s.item() batches_since_log += 1 # --- Backward Pass --- try: scaled_loss.backward() except Exception as e: print(f"Error during backward pass: {e}. Skipping step.") optimizer.zero_grad() # Reset gradients if backward failed continue # --- Gradient Accumulation Step --- if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader): # Clip gradients grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) # Check for non-finite gradients before stepping all_finite = True for p in model.parameters(): if p.grad is not None and not torch.isfinite(p.grad).all(): all_finite = False break if not all_finite: print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.") optimizer.zero_grad() continue # Skip optimizer step and scheduler step # Optimizer step optimizer.step() lr_scheduler.step() optimizer.zero_grad() step_counter += 1 # --- Logging --- if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum # Calculate average losses over the logging period using valid batch count avg_total_loss = epoch_total_loss_accum / valid_batches_accum avg_class_loss = epoch_class_loss_accum / valid_batches_accum avg_con_loss = epoch_con_loss_accum / valid_batches_accum avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum current_lr = optimizer.param_groups[0]['lr'] # --- Test Evaluation (Needs modification to handle mask) --- test_class_loss_val = float('nan') test_con_loss_val = float('nan') test_reg_loss_val = float('nan') if test_loader_has_data: model.eval() with torch.no_grad(): try: test_batch = next(iter(test_loader)) if test_batch: t_images = test_batch['image'].to(DEVICE).to(DTYPE) t_p_ids = test_batch['prompt_ids'].to(DEVICE) t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE) t_t_ids = test_batch['target_ids'].to(DEVICE) t_t_mask = test_batch['target_attention_mask'].to(DEVICE) t_gen_targets = test_batch['generative_targets'].to(DEVICE) t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded t_coords_mask = test_batch['coords_mask'].to(DEVICE) # Mask _, _, _, t_class_loss, t_con_loss, t_reg_loss = model( t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask, t_gen_targets, t_cont_coords, t_coords_mask # Pass mask ) # Use .item() only if the tensor is finite test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan') test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan') test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan') # ... (rest of exception handling) ... except StopIteration: print("Info: Test loader exhausted during logging.") except KeyError as e: print(f"Error: Missing key {e} in test batch.") except Exception as e: print(f"Error during test evaluation: {e}") model.train() # Prepare data for logging log_data = { "train/total_loss": avg_total_loss, "train/class_loss": avg_class_loss, "train/contrastive_loss": avg_con_loss, "train/regression_loss": avg_reg_loss, "test/class_loss": test_class_loss_val, "test/contrastive_loss": test_con_loss_val, "test/regression_loss": test_reg_loss_val, "epoch": epoch + ((batch_idx + 1) / len(train_loader)), "step": step_counter, "learning_rate": current_lr, "gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, "logit_scale": model.logit_scale.exp().item() } # Update progress bar pbar.set_postfix({ "lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}", "cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}", "reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}" }) if wandb_enabled: wandb.log(log_data) # Reset accumulators epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0 batches_since_log = 0 valid_batches_accum = 0 # Reset valid batch count # --- End of Epoch --- print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.") # Optional: Add end-of-epoch evaluation or model saving here if epoch % 5 == 0: torch.save(model.state_dict(), f"model_regression_multi_{epoch+1}.pth") # --- End of Training --- print("\nTraining completed!") if wandb_enabled: wandb.finish()