#!/usr/bin/env python3 import os # prevent HF tokenizers threads from hanging the process os.environ["TOKENIZERS_PARALLELISM"] = "false" import sys import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup from peft import PeftModel from torch.cuda.amp import GradScaler, autocast from tqdm.auto import tqdm from multiprocessing import freeze_support def main(): # --- Config --- PRET_FILE = "pretokenized_queries.pt" MODEL_NAME = "google/gemma-3-1b-pt" LORA_DIR = "phase2_triplet_amp/final" BATCH_SIZE = 64 LR = 1e-5 WEIGHT_DECAY = 0.01 NUM_EPOCHS = 1 TEMP = 0.05 OUTPUT_DIR = "phase3_self_contrast" GRAD_CLIP_NORM = 1.0 SEED = 42 os.makedirs(OUTPUT_DIR, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(SEED) # --- Load pretokenized queries safely --- data = torch.load(PRET_FILE, weights_only=True) input_ids = data["input_ids"] attention_mask = data["attention_mask"] dataset = TensorDataset(input_ids, attention_mask) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # --- Load base model + LoRA adapters --- base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager") peft = PeftModel.from_pretrained(base, LORA_DIR).to(device) # --- Projection head --- class GemmaSelfContrast(nn.Module): def __init__(self, peft_model): super().__init__() self.peft = peft_model hs = peft_model.base_model.config.hidden_size self.proj = nn.Sequential( nn.Linear(hs, 512), nn.ReLU(), nn.Linear(512, 256), ) def forward(self, ids, mask): out = self.peft.base_model( input_ids=ids, attention_mask=mask, output_hidden_states=True, return_dict=True ) h = out.hidden_states[-1].mean(dim=1) h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6) z = self.proj(h) z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) return z / norm model = GemmaSelfContrast(peft).to(device) # --- Optimizer, scheduler, AMP scaler --- optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) total_steps = len(loader) * NUM_EPOCHS scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps ) scaler = GradScaler() # --- Training loop --- model.train() for epoch in range(1, NUM_EPOCHS + 1): total_loss = 0.0 for ids, mask in tqdm(loader, desc=f"Epoch {epoch}", unit="batch"): ids, mask = ids.to(device), mask.to(device) with autocast(): e1 = model(ids, mask) e2 = model(ids, mask) emb = torch.cat([e1, e2], dim=0) sim = (emb @ emb.T) / TEMP # mask diagonal with -inf mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool) sim = sim.masked_fill(mask_eye, float('-inf')) B = e1.size(0) labels = torch.cat([ torch.arange(B, device=device) + B, torch.arange(B, device=device) ], dim=0) loss = F.cross_entropy(sim, labels) optimizer.zero_grad() scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) scaler.step(optimizer) scaler.update() scheduler.step() total_loss += loss.item() avg_loss = total_loss / len(loader) print(f"Epoch {epoch} avg loss: {avg_loss:.6f}") # --- Save only LoRA adapters --- final_dir = os.path.join(OUTPUT_DIR, "final") os.makedirs(final_dir, exist_ok=True) peft.save_pretrained(final_dir) print("Phase 3 complete. LoRA adapters saved to", final_dir) if __name__ == "__main__": freeze_support() main() sys.exit(0)