|
|
from constants import * |
|
|
from dataset import create_train_dataloader, create_test_dataloader |
|
|
from vision_language_model import VisionLanguageModel |
|
|
from utils import * |
|
|
from datetime import datetime |
|
|
import wandb |
|
|
import torch |
|
|
import torch.optim as optim |
|
|
from torch.optim.lr_scheduler import OneCycleLR |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
print(f"Using device: {DEVICE}") |
|
|
print(f"Vocab size: {vocab_size}") |
|
|
|
|
|
|
|
|
|
|
|
model = VisionLanguageModel( |
|
|
n_embd=HIDDEN_DIM, |
|
|
vocab_size=vocab_size, |
|
|
img_size=IMAGE_SIZE, |
|
|
patch_size=PATCH_SIZE, |
|
|
num_heads=NUM_HEADS, |
|
|
num_blks_vit=NUM_LAYERS, |
|
|
num_blks_dec=NUM_LAYERS, |
|
|
emb_dropout=DROPOUT, |
|
|
blk_dropout=DROPOUT, |
|
|
max_context=CONTEXT_LENGTH, |
|
|
shared_embed_dim=SHARED_EMBED_DIM, |
|
|
lambda_contrastive=LAMBDA_CONTRASTIVE, |
|
|
lambda_regression=LAMBDA_REGRESSION |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1) |
|
|
|
|
|
|
|
|
|
|
|
train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) |
|
|
test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2) |
|
|
if train_loader is None: exit("Training loader failed to initialize.") |
|
|
test_loader_has_data = test_loader and len(test_loader.dataset) > 0 |
|
|
|
|
|
|
|
|
if train_loader and len(train_loader) > 0: |
|
|
steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0) |
|
|
total_steps = steps_per_epoch * NUM_EPOCHS |
|
|
|
|
|
warmup_steps = min(max(1, total_steps // 10), 10000) |
|
|
print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}") |
|
|
lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1) |
|
|
else: |
|
|
print("Warning: Train loader empty. Using constant LR.") |
|
|
total_steps = 0; warmup_steps = 0 |
|
|
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) |
|
|
|
|
|
|
|
|
try: |
|
|
wandb.init( |
|
|
|
|
|
project="point-language-model-regression-vast", |
|
|
name=f"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}", |
|
|
config={ |
|
|
"image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM, |
|
|
"context_length": CONTEXT_LENGTH, "dropout": DROPOUT, |
|
|
"num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE, |
|
|
"learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS, |
|
|
"shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE, |
|
|
"lambda_regression": LAMBDA_REGRESSION, |
|
|
"architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW", |
|
|
"num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps |
|
|
} |
|
|
) |
|
|
wandb_enabled = True |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print(f"Wandb initialization failed: {e}. Running without wandb.") |
|
|
wandb_enabled = False |
|
|
|
|
|
|
|
|
print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...") |
|
|
step_counter = 0 |
|
|
optimizer.zero_grad() |
|
|
|
|
|
for epoch in range(NUM_EPOCHS): |
|
|
model.train() |
|
|
epoch_total_loss_accum = 0.0 |
|
|
epoch_class_loss_accum = 0.0 |
|
|
epoch_con_loss_accum = 0.0 |
|
|
epoch_reg_loss_accum = 0.0 |
|
|
batches_since_log = 0 |
|
|
valid_batches_accum = 0 |
|
|
|
|
|
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False) |
|
|
|
|
|
for batch_idx, batch in pbar: |
|
|
if batch is None: continue |
|
|
|
|
|
|
|
|
try: |
|
|
images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE) |
|
|
prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True) |
|
|
prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True) |
|
|
target_ids = batch['target_ids'].to(DEVICE, non_blocking=True) |
|
|
target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True) |
|
|
generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True) |
|
|
continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) |
|
|
coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True) |
|
|
except KeyError as e: |
|
|
print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.") |
|
|
continue |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0))) |
|
|
|
|
|
|
|
|
|
|
|
logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model( |
|
|
img_array=images, |
|
|
prompt_ids=prompt_ids, |
|
|
prompt_attention_mask=prompt_attention_mask, |
|
|
target_ids=target_ids, |
|
|
target_attention_mask=target_attention_mask, |
|
|
generative_targets=generative_targets, |
|
|
continuous_coords=continuous_coords, |
|
|
coords_mask=coords_mask |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if total_loss is None or not torch.isfinite(total_loss): |
|
|
print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.") |
|
|
optimizer.zero_grad() |
|
|
continue |
|
|
|
|
|
|
|
|
scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS |
|
|
|
|
|
|
|
|
|
|
|
if torch.isfinite(total_loss): |
|
|
epoch_total_loss_accum += total_loss.item() |
|
|
valid_batches_accum += 1 |
|
|
if torch.isfinite(class_loss_s): |
|
|
epoch_class_loss_accum += class_loss_s.item() |
|
|
if torch.isfinite(contrastive_loss_s): |
|
|
epoch_con_loss_accum += contrastive_loss_s.item() |
|
|
if torch.isfinite(regression_loss_s): |
|
|
epoch_reg_loss_accum += regression_loss_s.item() |
|
|
batches_since_log += 1 |
|
|
|
|
|
|
|
|
try: |
|
|
scaled_loss.backward() |
|
|
except Exception as e: |
|
|
print(f"Error during backward pass: {e}. Skipping step.") |
|
|
optimizer.zero_grad() |
|
|
continue |
|
|
|
|
|
|
|
|
if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader): |
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) |
|
|
|
|
|
|
|
|
all_finite = True |
|
|
for p in model.parameters(): |
|
|
if p.grad is not None and not torch.isfinite(p.grad).all(): |
|
|
all_finite = False |
|
|
break |
|
|
if not all_finite: |
|
|
print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.") |
|
|
optimizer.zero_grad() |
|
|
continue |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
step_counter += 1 |
|
|
|
|
|
|
|
|
if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: |
|
|
|
|
|
avg_total_loss = epoch_total_loss_accum / valid_batches_accum |
|
|
avg_class_loss = epoch_class_loss_accum / valid_batches_accum |
|
|
avg_con_loss = epoch_con_loss_accum / valid_batches_accum |
|
|
avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum |
|
|
current_lr = optimizer.param_groups[0]['lr'] |
|
|
|
|
|
|
|
|
test_class_loss_val = float('nan') |
|
|
test_con_loss_val = float('nan') |
|
|
test_reg_loss_val = float('nan') |
|
|
if test_loader_has_data: |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
try: |
|
|
test_batch = next(iter(test_loader)) |
|
|
if test_batch: |
|
|
t_images = test_batch['image'].to(DEVICE).to(DTYPE) |
|
|
t_p_ids = test_batch['prompt_ids'].to(DEVICE) |
|
|
t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE) |
|
|
t_t_ids = test_batch['target_ids'].to(DEVICE) |
|
|
t_t_mask = test_batch['target_attention_mask'].to(DEVICE) |
|
|
t_gen_targets = test_batch['generative_targets'].to(DEVICE) |
|
|
t_cont_coords = test_batch['continuous_coords'].to(DEVICE) |
|
|
t_coords_mask = test_batch['coords_mask'].to(DEVICE) |
|
|
|
|
|
_, _, _, t_class_loss, t_con_loss, t_reg_loss = model( |
|
|
t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask, |
|
|
t_gen_targets, t_cont_coords, t_coords_mask |
|
|
) |
|
|
|
|
|
test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan') |
|
|
test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan') |
|
|
test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan') |
|
|
|
|
|
except StopIteration: print("Info: Test loader exhausted during logging.") |
|
|
except KeyError as e: print(f"Error: Missing key {e} in test batch.") |
|
|
except Exception as e: print(f"Error during test evaluation: {e}") |
|
|
model.train() |
|
|
|
|
|
|
|
|
log_data = { |
|
|
"train/total_loss": avg_total_loss, |
|
|
"train/class_loss": avg_class_loss, |
|
|
"train/contrastive_loss": avg_con_loss, |
|
|
"train/regression_loss": avg_reg_loss, |
|
|
"test/class_loss": test_class_loss_val, |
|
|
"test/contrastive_loss": test_con_loss_val, |
|
|
"test/regression_loss": test_reg_loss_val, |
|
|
"epoch": epoch + ((batch_idx + 1) / len(train_loader)), |
|
|
"step": step_counter, |
|
|
"learning_rate": current_lr, |
|
|
"gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, |
|
|
"logit_scale": model.logit_scale.exp().item() |
|
|
} |
|
|
|
|
|
pbar.set_postfix({ |
|
|
"lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}", |
|
|
"cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}", |
|
|
"reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}" |
|
|
}) |
|
|
if wandb_enabled: wandb.log(log_data) |
|
|
|
|
|
|
|
|
epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0 |
|
|
batches_since_log = 0 |
|
|
valid_batches_accum = 0 |
|
|
|
|
|
|
|
|
print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.") |
|
|
|
|
|
if epoch % 5 == 0: |
|
|
torch.save(model.state_dict(), f"model_regression_multi_{epoch+1}.pth") |
|
|
|
|
|
|
|
|
print("\nTraining completed!") |
|
|
if wandb_enabled: |
|
|
wandb.finish() |