query-fanout / train.py
dejanseo's picture
Upload 2 files
5adc166 verified
#!/usr/bin/env python3
import torch
import numpy as np
# ---- PyTorch 2.6+ checkpoint‑resume patches ------------------------------
# 1) allow numpy reconstruct in pickle
torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
# 2) force torch.load (weights_only=False) for RNG‑state files
_orig_torch_load = torch.load
def _patched_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_torch_load(*args, **kwargs)
torch.load = _patched_load
# --------------------------------------------------------------------------
"""
Train mT5-large for query diversification with URL context,
with resume-from-checkpoint and additional‑epochs support.
"""
import pandas as pd
from transformers import (
MT5ForConditionalGeneration,
MT5Tokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
DataCollatorForSeq2Seq,
)
from sklearn.model_selection import train_test_split
import numpy as np2 # metrics helper
from datasets import Dataset as HFDataset
import wandb
import os, json
import gc # Added for memory cleanup
# --------------------- CONSTANTS ------------------------------------------
MODEL_NAME = "google/mt5-large"
MAX_INPUT_LENGTH = 32
MAX_TARGET_LENGTH = 16
BATCH_SIZE = 160
LEARNING_RATE = 5e-5
NUM_EPOCHS = 5
WARMUP_STEPS = 1000
GRAD_ACC_STEPS = 1
CACHE_DIR = "./tokenized_cache"
OUTPUT_DIR = "./mt5-query-diversification"
# --------------------------------------------------------------------------
def prepare_datasets(csv_path: str):
df = pd.read_csv(csv_path)
train_df, val_df = train_test_split(df, test_size=0.01, random_state=42)
return train_df, val_df
def compute_metrics(eval_preds, tok):
preds, labels = eval_preds
vs = len(tok)
preds = np2.where(preds < vs, preds, tok.pad_token_id)
preds = np2.where(preds >= 0, preds, tok.pad_token_id)
labels = np2.where(labels != -100, labels, tok.pad_token_id)
pred_str = tok.batch_decode(preds, skip_special_tokens=True)
label_str = tok.batch_decode(labels, skip_special_tokens=True)
exact = sum(p.strip() == l.strip() for p, l in zip(pred_str, label_str)) / len(pred_str)
diff = np2.mean([len(p.split()) - len(l.split()) for p, l in zip(pred_str, label_str)])
return {"exact_match": exact, "avg_length_diff": diff}
def list_checkpoints(out_dir):
if not os.path.isdir(out_dir):
return []
cps = [d for d in os.listdir(out_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(out_dir, d))]
cps.sort(key=lambda x: int(x.split("-")[1]))
return cps
def select_checkpoint(cps):
print("\nAvailable checkpoints:")
for i, cp in enumerate(cps):
print(f" [{i}] {cp}")
print(" [n] Start training from scratch")
sel = input(f"Select checkpoint [0-{len(cps)-1}, n]: ").strip()
if sel.lower() in {"", "n"}:
return None
idx = int(sel)
return cps[idx] if 0 <= idx < len(cps) else None
def last_epoch(ckpt_path):
ts = os.path.join(ckpt_path, "trainer_state.json")
if not os.path.isfile(ts):
return 0
with open(ts, "r", encoding="utf-8") as f:
st = json.load(f)
if "epoch" in st:
return float(st["epoch"])
epochs = [e.get("epoch", 0) for e in st.get("log_history", []) if "epoch" in e]
return max(epochs) if epochs else 0
def main():
# Clear GPU memory before starting
torch.cuda.empty_cache()
gc.collect()
wandb.init(project="query-diversification", name="mt5-large-url-context")
tok = MT5Tokenizer.from_pretrained(MODEL_NAME)
# Load model with memory optimizations
model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME)
#model.gradient_checkpointing_enable() # Enable gradient checkpointing
model.config.use_cache = False # Disable cache during training
torch.cuda.empty_cache() # Clear cache after model loading
# Print memory usage
print(f"Model loaded. GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
# ----- dataset --------------------------------------------------------
if os.path.exists(os.path.join(CACHE_DIR, "train")):
train_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "train"))
val_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "val"))
else:
tr_df, va_df = prepare_datasets("train.csv")
train_ds = HFDataset.from_pandas(tr_df)
val_ds = HFDataset.from_pandas(va_df)
def tok_fn(ex):
ins = [f"For URL: {u} diversify query: {q}" for u, q in zip(ex["url"], ex["query"])]
tars = ex["fanout"]
mi = tok(ins, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
lbl = tok(text_target=tars, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
lbl["input_ids"] = [[(x if x != tok.pad_token_id else -100) for x in l] for l in lbl["input_ids"]]
mi["labels"] = lbl["input_ids"]
return mi
train_ds = train_ds.map(tok_fn, batched=True, num_proc=4)
val_ds = val_ds.map(tok_fn, batched=True, num_proc=4)
os.makedirs(CACHE_DIR, exist_ok=True)
train_ds.save_to_disk(os.path.join(CACHE_DIR, "train"))
val_ds.save_to_disk(os.path.join(CACHE_DIR, "val"))
collator = DataCollatorForSeq2Seq(tok, model=model, padding=True)
# ----- checkpoint handling -------------------------------------------
cps = list_checkpoints(OUTPUT_DIR)
resume = None
n_epochs = NUM_EPOCHS
if cps:
chosen = select_checkpoint(cps)
if chosen:
resume = os.path.join(OUTPUT_DIR, chosen)
le = last_epoch(resume)
print(f"\nResuming from {resume} (epoch {le})")
if le >= NUM_EPOCHS:
extra = int(input("How many extra epochs? [0]: ").strip() or "0")
if extra == 0:
print("No extra epochs. Exit.")
return
n_epochs = le + extra
args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
eval_strategy="steps",
eval_steps=5000,
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACC_STEPS,
num_train_epochs=n_epochs,
warmup_steps=WARMUP_STEPS,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=1,
save_steps=5000,
save_total_limit=3,
predict_with_generate=True,
generation_max_length=MAX_TARGET_LENGTH,
generation_num_beams=5,
bf16=torch.cuda.is_available(),
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to="wandb",
gradient_checkpointing=True,
optim="adafactor", # Changed from default AdamW - saves ~30% memory
tf32=True, # Enable TF32 for RTX 4090
dataloader_pin_memory=False, # Reduce memory fragmentation
full_determinism=False, # Allow non-deterministic ops for memory efficiency
)
# Reduce number of beams during evaluation
args.generation_num_beams = 3 # Instead of 5
trainer = Seq2SeqTrainer(
model=model,
args=args,
data_collator=collator,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=tok,
compute_metrics=lambda p: compute_metrics(p, tok),
)
# Clear cache more aggressively during training
original_train = trainer.train
def train_with_memory_management(*args, **kwargs):
# Clear cache every 100 steps
if trainer.state.global_step % 100 == 0:
torch.cuda.empty_cache()
return original_train(*args, **kwargs)
trainer.train = train_with_memory_management
trainer.train(resume_from_checkpoint=resume)
trainer.save_model("./mt5-query-diversification-final")
tok.save_pretrained("./mt5-query-diversification-final")
# ---- quick sanity generation ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
model.config.use_cache = True # Re-enable cache for inference
samples = [("python.org", "python tutorial"),
("amazon.com", "laptop deals"),
("wikipedia.org", "machine learning")]
for url, q in samples:
txt = f"For URL: {url} diversify query: {q}"
ins = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
ins = {k: v.to(device) for k, v in ins.items()}
out = model.generate(**ins, max_length=MAX_TARGET_LENGTH,
num_beams=5, temperature=0.7,
do_sample=True, top_p=0.9)
print(f"\nInput: {txt}\nOutput: {tok.decode(out[0], skip_special_tokens=True)}")
if __name__ == "__main__":
main()