File size: 8,739 Bytes
4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c 4cabce9 9b9f13c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | #!/usr/bin/env python3
import os
# prevent HF tokenizers threads from hanging the process
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import PeftModel
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
from multiprocessing import freeze_support
import shutil # Import shutil for removing old checkpoints
import collections # Import collections for deque
def main():
# --- Config ---
PRET_FILE = "pretokenized_queries.pt"
MODEL_NAME = "google/gemma-3-1b-pt"
LORA_DIR = "phase2_triplet_amp/final" # Adapters from previous stage
BATCH_SIZE = 200
LR = 1e-5
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 1 # As per our discussion, 1 epoch is likely sufficient given fast convergence
TEMP = 0.05
OUTPUT_DIR = "phase3_self_contrast_wandb"
GRAD_CLIP_NORM = 1.0
SEED = 42
WANDB_PROJECT = "query-encoder-phase3"
# --- Checkpointing Configuration ---
SAVE_INTERVAL = 1000 # Save a checkpoint every N steps
KEEP_LAST_CKPTS = 5 # Keep only the last N checkpoints (to save disk space)
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)
# --- Initialize WandB ---
wandb.init(
project=WANDB_PROJECT,
config={
"model_name": MODEL_NAME, "lora_dir": LORA_DIR, "batch_size": BATCH_SIZE,
"lr": LR, "num_epochs": NUM_EPOCHS, "seed": SEED,
"save_interval_steps": SAVE_INTERVAL,
"keep_last_checkpoints": KEEP_LAST_CKPTS,
}
)
# --- Load pretokenized queries safely ---
print(f"Loading pretokenized queries from {PRET_FILE}...")
data = torch.load(PRET_FILE, weights_only=True)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
dataset = TensorDataset(input_ids, attention_mask)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"Loaded {len(dataset)} samples.")
# --- Load base model + LoRA adapters from previous stage ---
print(f"Loading base model '{MODEL_NAME}' and LoRA adapters from '{LORA_DIR}'...")
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)
print("LoRA adapters loaded.")
# --- Projection head now outputs hidden_size ---
class GemmaSelfContrast(nn.Module):
def __init__(self, peft_model):
super().__init__()
self.peft = peft_model
hs = peft_model.base_model.config.hidden_size
self.proj = nn.Sequential(
nn.Linear(hs, 512),
nn.ReLU(),
nn.Linear(512, hs),
)
def forward(self, ids, mask):
out = self.peft.base_model(
input_ids=ids,
attention_mask=mask,
output_hidden_states=True,
return_dict=True
)
h = out.hidden_states[-1].mean(dim=1)
h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6)
z = self.proj(h) # now (B, hidden_size)
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
return z / norm
model = GemmaSelfContrast(peft).to(device)
print("Encoder model (with projection head) initialized.")
# Watch the model with wandb (optional, can be slow, but good for tracking gradients)
# wandb.watch(model, log="all", log_freq=100) # Commented out due to potential slowdown
# --- Optimizer, scheduler, AMP scaler ---
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = len(loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps
)
scaler = GradScaler()
print(f"Training will run for {total_steps} steps.")
# Deque to manage checkpoint paths and enforce keeping only the last N
checkpoint_paths = collections.deque(maxlen=KEEP_LAST_CKPTS)
# --- Training loop ---
model.train()
global_step = 0
for epoch in range(1, NUM_EPOCHS + 1):
total_loss = 0.0
pbar = tqdm(loader, desc=f"Epoch {epoch}", unit="batch")
for ids, mask in pbar:
ids, mask = ids.to(device), mask.to(device)
with autocast():
e1 = model(ids, mask)
e2 = model(ids, mask)
emb = torch.cat([e1, e2], dim=0)
sim = (emb @ emb.T) / TEMP
# mask diagonal with -inf
mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool)
sim = sim.masked_fill(mask_eye, float('-inf'))
B = e1.size(0)
labels = torch.cat([
torch.arange(B, device=device) + B,
torch.arange(B, device=device)
], dim=0)
loss = F.cross_entropy(sim, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # Unscale gradients before clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
scaler.step(optimizer)
scaler.update()
scheduler.step()
# --- Log metrics to WandB at every step ---
wandb.log({
"train/loss": loss.item(),
"train/lr": scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/global_step": global_step
}, step=global_step)
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
# --- PERIODIC SAVING BLOCK ---
# Save checkpoint every SAVE_INTERVAL steps
if (global_step + 1) % SAVE_INTERVAL == 0:
# Create a unique directory for this checkpoint
ckpt_dir = os.path.join(OUTPUT_DIR, f"checkpoint-step-{global_step + 1}")
os.makedirs(ckpt_dir, exist_ok=True)
print(f"\nSaving checkpoint to {ckpt_dir}...")
# Save the PEFT adapters
peft.save_pretrained(ckpt_dir)
# Save the trained projection head's state dictionary
torch.save(model.proj.state_dict(), os.path.join(ckpt_dir, "encoder_proj.pth"))
# Manage old checkpoints
if len(checkpoint_paths) == KEEP_LAST_CKPTS:
oldest_ckpt = checkpoint_paths.popleft() # Remove the oldest path from deque
if os.path.isdir(oldest_ckpt):
print(f"Removing old checkpoint: {oldest_ckpt}")
shutil.rmtree(oldest_ckpt, ignore_errors=True) # Delete the directory
checkpoint_paths.append(ckpt_dir) # Add new checkpoint path
print("Checkpoint saved and old ones managed.")
# --- END PERIODIC SAVING ---
global_step += 1
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch} training complete. Avg loss: {avg_loss:.6f}")
# Log average epoch loss as well
wandb.log({"train/epoch_avg_loss": avg_loss, "epoch": epoch}, step=global_step)
# --- Final Save for the "final" directory ---
# This ensures that even if you stop mid-epoch (after a checkpoint)
# or don't stop, there's always a clear 'final' model.
print("\nTraining finished. Saving final model to 'final' directory...")
final_dir = os.path.join(OUTPUT_DIR, "final")
os.makedirs(final_dir, exist_ok=True)
# Save the LoRA adapters
peft.save_pretrained(final_dir)
# Save the trained projection head's state dictionary
torch.save(model.proj.state_dict(), os.path.join(final_dir, "encoder_proj.pth"))
print(f"Phase 3 complete. LoRA adapters and projection head saved to {final_dir}")
# --- Finalize WandB run ---
wandb.finish()
if __name__ == "__main__":
freeze_support()
main() |