|
|
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
|
|
|
|
|
|
_orig_torch_load = torch.load
|
|
|
def _patched_load(*args, **kwargs):
|
|
|
kwargs.setdefault("weights_only", False)
|
|
|
return _orig_torch_load(*args, **kwargs)
|
|
|
torch.load = _patched_load
|
|
|
|
|
|
|
|
|
"""
|
|
|
Train mT5-large for query diversification with URL context,
|
|
|
with resume-from-checkpoint and additional‑epochs support.
|
|
|
"""
|
|
|
import pandas as pd
|
|
|
from transformers import (
|
|
|
MT5ForConditionalGeneration,
|
|
|
MT5Tokenizer,
|
|
|
Seq2SeqTrainer,
|
|
|
Seq2SeqTrainingArguments,
|
|
|
DataCollatorForSeq2Seq,
|
|
|
)
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
import numpy as np2
|
|
|
from datasets import Dataset as HFDataset
|
|
|
import wandb
|
|
|
import os, json
|
|
|
import gc
|
|
|
|
|
|
|
|
|
MODEL_NAME = "google/mt5-large"
|
|
|
MAX_INPUT_LENGTH = 32
|
|
|
MAX_TARGET_LENGTH = 16
|
|
|
BATCH_SIZE = 160
|
|
|
LEARNING_RATE = 5e-5
|
|
|
NUM_EPOCHS = 5
|
|
|
WARMUP_STEPS = 1000
|
|
|
GRAD_ACC_STEPS = 1
|
|
|
CACHE_DIR = "./tokenized_cache"
|
|
|
OUTPUT_DIR = "./mt5-query-diversification"
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_datasets(csv_path: str):
|
|
|
df = pd.read_csv(csv_path)
|
|
|
train_df, val_df = train_test_split(df, test_size=0.01, random_state=42)
|
|
|
return train_df, val_df
|
|
|
|
|
|
|
|
|
def compute_metrics(eval_preds, tok):
|
|
|
preds, labels = eval_preds
|
|
|
vs = len(tok)
|
|
|
preds = np2.where(preds < vs, preds, tok.pad_token_id)
|
|
|
preds = np2.where(preds >= 0, preds, tok.pad_token_id)
|
|
|
labels = np2.where(labels != -100, labels, tok.pad_token_id)
|
|
|
pred_str = tok.batch_decode(preds, skip_special_tokens=True)
|
|
|
label_str = tok.batch_decode(labels, skip_special_tokens=True)
|
|
|
exact = sum(p.strip() == l.strip() for p, l in zip(pred_str, label_str)) / len(pred_str)
|
|
|
diff = np2.mean([len(p.split()) - len(l.split()) for p, l in zip(pred_str, label_str)])
|
|
|
return {"exact_match": exact, "avg_length_diff": diff}
|
|
|
|
|
|
|
|
|
def list_checkpoints(out_dir):
|
|
|
if not os.path.isdir(out_dir):
|
|
|
return []
|
|
|
cps = [d for d in os.listdir(out_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(out_dir, d))]
|
|
|
cps.sort(key=lambda x: int(x.split("-")[1]))
|
|
|
return cps
|
|
|
|
|
|
|
|
|
def select_checkpoint(cps):
|
|
|
print("\nAvailable checkpoints:")
|
|
|
for i, cp in enumerate(cps):
|
|
|
print(f" [{i}] {cp}")
|
|
|
print(" [n] Start training from scratch")
|
|
|
sel = input(f"Select checkpoint [0-{len(cps)-1}, n]: ").strip()
|
|
|
if sel.lower() in {"", "n"}:
|
|
|
return None
|
|
|
idx = int(sel)
|
|
|
return cps[idx] if 0 <= idx < len(cps) else None
|
|
|
|
|
|
|
|
|
def last_epoch(ckpt_path):
|
|
|
ts = os.path.join(ckpt_path, "trainer_state.json")
|
|
|
if not os.path.isfile(ts):
|
|
|
return 0
|
|
|
with open(ts, "r", encoding="utf-8") as f:
|
|
|
st = json.load(f)
|
|
|
if "epoch" in st:
|
|
|
return float(st["epoch"])
|
|
|
epochs = [e.get("epoch", 0) for e in st.get("log_history", []) if "epoch" in e]
|
|
|
return max(epochs) if epochs else 0
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
|
|
|
wandb.init(project="query-diversification", name="mt5-large-url-context")
|
|
|
tok = MT5Tokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
|
model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
model.config.use_cache = False
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
print(f"Model loaded. GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(CACHE_DIR, "train")):
|
|
|
train_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "train"))
|
|
|
val_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "val"))
|
|
|
else:
|
|
|
tr_df, va_df = prepare_datasets("train.csv")
|
|
|
train_ds = HFDataset.from_pandas(tr_df)
|
|
|
val_ds = HFDataset.from_pandas(va_df)
|
|
|
|
|
|
def tok_fn(ex):
|
|
|
ins = [f"For URL: {u} diversify query: {q}" for u, q in zip(ex["url"], ex["query"])]
|
|
|
tars = ex["fanout"]
|
|
|
mi = tok(ins, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
|
|
|
lbl = tok(text_target=tars, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
|
|
|
lbl["input_ids"] = [[(x if x != tok.pad_token_id else -100) for x in l] for l in lbl["input_ids"]]
|
|
|
mi["labels"] = lbl["input_ids"]
|
|
|
return mi
|
|
|
|
|
|
train_ds = train_ds.map(tok_fn, batched=True, num_proc=4)
|
|
|
val_ds = val_ds.map(tok_fn, batched=True, num_proc=4)
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
train_ds.save_to_disk(os.path.join(CACHE_DIR, "train"))
|
|
|
val_ds.save_to_disk(os.path.join(CACHE_DIR, "val"))
|
|
|
|
|
|
collator = DataCollatorForSeq2Seq(tok, model=model, padding=True)
|
|
|
|
|
|
|
|
|
cps = list_checkpoints(OUTPUT_DIR)
|
|
|
resume = None
|
|
|
n_epochs = NUM_EPOCHS
|
|
|
if cps:
|
|
|
chosen = select_checkpoint(cps)
|
|
|
if chosen:
|
|
|
resume = os.path.join(OUTPUT_DIR, chosen)
|
|
|
le = last_epoch(resume)
|
|
|
print(f"\nResuming from {resume} (epoch {le})")
|
|
|
if le >= NUM_EPOCHS:
|
|
|
extra = int(input("How many extra epochs? [0]: ").strip() or "0")
|
|
|
if extra == 0:
|
|
|
print("No extra epochs. Exit.")
|
|
|
return
|
|
|
n_epochs = le + extra
|
|
|
|
|
|
args = Seq2SeqTrainingArguments(
|
|
|
output_dir=OUTPUT_DIR,
|
|
|
eval_strategy="steps",
|
|
|
eval_steps=5000,
|
|
|
learning_rate=LEARNING_RATE,
|
|
|
per_device_train_batch_size=BATCH_SIZE,
|
|
|
per_device_eval_batch_size=BATCH_SIZE,
|
|
|
gradient_accumulation_steps=GRAD_ACC_STEPS,
|
|
|
num_train_epochs=n_epochs,
|
|
|
warmup_steps=WARMUP_STEPS,
|
|
|
weight_decay=0.01,
|
|
|
logging_dir="./logs",
|
|
|
logging_steps=1,
|
|
|
save_steps=5000,
|
|
|
save_total_limit=3,
|
|
|
predict_with_generate=True,
|
|
|
generation_max_length=MAX_TARGET_LENGTH,
|
|
|
generation_num_beams=5,
|
|
|
bf16=torch.cuda.is_available(),
|
|
|
load_best_model_at_end=True,
|
|
|
metric_for_best_model="eval_loss",
|
|
|
greater_is_better=False,
|
|
|
report_to="wandb",
|
|
|
gradient_checkpointing=True,
|
|
|
optim="adafactor",
|
|
|
tf32=True,
|
|
|
dataloader_pin_memory=False,
|
|
|
full_determinism=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
args.generation_num_beams = 3
|
|
|
|
|
|
trainer = Seq2SeqTrainer(
|
|
|
model=model,
|
|
|
args=args,
|
|
|
data_collator=collator,
|
|
|
train_dataset=train_ds,
|
|
|
eval_dataset=val_ds,
|
|
|
tokenizer=tok,
|
|
|
compute_metrics=lambda p: compute_metrics(p, tok),
|
|
|
)
|
|
|
|
|
|
|
|
|
original_train = trainer.train
|
|
|
|
|
|
def train_with_memory_management(*args, **kwargs):
|
|
|
|
|
|
if trainer.state.global_step % 100 == 0:
|
|
|
torch.cuda.empty_cache()
|
|
|
return original_train(*args, **kwargs)
|
|
|
|
|
|
trainer.train = train_with_memory_management
|
|
|
|
|
|
trainer.train(resume_from_checkpoint=resume)
|
|
|
trainer.save_model("./mt5-query-diversification-final")
|
|
|
tok.save_pretrained("./mt5-query-diversification-final")
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
model.to(device).eval()
|
|
|
model.config.use_cache = True
|
|
|
|
|
|
samples = [("python.org", "python tutorial"),
|
|
|
("amazon.com", "laptop deals"),
|
|
|
("wikipedia.org", "machine learning")]
|
|
|
for url, q in samples:
|
|
|
txt = f"For URL: {url} diversify query: {q}"
|
|
|
ins = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
|
|
|
ins = {k: v.to(device) for k, v in ins.items()}
|
|
|
out = model.generate(**ins, max_length=MAX_TARGET_LENGTH,
|
|
|
num_beams=5, temperature=0.7,
|
|
|
do_sample=True, top_p=0.9)
|
|
|
print(f"\nInput: {txt}\nOutput: {tok.decode(out[0], skip_special_tokens=True)}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |