gemma-embed-large / train3.py
dejanseo's picture
Upload train3.py
9b9f13c verified
#!/usr/bin/env python3
import os
# prevent HF tokenizers threads from hanging the process
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import PeftModel
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
from multiprocessing import freeze_support
import shutil # Import shutil for removing old checkpoints
import collections # Import collections for deque
def main():
# --- Config ---
PRET_FILE = "pretokenized_queries.pt"
MODEL_NAME = "google/gemma-3-1b-pt"
LORA_DIR = "phase2_triplet_amp/final" # Adapters from previous stage
BATCH_SIZE = 200
LR = 1e-5
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 1 # As per our discussion, 1 epoch is likely sufficient given fast convergence
TEMP = 0.05
OUTPUT_DIR = "phase3_self_contrast_wandb"
GRAD_CLIP_NORM = 1.0
SEED = 42
WANDB_PROJECT = "query-encoder-phase3"
# --- Checkpointing Configuration ---
SAVE_INTERVAL = 1000 # Save a checkpoint every N steps
KEEP_LAST_CKPTS = 5 # Keep only the last N checkpoints (to save disk space)
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)
# --- Initialize WandB ---
wandb.init(
project=WANDB_PROJECT,
config={
"model_name": MODEL_NAME, "lora_dir": LORA_DIR, "batch_size": BATCH_SIZE,
"lr": LR, "num_epochs": NUM_EPOCHS, "seed": SEED,
"save_interval_steps": SAVE_INTERVAL,
"keep_last_checkpoints": KEEP_LAST_CKPTS,
}
)
# --- Load pretokenized queries safely ---
print(f"Loading pretokenized queries from {PRET_FILE}...")
data = torch.load(PRET_FILE, weights_only=True)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
dataset = TensorDataset(input_ids, attention_mask)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"Loaded {len(dataset)} samples.")
# --- Load base model + LoRA adapters from previous stage ---
print(f"Loading base model '{MODEL_NAME}' and LoRA adapters from '{LORA_DIR}'...")
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)
print("LoRA adapters loaded.")
# --- Projection head now outputs hidden_size ---
class GemmaSelfContrast(nn.Module):
def __init__(self, peft_model):
super().__init__()
self.peft = peft_model
hs = peft_model.base_model.config.hidden_size
self.proj = nn.Sequential(
nn.Linear(hs, 512),
nn.ReLU(),
nn.Linear(512, hs),
)
def forward(self, ids, mask):
out = self.peft.base_model(
input_ids=ids,
attention_mask=mask,
output_hidden_states=True,
return_dict=True
)
h = out.hidden_states[-1].mean(dim=1)
h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6)
z = self.proj(h) # now (B, hidden_size)
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
return z / norm
model = GemmaSelfContrast(peft).to(device)
print("Encoder model (with projection head) initialized.")
# Watch the model with wandb (optional, can be slow, but good for tracking gradients)
# wandb.watch(model, log="all", log_freq=100) # Commented out due to potential slowdown
# --- Optimizer, scheduler, AMP scaler ---
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = len(loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps
)
scaler = GradScaler()
print(f"Training will run for {total_steps} steps.")
# Deque to manage checkpoint paths and enforce keeping only the last N
checkpoint_paths = collections.deque(maxlen=KEEP_LAST_CKPTS)
# --- Training loop ---
model.train()
global_step = 0
for epoch in range(1, NUM_EPOCHS + 1):
total_loss = 0.0
pbar = tqdm(loader, desc=f"Epoch {epoch}", unit="batch")
for ids, mask in pbar:
ids, mask = ids.to(device), mask.to(device)
with autocast():
e1 = model(ids, mask)
e2 = model(ids, mask)
emb = torch.cat([e1, e2], dim=0)
sim = (emb @ emb.T) / TEMP
# mask diagonal with -inf
mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool)
sim = sim.masked_fill(mask_eye, float('-inf'))
B = e1.size(0)
labels = torch.cat([
torch.arange(B, device=device) + B,
torch.arange(B, device=device)
], dim=0)
loss = F.cross_entropy(sim, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # Unscale gradients before clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
scaler.step(optimizer)
scaler.update()
scheduler.step()
# --- Log metrics to WandB at every step ---
wandb.log({
"train/loss": loss.item(),
"train/lr": scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/global_step": global_step
}, step=global_step)
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
# --- PERIODIC SAVING BLOCK ---
# Save checkpoint every SAVE_INTERVAL steps
if (global_step + 1) % SAVE_INTERVAL == 0:
# Create a unique directory for this checkpoint
ckpt_dir = os.path.join(OUTPUT_DIR, f"checkpoint-step-{global_step + 1}")
os.makedirs(ckpt_dir, exist_ok=True)
print(f"\nSaving checkpoint to {ckpt_dir}...")
# Save the PEFT adapters
peft.save_pretrained(ckpt_dir)
# Save the trained projection head's state dictionary
torch.save(model.proj.state_dict(), os.path.join(ckpt_dir, "encoder_proj.pth"))
# Manage old checkpoints
if len(checkpoint_paths) == KEEP_LAST_CKPTS:
oldest_ckpt = checkpoint_paths.popleft() # Remove the oldest path from deque
if os.path.isdir(oldest_ckpt):
print(f"Removing old checkpoint: {oldest_ckpt}")
shutil.rmtree(oldest_ckpt, ignore_errors=True) # Delete the directory
checkpoint_paths.append(ckpt_dir) # Add new checkpoint path
print("Checkpoint saved and old ones managed.")
# --- END PERIODIC SAVING ---
global_step += 1
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch} training complete. Avg loss: {avg_loss:.6f}")
# Log average epoch loss as well
wandb.log({"train/epoch_avg_loss": avg_loss, "epoch": epoch}, step=global_step)
# --- Final Save for the "final" directory ---
# This ensures that even if you stop mid-epoch (after a checkpoint)
# or don't stop, there's always a clear 'final' model.
print("\nTraining finished. Saving final model to 'final' directory...")
final_dir = os.path.join(OUTPUT_DIR, "final")
os.makedirs(final_dir, exist_ok=True)
# Save the LoRA adapters
peft.save_pretrained(final_dir)
# Save the trained projection head's state dictionary
torch.save(model.proj.state_dict(), os.path.join(final_dir, "encoder_proj.pth"))
print(f"Phase 3 complete. LoRA adapters and projection head saved to {final_dir}")
# --- Finalize WandB run ---
wandb.finish()
if __name__ == "__main__":
freeze_support()
main()