mbiswas's picture
Upload 10 files
b781107 verified
from constants import *
from dataset import create_train_dataloader, create_test_dataloader
from vision_language_model import VisionLanguageModel
from utils import *
from datetime import datetime
import wandb
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from tqdm.auto import tqdm
print(f"Using device: {DEVICE}")
print(f"Vocab size: {vocab_size}")
# --- Initialize Model ---
# Ensure lambda_regression is passed during initialization
model = VisionLanguageModel(
n_embd=HIDDEN_DIM,
vocab_size=vocab_size,
img_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_heads=NUM_HEADS,
num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
emb_dropout=DROPOUT,
blk_dropout=DROPOUT,
max_context=CONTEXT_LENGTH,
shared_embed_dim=SHARED_EMBED_DIM,
lambda_contrastive=LAMBDA_CONTRASTIVE,
lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
).to(DEVICE)
# --- Optimizer ---
# Optimizer will automatically include all model parameters, including the new regression head
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)
# --- Dataloaders ---
# Ensure these functions now return 'continuous_coords' in the batch dictionary
train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first
test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2)
if train_loader is None: exit("Training loader failed to initialize.")
test_loader_has_data = test_loader and len(test_loader.dataset) > 0
# --- LR Scheduler ---
if train_loader and len(train_loader) > 0:
steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0)
total_steps = steps_per_epoch * NUM_EPOCHS
# Adjust warmup steps if total steps are very low
warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup
print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}")
lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)
else:
print("Warning: Train loader empty. Using constant LR.")
total_steps = 0; warmup_steps = 0
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
# --- Wandb Setup ---
try:
wandb.init(
# project="point-language-model-dualhead", # Suggest new project name
project="point-language-model-regression-vast",
name=f"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
config={ # Add new hyperparameters
"image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM,
"context_length": CONTEXT_LENGTH, "dropout": DROPOUT,
"num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE,
"learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS,
"shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE,
"lambda_regression": LAMBDA_REGRESSION, # Log regression weight
"architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW",
"num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps
}
)
wandb_enabled = True
# Watch model gradients and parameters
# wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)
except Exception as e:
print(f"Wandb initialization failed: {e}. Running without wandb.")
wandb_enabled = False
# --- Training Loop ---
print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...")
step_counter = 0
optimizer.zero_grad()
for epoch in range(NUM_EPOCHS):
model.train()
epoch_total_loss_accum = 0.0
epoch_class_loss_accum = 0.0
epoch_con_loss_accum = 0.0
epoch_reg_loss_accum = 0.0
batches_since_log = 0
valid_batches_accum = 0 # Count batches with valid loss for averaging
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
for batch_idx, batch in pbar:
if batch is None: continue
# --- Unpack Batch Data ---
try:
images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)
continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded
coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) # Mask
except KeyError as e:
print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.")
continue
# Clamp logit_scale
with torch.no_grad():
model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))
# --- Forward Pass ---
# Model now returns potentially NaN scalar tensors for individual losses if invalid
logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model(
img_array=images,
prompt_ids=prompt_ids,
prompt_attention_mask=prompt_attention_mask,
target_ids=target_ids,
target_attention_mask=target_attention_mask,
generative_targets=generative_targets,
continuous_coords=continuous_coords,
coords_mask=coords_mask # Pass mask for regression loss calculation
)
# --- Loss Handling & Accumulation ---
# Check for invalid total loss before backward pass
if total_loss is None or not torch.isfinite(total_loss):
print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.")
optimizer.zero_grad() # Reset gradients for safety if loss is invalid
continue # Skip this batch for optimization step
# Scale loss for gradient accumulation
scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS
# Accumulate valid loss components for logging
# Check if the scalar tensor is finite before adding its item()
if torch.isfinite(total_loss):
epoch_total_loss_accum += total_loss.item()
valid_batches_accum += 1 # Increment count of batches contributing to average loss
if torch.isfinite(class_loss_s):
epoch_class_loss_accum += class_loss_s.item()
if torch.isfinite(contrastive_loss_s):
epoch_con_loss_accum += contrastive_loss_s.item()
if torch.isfinite(regression_loss_s):
epoch_reg_loss_accum += regression_loss_s.item()
batches_since_log += 1
# --- Backward Pass ---
try:
scaled_loss.backward()
except Exception as e:
print(f"Error during backward pass: {e}. Skipping step.")
optimizer.zero_grad() # Reset gradients if backward failed
continue
# --- Gradient Accumulation Step ---
if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):
# Clip gradients
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
# Check for non-finite gradients before stepping
all_finite = True
for p in model.parameters():
if p.grad is not None and not torch.isfinite(p.grad).all():
all_finite = False
break
if not all_finite:
print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.")
optimizer.zero_grad()
continue # Skip optimizer step and scheduler step
# Optimizer step
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
step_counter += 1
# --- Logging ---
if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum
# Calculate average losses over the logging period using valid batch count
avg_total_loss = epoch_total_loss_accum / valid_batches_accum
avg_class_loss = epoch_class_loss_accum / valid_batches_accum
avg_con_loss = epoch_con_loss_accum / valid_batches_accum
avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum
current_lr = optimizer.param_groups[0]['lr']
# --- Test Evaluation (Needs modification to handle mask) ---
test_class_loss_val = float('nan')
test_con_loss_val = float('nan')
test_reg_loss_val = float('nan')
if test_loader_has_data:
model.eval()
with torch.no_grad():
try:
test_batch = next(iter(test_loader))
if test_batch:
t_images = test_batch['image'].to(DEVICE).to(DTYPE)
t_p_ids = test_batch['prompt_ids'].to(DEVICE)
t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
t_t_ids = test_batch['target_ids'].to(DEVICE)
t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
t_gen_targets = test_batch['generative_targets'].to(DEVICE)
t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded
t_coords_mask = test_batch['coords_mask'].to(DEVICE) # Mask
_, _, _, t_class_loss, t_con_loss, t_reg_loss = model(
t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,
t_gen_targets, t_cont_coords, t_coords_mask # Pass mask
)
# Use .item() only if the tensor is finite
test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan')
test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan')
# ... (rest of exception handling) ...
except StopIteration: print("Info: Test loader exhausted during logging.")
except KeyError as e: print(f"Error: Missing key {e} in test batch.")
except Exception as e: print(f"Error during test evaluation: {e}")
model.train()
# Prepare data for logging
log_data = {
"train/total_loss": avg_total_loss,
"train/class_loss": avg_class_loss,
"train/contrastive_loss": avg_con_loss,
"train/regression_loss": avg_reg_loss,
"test/class_loss": test_class_loss_val,
"test/contrastive_loss": test_con_loss_val,
"test/regression_loss": test_reg_loss_val,
"epoch": epoch + ((batch_idx + 1) / len(train_loader)),
"step": step_counter,
"learning_rate": current_lr,
"gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
"logit_scale": model.logit_scale.exp().item()
}
# Update progress bar
pbar.set_postfix({
"lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}",
"cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}",
"reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}"
})
if wandb_enabled: wandb.log(log_data)
# Reset accumulators
epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0
batches_since_log = 0
valid_batches_accum = 0 # Reset valid batch count
# --- End of Epoch ---
print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.")
# Optional: Add end-of-epoch evaluation or model saving here
if epoch % 5 == 0:
torch.save(model.state_dict(), f"model_regression_multi_{epoch+1}.pth")
# --- End of Training ---
print("\nTraining completed!")
if wandb_enabled:
wandb.finish()