File size: 8,739 Bytes
4cabce9
 
 
 
 
 
 
 
 
9b9f13c
4cabce9
 
 
 
 
 
9b9f13c
 
4cabce9
 
 
 
 
9b9f13c
 
4cabce9
 
9b9f13c
4cabce9
9b9f13c
4cabce9
 
9b9f13c
 
 
 
 
4cabce9
 
 
 
 
9b9f13c
 
 
 
 
 
 
 
 
 
 
4cabce9
9b9f13c
4cabce9
 
 
 
 
9b9f13c
4cabce9
9b9f13c
 
4cabce9
 
9b9f13c
4cabce9
9b9f13c
4cabce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b9f13c
 
 
4cabce9
 
 
 
 
 
 
 
 
 
9b9f13c
 
 
 
4cabce9
 
 
9b9f13c
4cabce9
 
9b9f13c
 
4cabce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b9f13c
4cabce9
 
 
 
9b9f13c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cabce9
 
 
9b9f13c
 
 
 
 
 
 
 
4cabce9
 
9b9f13c
 
4cabce9
9b9f13c
 
 
 
 
 
 
 
 
4cabce9
 
 
9b9f13c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3
import os
# prevent HF tokenizers threads from hanging the process
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import PeftModel
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
from multiprocessing import freeze_support
import shutil # Import shutil for removing old checkpoints
import collections # Import collections for deque

def main():
    # --- Config ---
    PRET_FILE      = "pretokenized_queries.pt"
    MODEL_NAME     = "google/gemma-3-1b-pt"
    LORA_DIR       = "phase2_triplet_amp/final" # Adapters from previous stage
    BATCH_SIZE     = 200 
    LR             = 1e-5
    WEIGHT_DECAY   = 0.01
    NUM_EPOCHS     = 1 # As per our discussion, 1 epoch is likely sufficient given fast convergence
    TEMP           = 0.05
    OUTPUT_DIR     = "phase3_self_contrast_wandb"
    GRAD_CLIP_NORM = 1.0
    SEED           = 42
    WANDB_PROJECT  = "query-encoder-phase3"

    # --- Checkpointing Configuration ---
    SAVE_INTERVAL = 1000 # Save a checkpoint every N steps
    KEEP_LAST_CKPTS = 5  # Keep only the last N checkpoints (to save disk space)

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(SEED)

    # --- Initialize WandB ---
    wandb.init(
        project=WANDB_PROJECT,
        config={
            "model_name": MODEL_NAME, "lora_dir": LORA_DIR, "batch_size": BATCH_SIZE, 
            "lr": LR, "num_epochs": NUM_EPOCHS, "seed": SEED,
            "save_interval_steps": SAVE_INTERVAL,
            "keep_last_checkpoints": KEEP_LAST_CKPTS,
        }
    )

    # --- Load pretokenized queries safely ---
    print(f"Loading pretokenized queries from {PRET_FILE}...")
    data = torch.load(PRET_FILE, weights_only=True)
    input_ids      = data["input_ids"]
    attention_mask = data["attention_mask"]
    dataset = TensorDataset(input_ids, attention_mask)
    loader  = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    print(f"Loaded {len(dataset)} samples.")

    # --- Load base model + LoRA adapters from previous stage ---
    print(f"Loading base model '{MODEL_NAME}' and LoRA adapters from '{LORA_DIR}'...")
    base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
    peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)
    print("LoRA adapters loaded.")

    # --- Projection head now outputs hidden_size ---
    class GemmaSelfContrast(nn.Module):
        def __init__(self, peft_model):
            super().__init__()
            self.peft = peft_model
            hs = peft_model.base_model.config.hidden_size
            self.proj = nn.Sequential(
                nn.Linear(hs, 512),
                nn.ReLU(),
                nn.Linear(512, hs),
            )
        def forward(self, ids, mask):
            out = self.peft.base_model(
                input_ids=ids,
                attention_mask=mask,
                output_hidden_states=True,
                return_dict=True
            )
            h = out.hidden_states[-1].mean(dim=1)
            h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6)
            z = self.proj(h)               # now (B, hidden_size)
            z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
            norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
            return z / norm

    model = GemmaSelfContrast(peft).to(device)
    print("Encoder model (with projection head) initialized.")
    # Watch the model with wandb (optional, can be slow, but good for tracking gradients)
    # wandb.watch(model, log="all", log_freq=100) # Commented out due to potential slowdown

    # --- Optimizer, scheduler, AMP scaler ---
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    total_steps = len(loader) * NUM_EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    scaler = GradScaler()
    print(f"Training will run for {total_steps} steps.")

    # Deque to manage checkpoint paths and enforce keeping only the last N
    checkpoint_paths = collections.deque(maxlen=KEEP_LAST_CKPTS)

    # --- Training loop ---
    model.train()
    global_step = 0
    for epoch in range(1, NUM_EPOCHS + 1):
        total_loss = 0.0
        pbar = tqdm(loader, desc=f"Epoch {epoch}", unit="batch")
        for ids, mask in pbar:
            ids, mask = ids.to(device), mask.to(device)

            with autocast():
                e1  = model(ids, mask)
                e2  = model(ids, mask)
                emb = torch.cat([e1, e2], dim=0)
                sim = (emb @ emb.T) / TEMP
                # mask diagonal with -inf
                mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool)
                sim = sim.masked_fill(mask_eye, float('-inf'))

                B = e1.size(0)
                labels = torch.cat([
                    torch.arange(B, device=device) + B,
                    torch.arange(B, device=device)
                ], dim=0)
                loss = F.cross_entropy(sim, labels)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer) # Unscale gradients before clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            # --- Log metrics to WandB at every step ---
            wandb.log({
                "train/loss": loss.item(),
                "train/lr": scheduler.get_last_lr()[0],
                "train/epoch": epoch,
                "train/global_step": global_step
            }, step=global_step)
            
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

            # --- PERIODIC SAVING BLOCK ---
            # Save checkpoint every SAVE_INTERVAL steps
            if (global_step + 1) % SAVE_INTERVAL == 0: 
                # Create a unique directory for this checkpoint
                ckpt_dir = os.path.join(OUTPUT_DIR, f"checkpoint-step-{global_step + 1}")
                os.makedirs(ckpt_dir, exist_ok=True)
                
                print(f"\nSaving checkpoint to {ckpt_dir}...")
                # Save the PEFT adapters
                peft.save_pretrained(ckpt_dir)
                # Save the trained projection head's state dictionary
                torch.save(model.proj.state_dict(), os.path.join(ckpt_dir, "encoder_proj.pth"))
                
                # Manage old checkpoints
                if len(checkpoint_paths) == KEEP_LAST_CKPTS:
                    oldest_ckpt = checkpoint_paths.popleft() # Remove the oldest path from deque
                    if os.path.isdir(oldest_ckpt):
                        print(f"Removing old checkpoint: {oldest_ckpt}")
                        shutil.rmtree(oldest_ckpt, ignore_errors=True) # Delete the directory
                checkpoint_paths.append(ckpt_dir) # Add new checkpoint path
                print("Checkpoint saved and old ones managed.")
            # --- END PERIODIC SAVING ---

            global_step += 1
            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch} training complete. Avg loss: {avg_loss:.6f}")
        # Log average epoch loss as well
        wandb.log({"train/epoch_avg_loss": avg_loss, "epoch": epoch}, step=global_step)

    # --- Final Save for the "final" directory ---
    # This ensures that even if you stop mid-epoch (after a checkpoint)
    # or don't stop, there's always a clear 'final' model.
    print("\nTraining finished. Saving final model to 'final' directory...")
    final_dir = os.path.join(OUTPUT_DIR, "final")
    os.makedirs(final_dir, exist_ok=True)
    
    # Save the LoRA adapters
    peft.save_pretrained(final_dir)
    
    # Save the trained projection head's state dictionary
    torch.save(model.proj.state_dict(), os.path.join(final_dir, "encoder_proj.pth"))
    
    print(f"Phase 3 complete. LoRA adapters and projection head saved to {final_dir}")
    
    # --- Finalize WandB run ---
    wandb.finish()


if __name__ == "__main__":
    freeze_support()
    main()