File size: 13,347 Bytes
b781107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from constants import *
from dataset import create_train_dataloader, create_test_dataloader
from vision_language_model import VisionLanguageModel
from utils import *
from datetime import datetime
import wandb
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from tqdm.auto import tqdm

print(f"Using device: {DEVICE}")
print(f"Vocab size: {vocab_size}")

# --- Initialize Model ---
# Ensure lambda_regression is passed during initialization
model = VisionLanguageModel(
    n_embd=HIDDEN_DIM,
    vocab_size=vocab_size,
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    num_heads=NUM_HEADS,
    num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
    num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
    emb_dropout=DROPOUT,
    blk_dropout=DROPOUT,
    max_context=CONTEXT_LENGTH,
    shared_embed_dim=SHARED_EMBED_DIM,
    lambda_contrastive=LAMBDA_CONTRASTIVE,
    lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
).to(DEVICE)

# --- Optimizer ---
# Optimizer will automatically include all model parameters, including the new regression head
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)

# --- Dataloaders ---
# Ensure these functions now return 'continuous_coords' in the batch dictionary
train_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=2) # Use num_workers=0 for easier debugging first
test_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=2)
if train_loader is None: exit("Training loader failed to initialize.")
test_loader_has_data = test_loader and len(test_loader.dataset) > 0

# --- LR Scheduler ---
if train_loader and len(train_loader) > 0:
    steps_per_epoch = (len(train_loader) // GRAD_ACCUMULATION_STEPS) + (1 if len(train_loader) % GRAD_ACCUMULATION_STEPS != 0 else 0)
    total_steps = steps_per_epoch * NUM_EPOCHS
    # Adjust warmup steps if total steps are very low
    warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup
    print(f"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}")
    lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)
else:
    print("Warning: Train loader empty. Using constant LR.")
    total_steps = 0; warmup_steps = 0
    lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)

# --- Wandb Setup ---
try:
    wandb.init(
        # project="point-language-model-dualhead", # Suggest new project name
        project="point-language-model-regression-vast",
        name=f"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
        config={ # Add new hyperparameters
            "image_size": IMAGE_SIZE, "patch_size": PATCH_SIZE, "hidden_dim": HIDDEN_DIM,
            "context_length": CONTEXT_LENGTH, "dropout": DROPOUT,
            "num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "batch_size": BATCH_SIZE,
            "learning_rate": LEARNING_RATE, "grad_accum_steps": GRAD_ACCUMULATION_STEPS,
            "shared_embed_dim": SHARED_EMBED_DIM, "lambda_contrastive": LAMBDA_CONTRASTIVE,
            "lambda_regression": LAMBDA_REGRESSION, # Log regression weight
            "architecture": "VisionLanguageModel (Dual Head)", "optimizer": "AdamW",
            "num_epochs": NUM_EPOCHS, "total_steps": total_steps, "warmup_steps": warmup_steps
        }
    )
    wandb_enabled = True
    # Watch model gradients and parameters
    # wandb.watch(model, log="all", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)
except Exception as e:
    print(f"Wandb initialization failed: {e}. Running without wandb.")
    wandb_enabled = False

# --- Training Loop ---
print("Starting training with Classification + Contrastive + Regression Loss (Multi-Point)...")
step_counter = 0
optimizer.zero_grad()

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_total_loss_accum = 0.0
    epoch_class_loss_accum = 0.0
    epoch_con_loss_accum = 0.0
    epoch_reg_loss_accum = 0.0
    batches_since_log = 0
    valid_batches_accum = 0 # Count batches with valid loss for averaging

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)

    for batch_idx, batch in pbar:
        if batch is None: continue

        # --- Unpack Batch Data ---
        try:
            images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)
            prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)
            prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)
            target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)
            target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)
            generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)
            continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Padded
            coords_mask = batch['coords_mask'].to(DEVICE, non_blocking=True)           # Mask
        except KeyError as e:
            print(f"Error: Missing key {e} in batch. Check dataloader and collate_fn.")
            continue

        # Clamp logit_scale
        with torch.no_grad():
             model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))

        # --- Forward Pass ---
        # Model now returns potentially NaN scalar tensors for individual losses if invalid
        logits, reg_output, total_loss, class_loss_s, contrastive_loss_s, regression_loss_s = model(
            img_array=images,
            prompt_ids=prompt_ids,
            prompt_attention_mask=prompt_attention_mask,
            target_ids=target_ids,
            target_attention_mask=target_attention_mask,
            generative_targets=generative_targets,
            continuous_coords=continuous_coords,
            coords_mask=coords_mask # Pass mask for regression loss calculation
        )

        # --- Loss Handling & Accumulation ---
        # Check for invalid total loss before backward pass
        if total_loss is None or not torch.isfinite(total_loss):
             print(f"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.")
             optimizer.zero_grad() # Reset gradients for safety if loss is invalid
             continue # Skip this batch for optimization step

        # Scale loss for gradient accumulation
        scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS

        # Accumulate valid loss components for logging
        # Check if the scalar tensor is finite before adding its item()
        if torch.isfinite(total_loss):
            epoch_total_loss_accum += total_loss.item()
            valid_batches_accum += 1 # Increment count of batches contributing to average loss
        if torch.isfinite(class_loss_s):
            epoch_class_loss_accum += class_loss_s.item()
        if torch.isfinite(contrastive_loss_s):
            epoch_con_loss_accum += contrastive_loss_s.item()
        if torch.isfinite(regression_loss_s):
            epoch_reg_loss_accum += regression_loss_s.item()
        batches_since_log += 1

        # --- Backward Pass ---
        try:
             scaled_loss.backward()
        except Exception as e:
             print(f"Error during backward pass: {e}. Skipping step.")
             optimizer.zero_grad() # Reset gradients if backward failed
             continue

        # --- Gradient Accumulation Step ---
        if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):
            # Clip gradients
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)

            # Check for non-finite gradients before stepping
            all_finite = True
            for p in model.parameters():
                 if p.grad is not None and not torch.isfinite(p.grad).all():
                      all_finite = False
                      break
            if not all_finite:
                 print(f"Warning: Non-finite gradients detected at step {step_counter}. Skipping optimizer step.")
                 optimizer.zero_grad()
                 continue # Skip optimizer step and scheduler step

            # Optimizer step
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            step_counter += 1

            # --- Logging ---
            if step_counter % LOGGING_STEPS == 0 and valid_batches_accum > 0: # Use valid_batches_accum
                # Calculate average losses over the logging period using valid batch count
                avg_total_loss = epoch_total_loss_accum / valid_batches_accum
                avg_class_loss = epoch_class_loss_accum / valid_batches_accum
                avg_con_loss = epoch_con_loss_accum / valid_batches_accum
                avg_reg_loss = epoch_reg_loss_accum / valid_batches_accum
                current_lr = optimizer.param_groups[0]['lr']

                # --- Test Evaluation (Needs modification to handle mask) ---
                test_class_loss_val = float('nan')
                test_con_loss_val = float('nan')
                test_reg_loss_val = float('nan')
                if test_loader_has_data:
                     model.eval()
                     with torch.no_grad():
                          try:
                               test_batch = next(iter(test_loader))
                               if test_batch:
                                    t_images = test_batch['image'].to(DEVICE).to(DTYPE)
                                    t_p_ids = test_batch['prompt_ids'].to(DEVICE)
                                    t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)
                                    t_t_ids = test_batch['target_ids'].to(DEVICE)
                                    t_t_mask = test_batch['target_attention_mask'].to(DEVICE)
                                    t_gen_targets = test_batch['generative_targets'].to(DEVICE)
                                    t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Padded
                                    t_coords_mask = test_batch['coords_mask'].to(DEVICE)       # Mask

                                    _, _, _, t_class_loss, t_con_loss, t_reg_loss = model(
                                        t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,
                                        t_gen_targets, t_cont_coords, t_coords_mask # Pass mask
                                    )
                                    # Use .item() only if the tensor is finite
                                    test_class_loss_val = t_class_loss.item() if torch.isfinite(t_class_loss) else float('nan')
                                    test_con_loss_val = t_con_loss.item() if torch.isfinite(t_con_loss) else float('nan')
                                    test_reg_loss_val = t_reg_loss.item() if torch.isfinite(t_reg_loss) else float('nan')
                          # ... (rest of exception handling) ...
                          except StopIteration: print("Info: Test loader exhausted during logging.")
                          except KeyError as e: print(f"Error: Missing key {e} in test batch.")
                          except Exception as e: print(f"Error during test evaluation: {e}")
                     model.train()

                # Prepare data for logging
                log_data = {
                    "train/total_loss": avg_total_loss,
                    "train/class_loss": avg_class_loss,
                    "train/contrastive_loss": avg_con_loss,
                    "train/regression_loss": avg_reg_loss,
                    "test/class_loss": test_class_loss_val,
                    "test/contrastive_loss": test_con_loss_val,
                    "test/regression_loss": test_reg_loss_val,
                    "epoch": epoch + ((batch_idx + 1) / len(train_loader)),
                    "step": step_counter,
                    "learning_rate": current_lr,
                    "gradient_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
                    "logit_scale": model.logit_scale.exp().item()
                }
                # Update progress bar
                pbar.set_postfix({
                    "lr": f"{current_lr:.2e}", "loss": f"{avg_total_loss:.3f}",
                    "cls": f"{avg_class_loss:.3f}", "con": f"{avg_con_loss:.3f}",
                    "reg": f"{avg_reg_loss:.3f}", "gnorm": f"{log_data['gradient_norm']:.2f}"
                    })
                if wandb_enabled: wandb.log(log_data)

                # Reset accumulators
                epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0
                batches_since_log = 0
                valid_batches_accum = 0 # Reset valid batch count

    # --- End of Epoch ---
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} completed.")
    # Optional: Add end-of-epoch evaluation or model saving here
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"model_regression_multi_{epoch+1}.pth")

# --- End of Training ---
print("\nTraining completed!")
if wandb_enabled:
    wandb.finish()