#!/usr/bin/env python3 import torch import numpy as np # ---- PyTorch 2.6+ checkpoint‑resume patches ------------------------------ # 1) allow numpy reconstruct in pickle torch.serialization.add_safe_globals([np.core.multiarray._reconstruct]) # 2) force torch.load (weights_only=False) for RNG‑state files _orig_torch_load = torch.load def _patched_load(*args, **kwargs): kwargs.setdefault("weights_only", False) return _orig_torch_load(*args, **kwargs) torch.load = _patched_load # -------------------------------------------------------------------------- """ Train mT5-large for query diversification with URL context, with resume-from-checkpoint and additional‑epochs support. """ import pandas as pd from transformers import ( MT5ForConditionalGeneration, MT5Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, ) from sklearn.model_selection import train_test_split import numpy as np2 # metrics helper from datasets import Dataset as HFDataset import wandb import os, json import gc # Added for memory cleanup # --------------------- CONSTANTS ------------------------------------------ MODEL_NAME = "google/mt5-large" MAX_INPUT_LENGTH = 32 MAX_TARGET_LENGTH = 16 BATCH_SIZE = 160 LEARNING_RATE = 5e-5 NUM_EPOCHS = 5 WARMUP_STEPS = 1000 GRAD_ACC_STEPS = 1 CACHE_DIR = "./tokenized_cache" OUTPUT_DIR = "./mt5-query-diversification" # -------------------------------------------------------------------------- def prepare_datasets(csv_path: str): df = pd.read_csv(csv_path) train_df, val_df = train_test_split(df, test_size=0.01, random_state=42) return train_df, val_df def compute_metrics(eval_preds, tok): preds, labels = eval_preds vs = len(tok) preds = np2.where(preds < vs, preds, tok.pad_token_id) preds = np2.where(preds >= 0, preds, tok.pad_token_id) labels = np2.where(labels != -100, labels, tok.pad_token_id) pred_str = tok.batch_decode(preds, skip_special_tokens=True) label_str = tok.batch_decode(labels, skip_special_tokens=True) exact = sum(p.strip() == l.strip() for p, l in zip(pred_str, label_str)) / len(pred_str) diff = np2.mean([len(p.split()) - len(l.split()) for p, l in zip(pred_str, label_str)]) return {"exact_match": exact, "avg_length_diff": diff} def list_checkpoints(out_dir): if not os.path.isdir(out_dir): return [] cps = [d for d in os.listdir(out_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(out_dir, d))] cps.sort(key=lambda x: int(x.split("-")[1])) return cps def select_checkpoint(cps): print("\nAvailable checkpoints:") for i, cp in enumerate(cps): print(f" [{i}] {cp}") print(" [n] Start training from scratch") sel = input(f"Select checkpoint [0-{len(cps)-1}, n]: ").strip() if sel.lower() in {"", "n"}: return None idx = int(sel) return cps[idx] if 0 <= idx < len(cps) else None def last_epoch(ckpt_path): ts = os.path.join(ckpt_path, "trainer_state.json") if not os.path.isfile(ts): return 0 with open(ts, "r", encoding="utf-8") as f: st = json.load(f) if "epoch" in st: return float(st["epoch"]) epochs = [e.get("epoch", 0) for e in st.get("log_history", []) if "epoch" in e] return max(epochs) if epochs else 0 def main(): # Clear GPU memory before starting torch.cuda.empty_cache() gc.collect() wandb.init(project="query-diversification", name="mt5-large-url-context") tok = MT5Tokenizer.from_pretrained(MODEL_NAME) # Load model with memory optimizations model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME) #model.gradient_checkpointing_enable() # Enable gradient checkpointing model.config.use_cache = False # Disable cache during training torch.cuda.empty_cache() # Clear cache after model loading # Print memory usage print(f"Model loaded. GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB") # ----- dataset -------------------------------------------------------- if os.path.exists(os.path.join(CACHE_DIR, "train")): train_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "train")) val_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "val")) else: tr_df, va_df = prepare_datasets("train.csv") train_ds = HFDataset.from_pandas(tr_df) val_ds = HFDataset.from_pandas(va_df) def tok_fn(ex): ins = [f"For URL: {u} diversify query: {q}" for u, q in zip(ex["url"], ex["query"])] tars = ex["fanout"] mi = tok(ins, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length") lbl = tok(text_target=tars, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length") lbl["input_ids"] = [[(x if x != tok.pad_token_id else -100) for x in l] for l in lbl["input_ids"]] mi["labels"] = lbl["input_ids"] return mi train_ds = train_ds.map(tok_fn, batched=True, num_proc=4) val_ds = val_ds.map(tok_fn, batched=True, num_proc=4) os.makedirs(CACHE_DIR, exist_ok=True) train_ds.save_to_disk(os.path.join(CACHE_DIR, "train")) val_ds.save_to_disk(os.path.join(CACHE_DIR, "val")) collator = DataCollatorForSeq2Seq(tok, model=model, padding=True) # ----- checkpoint handling ------------------------------------------- cps = list_checkpoints(OUTPUT_DIR) resume = None n_epochs = NUM_EPOCHS if cps: chosen = select_checkpoint(cps) if chosen: resume = os.path.join(OUTPUT_DIR, chosen) le = last_epoch(resume) print(f"\nResuming from {resume} (epoch {le})") if le >= NUM_EPOCHS: extra = int(input("How many extra epochs? [0]: ").strip() or "0") if extra == 0: print("No extra epochs. Exit.") return n_epochs = le + extra args = Seq2SeqTrainingArguments( output_dir=OUTPUT_DIR, eval_strategy="steps", eval_steps=5000, learning_rate=LEARNING_RATE, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACC_STEPS, num_train_epochs=n_epochs, warmup_steps=WARMUP_STEPS, weight_decay=0.01, logging_dir="./logs", logging_steps=1, save_steps=5000, save_total_limit=3, predict_with_generate=True, generation_max_length=MAX_TARGET_LENGTH, generation_num_beams=5, bf16=torch.cuda.is_available(), load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, report_to="wandb", gradient_checkpointing=True, optim="adafactor", # Changed from default AdamW - saves ~30% memory tf32=True, # Enable TF32 for RTX 4090 dataloader_pin_memory=False, # Reduce memory fragmentation full_determinism=False, # Allow non-deterministic ops for memory efficiency ) # Reduce number of beams during evaluation args.generation_num_beams = 3 # Instead of 5 trainer = Seq2SeqTrainer( model=model, args=args, data_collator=collator, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=tok, compute_metrics=lambda p: compute_metrics(p, tok), ) # Clear cache more aggressively during training original_train = trainer.train def train_with_memory_management(*args, **kwargs): # Clear cache every 100 steps if trainer.state.global_step % 100 == 0: torch.cuda.empty_cache() return original_train(*args, **kwargs) trainer.train = train_with_memory_management trainer.train(resume_from_checkpoint=resume) trainer.save_model("./mt5-query-diversification-final") tok.save_pretrained("./mt5-query-diversification-final") # ---- quick sanity generation ---------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device).eval() model.config.use_cache = True # Re-enable cache for inference samples = [("python.org", "python tutorial"), ("amazon.com", "laptop deals"), ("wikipedia.org", "machine learning")] for url, q in samples: txt = f"For URL: {url} diversify query: {q}" ins = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True) ins = {k: v.to(device) for k, v in ins.items()} out = model.generate(**ins, max_length=MAX_TARGET_LENGTH, num_beams=5, temperature=0.7, do_sample=True, top_p=0.9) print(f"\nInput: {txt}\nOutput: {tok.decode(out[0], skip_special_tokens=True)}") if __name__ == "__main__": main()