#!/usr/bin/env python3 import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1" import sys import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from transformers import ( AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup ) from peft import LoraConfig, get_peft_model, TaskType from datasets import load_dataset from tqdm.auto import tqdm from multiprocessing import freeze_support def main(): # Config MODEL_NAME = "google/gemma-3-1b-pt" DATA_FILE = "text.txt" # one sequence per line BATCH_SIZE = 12 MAX_LENGTH = 128 LR = 1e-5 WEIGHT_DECAY = 0.01 NUM_EPOCHS = 1 VAL_RATIO = 0.1 # 10% for validation LORA_R = 8 LORA_ALPHA = 16 LORA_DROPOUT = 0.0 PROJ_HIDDEN = 512 # intermediate MLP width TEMP = 0.05 OUTPUT_DIR = "stage1_simcse" GRAD_CLIP_NORM = 1.0 SIM_CLAMP_MIN = -10.0 SIM_CLAMP_MAX = 10.0 SEED = 42 os.makedirs(OUTPUT_DIR, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # enable TF32 and cuDNN autotuner on CUDA if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True # tokenizer + model tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) base_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, attn_implementation="eager" ) # LoRA on q_proj & v_proj lora_cfg = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=["q_proj", "v_proj"], ) model_lora = get_peft_model(base_model, lora_cfg) # Encoder + projection head now outputs hidden_size class GemmaSimCSE(nn.Module): def __init__(self, base): super().__init__() self.base = base hs = base.config.hidden_size self.proj = nn.Sequential( nn.Linear(hs, PROJ_HIDDEN), nn.ReLU(), nn.Linear(PROJ_HIDDEN, hs), ) def forward(self, input_ids, attention_mask): out = self.base( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) hidden = out.hidden_states[-1] # (B, T, H) emb = hidden.mean(dim=1) # mean-pooling emb = torch.nan_to_num(emb, nan=0.0, posinf=1e-6, neginf=-1e-6) z = self.proj(emb) # now (B, H) z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) return z / norm model = GemmaSimCSE(model_lora).to(device) torch.autograd.set_detect_anomaly(True) # Load and split dataset raw = load_dataset("text", data_files={"train": DATA_FILE}, split="train") raw = raw.filter(lambda x: x["text"].strip() != "") split = raw.train_test_split(test_size=VAL_RATIO, seed=SEED) train_ds = split["train"] val_ds = split["test"] # Tokenization def tokenize_fn(batch): toks = tokenizer( batch["text"], max_length=MAX_LENGTH, truncation=True, padding="max_length" ) return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]} train_ds = train_ds.map( tokenize_fn, batched=True, batch_size=1000, num_proc=4, remove_columns=["text"] ) val_ds = val_ds.map( tokenize_fn, batched=True, batch_size=1000, num_proc=4, remove_columns=["text"] ) train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) # Optimizer & scheduler optimizer = torch.optim.AdamW( model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY ) total_steps = len(train_loader) * NUM_EPOCHS scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps ) # Training + validation loop for epoch in range(1, NUM_EPOCHS + 1): # --- train --- model.train() train_loss = 0.0 for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch"): ids = batch["input_ids"].to(device) mask = batch["attention_mask"].to(device) emb1 = model(ids, mask) emb2 = model(ids, mask) emb = torch.cat([emb1, emb2], dim=0) sim = (emb @ emb.T) / TEMP sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) sim.fill_diagonal_(-1e9) B = emb1.size(0) labels = torch.cat([ torch.arange(B, device=device) + B, torch.arange(B, device=device) ], dim=0) loss = F.cross_entropy(sim, labels) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) optimizer.step() scheduler.step() train_loss += loss.item() avg_train_loss = train_loss / len(train_loader) print(f"Epoch {epoch} training complete. avg train loss: {avg_train_loss:.6f}") # --- validate --- model.eval() val_loss = 0.0 with torch.no_grad(): for batch in tqdm(val_loader, desc=f"Validate Epoch {epoch}", unit="batch"): ids = batch["input_ids"].to(device) mask = batch["attention_mask"].to(device) emb1 = model(ids, mask) emb2 = model(ids, mask) emb = torch.cat([emb1, emb2], dim=0) sim = (emb @ emb.T) / TEMP sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) sim.fill_diagonal_(-1e9) B = emb1.size(0) labels = torch.cat([ torch.arange(B, device=device) + B, torch.arange(B, device=device) ], dim=0) loss = F.cross_entropy(sim, labels) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch} validation complete. avg val loss: {avg_val_loss:.6f}") # save checkpoint ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}") model_lora.save_pretrained(ckpt_dir) tokenizer.save_pretrained(ckpt_dir) # save final model final_dir = os.path.join(OUTPUT_DIR, "final") os.makedirs(final_dir, exist_ok=True) model_lora.save_pretrained(final_dir) tokenizer.save_pretrained(final_dir) print("Training and validation complete. Final model saved to", final_dir) if __name__ == "__main__": freeze_support() main()