Upload 2 files
Browse files
app.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
import difflib
|
| 7 |
+
import torch
|
| 8 |
+
import streamlit as st
|
| 9 |
+
from typing import List, Tuple, Dict, Any
|
| 10 |
+
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
# ------------------ CONSTANTS ------------------
|
| 15 |
+
MODEL_PATH = "dejanseo/query-fanout"
|
| 16 |
+
MAX_INPUT_LENGTH = 32
|
| 17 |
+
MAX_TARGET_LENGTH = 16
|
| 18 |
+
PRESETS_FILE = "generation_presets.json"
|
| 19 |
+
# ------------------------------------------------
|
| 20 |
+
|
| 21 |
+
# ------------------ BUILT-IN PRESETS ------------------
|
| 22 |
+
DEFAULT_PRESET: Dict[str, Any] = {
|
| 23 |
+
"name": "Default",
|
| 24 |
+
"max_candidates": 50,
|
| 25 |
+
"temperature": 0.9,
|
| 26 |
+
"top_p": 0.95,
|
| 27 |
+
"no_repeat_ngram_size": 2,
|
| 28 |
+
"repetition_penalty": 1.1,
|
| 29 |
+
"seed": 42,
|
| 30 |
+
"sort_by": "logp/len",
|
| 31 |
+
"select_k": 20,
|
| 32 |
+
"mmr_lambda": 0.70,
|
| 33 |
+
"dup_ratio": 0.92,
|
| 34 |
+
"embedding_mode": "plain_both", # embedding toggle
|
| 35 |
+
}
|
| 36 |
+
DIVERSE_PRESET: Dict[str, Any] = {
|
| 37 |
+
"name": "Diverse",
|
| 38 |
+
"max_candidates": 200,
|
| 39 |
+
"temperature": 1.10,
|
| 40 |
+
"top_p": 0.98,
|
| 41 |
+
"no_repeat_ngram_size": 2,
|
| 42 |
+
"repetition_penalty": 1.10,
|
| 43 |
+
"seed": 42,
|
| 44 |
+
"sort_by": "logp/len",
|
| 45 |
+
"select_k": 20,
|
| 46 |
+
"mmr_lambda": 0.50,
|
| 47 |
+
"dup_ratio": 0.88,
|
| 48 |
+
"embedding_mode": "plain_both", # embedding toggle
|
| 49 |
+
}
|
| 50 |
+
BUILT_IN_PRESETS = {"Default": DEFAULT_PRESET, "Diverse": DIVERSE_PRESET}
|
| 51 |
+
|
| 52 |
+
# ------------------ PRESET IO ------------------
|
| 53 |
+
def load_user_presets() -> Dict[str, Dict[str, Any]]:
|
| 54 |
+
if not os.path.exists(PRESETS_FILE):
|
| 55 |
+
return {}
|
| 56 |
+
try:
|
| 57 |
+
with open(PRESETS_FILE, "r", encoding="utf-8") as f:
|
| 58 |
+
data = json.load(f)
|
| 59 |
+
if isinstance(data, dict):
|
| 60 |
+
cleaned: Dict[str, Dict[str, Any]] = {}
|
| 61 |
+
for k, v in data.items():
|
| 62 |
+
if isinstance(v, dict):
|
| 63 |
+
if "embedding_mode" not in v:
|
| 64 |
+
v["embedding_mode"] = "plain_both"
|
| 65 |
+
cleaned[k] = v
|
| 66 |
+
return cleaned
|
| 67 |
+
return {}
|
| 68 |
+
except Exception:
|
| 69 |
+
return {}
|
| 70 |
+
|
| 71 |
+
def save_user_preset(name: str, cfg: Dict[str, Any]) -> None:
|
| 72 |
+
data = load_user_presets()
|
| 73 |
+
data[name] = dict(cfg, name=name)
|
| 74 |
+
with open(PRESETS_FILE, "w", encoding="utf-8") as f:
|
| 75 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 76 |
+
|
| 77 |
+
def all_presets() -> Dict[str, Dict[str, Any]]:
|
| 78 |
+
out: Dict[str, Dict[str, Any]] = {}
|
| 79 |
+
out.update(BUILT_IN_PRESETS)
|
| 80 |
+
out.update(load_user_presets())
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
# ------------------ MODEL LOADING ------------------
|
| 84 |
+
@st.cache_resource
|
| 85 |
+
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
|
| 86 |
+
tok = MT5Tokenizer.from_pretrained(MODEL_PATH)
|
| 87 |
+
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH)
|
| 88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
model.to(device).eval()
|
| 90 |
+
return tok, model, device
|
| 91 |
+
|
| 92 |
+
# ------------------ GENERATION HELPERS ------------------
|
| 93 |
+
def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device):
|
| 94 |
+
txt = f"For URL: {url} diversify query: {query}"
|
| 95 |
+
enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
|
| 96 |
+
return {k: v.to(device) for k, v in enc.items()}, txt
|
| 97 |
+
|
| 98 |
+
def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]:
|
| 99 |
+
return tok.batch_decode(seqs, skip_special_tokens=True)
|
| 100 |
+
|
| 101 |
+
def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
|
| 102 |
+
if not hasattr(gen, "scores") or not gen.scores:
|
| 103 |
+
return [float("nan")] * gen.sequences.size(0)
|
| 104 |
+
scores = gen.scores
|
| 105 |
+
seqs = gen.sequences
|
| 106 |
+
nseq = seqs.size(0)
|
| 107 |
+
eos_id = tok.eos_token_id if tok.eos_token_id is not None else 1
|
| 108 |
+
pad_id = tok.pad_token_id
|
| 109 |
+
sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
|
| 110 |
+
count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
|
| 111 |
+
finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device)
|
| 112 |
+
for t in range(len(scores)):
|
| 113 |
+
step_logits = scores[t]
|
| 114 |
+
step_logprobs = F.log_softmax(step_logits, dim=-1)
|
| 115 |
+
step_tok = seqs[:, t + 1]
|
| 116 |
+
valid = step_tok.ne(pad_id) & (~finished)
|
| 117 |
+
if valid.any():
|
| 118 |
+
gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1)
|
| 119 |
+
sum_logp += torch.where(valid, gather, torch.zeros_like(gather))
|
| 120 |
+
count += valid.float()
|
| 121 |
+
finished |= step_tok.eq(eos_id)
|
| 122 |
+
count = torch.where(count.eq(0), torch.ones_like(count), count)
|
| 123 |
+
return [(lp / c).item() for lp, c in zip(sum_logp, count)]
|
| 124 |
+
|
| 125 |
+
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p,
|
| 126 |
+
no_repeat_ngram_size=0, repetition_penalty=1.0):
|
| 127 |
+
kwargs = dict(
|
| 128 |
+
max_length=MAX_TARGET_LENGTH,
|
| 129 |
+
do_sample=True,
|
| 130 |
+
temperature=temperature,
|
| 131 |
+
top_p=top_p,
|
| 132 |
+
num_return_sequences=top_n,
|
| 133 |
+
return_dict_in_generate=True,
|
| 134 |
+
output_scores=True,
|
| 135 |
+
)
|
| 136 |
+
if no_repeat_ngram_size > 0:
|
| 137 |
+
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
| 138 |
+
if repetition_penalty != 1.0:
|
| 139 |
+
kwargs["repetition_penalty"] = float(repetition_penalty)
|
| 140 |
+
gen = model.generate(**inputs, **kwargs)
|
| 141 |
+
texts = decode_sequences(tok, gen.sequences)
|
| 142 |
+
scores = avg_logprobs_from_generate(tok, gen)
|
| 143 |
+
return texts, scores
|
| 144 |
+
|
| 145 |
+
def get_encoder_embedding(tok, model, text: str, device: torch.device):
|
| 146 |
+
inputs = tok(text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
enc_out = model.get_encoder()(**inputs)
|
| 149 |
+
return enc_out.last_hidden_state.mean(dim=1).squeeze(0)
|
| 150 |
+
|
| 151 |
+
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 152 |
+
return float(F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item())
|
| 153 |
+
|
| 154 |
+
def fmt_score(x: float) -> str:
|
| 155 |
+
if x != x or math.isinf(x):
|
| 156 |
+
return "n/a"
|
| 157 |
+
p = math.exp(x)
|
| 158 |
+
return f"logp/len={x:.3f} | p≈{p:.3f}"
|
| 159 |
+
|
| 160 |
+
# ------------------ RERANK (MMR + DEDUP) ------------------
|
| 161 |
+
def normalize_text(s: str) -> str:
|
| 162 |
+
return " ".join(s.strip().lower().split())
|
| 163 |
+
|
| 164 |
+
def is_near_duplicate(a: str, b: str, ratio_thresh: float) -> bool:
|
| 165 |
+
return difflib.SequenceMatcher(None, normalize_text(a), normalize_text(b)).ratio() >= ratio_thresh
|
| 166 |
+
|
| 167 |
+
def mmr_select(
|
| 168 |
+
cand_texts: List[str],
|
| 169 |
+
cand_embs: List[torch.Tensor],
|
| 170 |
+
query_emb: torch.Tensor,
|
| 171 |
+
k: int,
|
| 172 |
+
lambd: float
|
| 173 |
+
) -> List[int]:
|
| 174 |
+
rel = [cosine_similarity(query_emb, e) for e in cand_embs]
|
| 175 |
+
selected: List[int] = []
|
| 176 |
+
available = set(range(len(cand_texts)))
|
| 177 |
+
while available and len(selected) < k:
|
| 178 |
+
if not selected:
|
| 179 |
+
idx = max(available, key=lambda i: rel[i])
|
| 180 |
+
selected.append(idx)
|
| 181 |
+
available.remove(idx)
|
| 182 |
+
continue
|
| 183 |
+
best_idx = None
|
| 184 |
+
best_score = -1e9
|
| 185 |
+
for i in list(available):
|
| 186 |
+
max_sim_to_sel = max(cosine_similarity(cand_embs[i], cand_embs[j]) for j in selected)
|
| 187 |
+
score = lambd * rel[i] - (1.0 - lambd) * max_sim_to_sel
|
| 188 |
+
if score > best_score:
|
| 189 |
+
best_score = score
|
| 190 |
+
best_idx = i
|
| 191 |
+
selected.append(best_idx)
|
| 192 |
+
available.remove(best_idx)
|
| 193 |
+
return selected
|
| 194 |
+
|
| 195 |
+
def distinct_n(texts: List[str], n: int) -> float:
|
| 196 |
+
total = 0
|
| 197 |
+
uniq = set()
|
| 198 |
+
for t in texts:
|
| 199 |
+
toks = t.strip().split()
|
| 200 |
+
if len(toks) < n:
|
| 201 |
+
continue
|
| 202 |
+
for i in range(len(toks) - n + 1):
|
| 203 |
+
total += 1
|
| 204 |
+
uniq.add(tuple(toks[i:i+n]))
|
| 205 |
+
return (len(uniq) / total) if total > 0 else 0.0
|
| 206 |
+
|
| 207 |
+
# ------------------ EMBEDDING MODE HELPERS (TOGGLE) ------------------
|
| 208 |
+
def embed_text_for_mode(url: str, text: str, mode: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
mode:
|
| 211 |
+
- "plain_both": embed raw text
|
| 212 |
+
- "template_both": embed with the same instruction template used for inputs
|
| 213 |
+
"""
|
| 214 |
+
if mode == "template_both":
|
| 215 |
+
templated = f"For URL: {url} diversify query: {text}"
|
| 216 |
+
return get_encoder_embedding(tok, model, templated, device)
|
| 217 |
+
return get_encoder_embedding(tok, model, text, device)
|
| 218 |
+
|
| 219 |
+
# ------------------ TESTING HELPERS (DEFINED) ------------------
|
| 220 |
+
def single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty):
|
| 221 |
+
kwargs = dict(
|
| 222 |
+
max_length=MAX_TARGET_LENGTH,
|
| 223 |
+
do_sample=False,
|
| 224 |
+
num_beams=num_beams,
|
| 225 |
+
num_return_sequences=1,
|
| 226 |
+
)
|
| 227 |
+
if no_repeat_ngram_size > 0:
|
| 228 |
+
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
| 229 |
+
if repetition_penalty != 1.0:
|
| 230 |
+
kwargs["repetition_penalty"] = float(repetition_penalty)
|
| 231 |
+
out = model.generate(**inputs, **kwargs)
|
| 232 |
+
return decode_sequences(tok, out)[0]
|
| 233 |
+
|
| 234 |
+
def topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty):
|
| 235 |
+
kwargs = dict(
|
| 236 |
+
max_length=MAX_TARGET_LENGTH,
|
| 237 |
+
do_sample=False,
|
| 238 |
+
num_beams=max(num_beams, top_n),
|
| 239 |
+
num_return_sequences=top_n,
|
| 240 |
+
return_dict_in_generate=True,
|
| 241 |
+
output_scores=True,
|
| 242 |
+
)
|
| 243 |
+
if no_repeat_ngram_size > 0:
|
| 244 |
+
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
| 245 |
+
if repetition_penalty != 1.0:
|
| 246 |
+
kwargs["repetition_penalty"] = float(repetition_penalty)
|
| 247 |
+
gen = model.generate(**inputs, **kwargs)
|
| 248 |
+
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
| 249 |
+
|
| 250 |
+
def topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
|
| 251 |
+
kwargs = dict(
|
| 252 |
+
max_length=MAX_TARGET_LENGTH,
|
| 253 |
+
do_sample=True,
|
| 254 |
+
temperature=temperature,
|
| 255 |
+
top_p=top_p,
|
| 256 |
+
num_return_sequences=top_n,
|
| 257 |
+
return_dict_in_generate=True,
|
| 258 |
+
output_scores=True,
|
| 259 |
+
)
|
| 260 |
+
if no_repeat_ngram_size > 0:
|
| 261 |
+
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
| 262 |
+
if repetition_penalty != 1.0:
|
| 263 |
+
kwargs["repetition_penalty"] = float(repetition_penalty)
|
| 264 |
+
gen = model.generate(**inputs, **kwargs)
|
| 265 |
+
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
| 266 |
+
|
| 267 |
+
def score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
|
| 268 |
+
texts, scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
|
| 269 |
+
order = sorted(range(len(texts)), key=lambda i: scores[i], reverse=True)
|
| 270 |
+
return [texts[i] for i in order], [scores[i] for i in order]
|
| 271 |
+
|
| 272 |
+
def diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty):
|
| 273 |
+
num_beams = max(num_beams, num_beam_groups * max(1, top_n // max(1, num_beam_groups)))
|
| 274 |
+
if num_beams % num_beam_groups != 0:
|
| 275 |
+
num_beams = (num_beams // num_beam_groups + 1) * num_beam_groups
|
| 276 |
+
top_n = min(top_n, num_beams)
|
| 277 |
+
kwargs = dict(
|
| 278 |
+
max_length=MAX_TARGET_LENGTH,
|
| 279 |
+
do_sample=False,
|
| 280 |
+
num_beams=num_beams,
|
| 281 |
+
num_beam_groups=num_beam_groups,
|
| 282 |
+
diversity_penalty=diversity_penalty,
|
| 283 |
+
num_return_sequences=top_n,
|
| 284 |
+
return_dict_in_generate=True,
|
| 285 |
+
output_scores=True,
|
| 286 |
+
)
|
| 287 |
+
if no_repeat_ngram_size > 0:
|
| 288 |
+
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
| 289 |
+
if repetition_penalty != 1.0:
|
| 290 |
+
kwargs["repetition_penalty"] = float(repetition_penalty)
|
| 291 |
+
gen = model.generate(**inputs, **kwargs)
|
| 292 |
+
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
| 293 |
+
|
| 294 |
+
def token_by_token_probabilities(tok, model, device, inputs):
|
| 295 |
+
gen = model.generate(
|
| 296 |
+
**inputs,
|
| 297 |
+
max_length=MAX_TARGET_LENGTH,
|
| 298 |
+
do_sample=False,
|
| 299 |
+
num_beams=1,
|
| 300 |
+
return_dict_in_generate=True,
|
| 301 |
+
output_scores=True,
|
| 302 |
+
)
|
| 303 |
+
seq = gen.sequences[0]
|
| 304 |
+
token_ids = seq.tolist()
|
| 305 |
+
per_token = []
|
| 306 |
+
for t, logits in enumerate(gen.scores):
|
| 307 |
+
tok_id = token_ids[t + 1]
|
| 308 |
+
probs = F.softmax(logits[0], dim=-1)
|
| 309 |
+
prob = float(probs[tok_id].detach().cpu())
|
| 310 |
+
sp_token = tok.convert_ids_to_tokens([tok_id])[0]
|
| 311 |
+
per_token.append((sp_token, prob))
|
| 312 |
+
return per_token
|
| 313 |
+
|
| 314 |
+
# ------------------ STREAMLIT APP ------------------
|
| 315 |
+
st.set_page_config(page_title="Query Fanout – Generation & Testing", layout="wide")
|
| 316 |
+
tok, model, device = load_model()
|
| 317 |
+
tab1, tab2 = st.tabs(["Generation", "Testing"])
|
| 318 |
+
|
| 319 |
+
# ----------- COMMON GENERATION RUNNER -----------
|
| 320 |
+
def run_generation(url: str, query: str, cfg: Dict[str, Any], show_save_controls: bool) -> None:
|
| 321 |
+
torch.manual_seed(int(cfg["seed"]))
|
| 322 |
+
if torch.cuda.is_available():
|
| 323 |
+
torch.cuda.manual_seed_all(int(cfg["seed"]))
|
| 324 |
+
start_ts = time.time()
|
| 325 |
+
inputs, prompt_txt = build_inputs(tok, url, query, device)
|
| 326 |
+
embedding_mode = cfg.get("embedding_mode", "plain_both")
|
| 327 |
+
orig_emb = embed_text_for_mode(url, query, embedding_mode, tok, model, device)
|
| 328 |
+
|
| 329 |
+
texts, scores = sampling_generate(
|
| 330 |
+
tok, model, device, inputs,
|
| 331 |
+
top_n=int(cfg["max_candidates"]) * 2,
|
| 332 |
+
temperature=float(cfg["temperature"]),
|
| 333 |
+
top_p=float(cfg["top_p"]),
|
| 334 |
+
no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
|
| 335 |
+
repetition_penalty=float(cfg["repetition_penalty"]),
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
seen = set()
|
| 339 |
+
enriched: List[Dict[str, Any]] = []
|
| 340 |
+
for txt, sc in zip(texts, scores):
|
| 341 |
+
norm = normalize_text(txt)
|
| 342 |
+
if norm not in seen:
|
| 343 |
+
seen.add(norm)
|
| 344 |
+
cand_emb = embed_text_for_mode(url, txt, embedding_mode, tok, model, device)
|
| 345 |
+
cos_sim = cosine_similarity(orig_emb, cand_emb)
|
| 346 |
+
enriched.append({"logp/len": sc, "p≈": math.exp(sc), "cos≈": cos_sim, "text": txt, "emb": cand_emb})
|
| 347 |
+
if len(enriched) >= int(cfg["max_candidates"]):
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
if cfg["sort_by"] == "logp/len":
|
| 351 |
+
enriched.sort(key=lambda x: x["logp/len"], reverse=True)
|
| 352 |
+
else:
|
| 353 |
+
enriched.sort(key=lambda x: x["cos≈"], reverse=True)
|
| 354 |
+
|
| 355 |
+
df = pd.DataFrame([{"logp/len": e["logp/len"], "p≈": e["p≈"], "cos≈": e["cos≈"], "text": e["text"]} for e in enriched])
|
| 356 |
+
df.index = range(1, len(df) + 1)
|
| 357 |
+
elapsed = time.time() - start_ts
|
| 358 |
+
st.caption(f"Generated {len(df)} unique fan-out queries in {elapsed:.2f}s")
|
| 359 |
+
st.dataframe(df, use_container_width=True)
|
| 360 |
+
|
| 361 |
+
filtered: List[Dict[str, Any]] = []
|
| 362 |
+
for cand in enriched:
|
| 363 |
+
keep = True
|
| 364 |
+
for kept in filtered:
|
| 365 |
+
if is_near_duplicate(cand["text"], kept["text"], float(cfg["dup_ratio"])):
|
| 366 |
+
keep = False
|
| 367 |
+
break
|
| 368 |
+
if keep:
|
| 369 |
+
filtered.append(cand)
|
| 370 |
+
|
| 371 |
+
if filtered:
|
| 372 |
+
k_eff = min(int(cfg["select_k"]), len(filtered))
|
| 373 |
+
cand_texts = [c["text"] for c in filtered]
|
| 374 |
+
cand_embs = [c["emb"] for c in filtered]
|
| 375 |
+
sel_idx = mmr_select(cand_texts, cand_embs, orig_emb, k=k_eff, lambd=float(cfg["mmr_lambda"]))
|
| 376 |
+
selected = [filtered[i] for i in sel_idx]
|
| 377 |
+
|
| 378 |
+
st.markdown("### Reranked Top-K (MMR + Dedup)")
|
| 379 |
+
st.caption(f"Mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}")
|
| 380 |
+
df_sel = pd.DataFrame(
|
| 381 |
+
[{"rank": i+1, "cos≈": s["cos≈"], "text": s["text"]} for i, s in enumerate(selected)]
|
| 382 |
+
)
|
| 383 |
+
df_sel.set_index("rank", inplace=True)
|
| 384 |
+
st.dataframe(df_sel, use_container_width=True)
|
| 385 |
+
|
| 386 |
+
sel_texts = [s["text"] for s in selected]
|
| 387 |
+
d1 = distinct_n(sel_texts, 1)
|
| 388 |
+
d2 = distinct_n(sel_texts, 2)
|
| 389 |
+
st.caption(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f} on selected {len(sel_texts)}")
|
| 390 |
+
|
| 391 |
+
combined_output = [f"Input: {prompt_txt}"]
|
| 392 |
+
for rank, row in df.iterrows():
|
| 393 |
+
combined_output.append(f"#{rank} logp/len={row['logp/len']:.3f} | p≈{row['p≈']:.3f} | cos≈{row['cos≈']:.3f} — {row['text']}")
|
| 394 |
+
block = "\n".join(combined_output)
|
| 395 |
+
|
| 396 |
+
st.markdown("### Copy/Paste Summary")
|
| 397 |
+
st.code(block, language="text")
|
| 398 |
+
with open("generation_output.txt", "w", encoding="utf-8") as f:
|
| 399 |
+
f.write(block)
|
| 400 |
+
f.write("\n\n[MMR selection]\n")
|
| 401 |
+
f.write(f"mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}\n")
|
| 402 |
+
for i, s in enumerate(selected, 1):
|
| 403 |
+
f.write(f"#{i} cos≈={s['cos≈']:.3f} — {s['text']}\n")
|
| 404 |
+
f.write(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f}\n")
|
| 405 |
+
with open("generation_selected.txt", "w", encoding="utf-8") as f:
|
| 406 |
+
for i, s in enumerate(selected, 1):
|
| 407 |
+
f.write(f"{i}\t{s['text']}\n")
|
| 408 |
+
st.success("Saved summary to generation_output.txt and selection to generation_selected.txt")
|
| 409 |
+
else:
|
| 410 |
+
st.warning("All candidates filtered as near-duplicates. Lower the duplicate threshold or increase max candidates.")
|
| 411 |
+
|
| 412 |
+
if show_save_controls:
|
| 413 |
+
st.markdown("---")
|
| 414 |
+
with st.form(key="save_preset_form"):
|
| 415 |
+
new_name = st.text_input("Preset Name", value="", placeholder="Enter a preset name")
|
| 416 |
+
submitted = st.form_submit_button("Save as Preset")
|
| 417 |
+
if submitted:
|
| 418 |
+
if not new_name.strip():
|
| 419 |
+
st.error("Preset name cannot be empty.")
|
| 420 |
+
elif new_name in BUILT_IN_PRESETS:
|
| 421 |
+
st.error("Cannot overwrite built-in presets (Default, Diverse). Use a different name.")
|
| 422 |
+
else:
|
| 423 |
+
to_save = {
|
| 424 |
+
"max_candidates": int(cfg["max_candidates"]),
|
| 425 |
+
"temperature": float(cfg["temperature"]),
|
| 426 |
+
"top_p": float(cfg["top_p"]),
|
| 427 |
+
"no_repeat_ngram_size": int(cfg["no_repeat_ngram_size"]),
|
| 428 |
+
"repetition_penalty": float(cfg["repetition_penalty"]),
|
| 429 |
+
"seed": int(cfg["seed"]),
|
| 430 |
+
"sort_by": str(cfg["sort_by"]),
|
| 431 |
+
"select_k": int(cfg["select_k"]),
|
| 432 |
+
"mmr_lambda": float(cfg["mmr_lambda"]),
|
| 433 |
+
"dup_ratio": float(cfg["dup_ratio"]),
|
| 434 |
+
"embedding_mode": str(cfg.get("embedding_mode", "plain_both")),
|
| 435 |
+
}
|
| 436 |
+
save_user_preset(new_name.strip(), to_save)
|
| 437 |
+
st.success(f"Preset '{new_name.strip()}' saved.")
|
| 438 |
+
|
| 439 |
+
# ----------- TAB 1: GENERATION -----------
|
| 440 |
+
with tab1:
|
| 441 |
+
st.header("Generation Mode — Large Diverse Fan-out")
|
| 442 |
+
url = st.text_input("URL", value="airbnb.com", key="gen_url")
|
| 443 |
+
query = st.text_input("Query", value="airbnb reviews", key="gen_query")
|
| 444 |
+
|
| 445 |
+
subtab_presets, subtab_manual = st.tabs(["Presets", "Manual Settings"])
|
| 446 |
+
|
| 447 |
+
# ----- Presets sub-tab -----
|
| 448 |
+
with subtab_presets:
|
| 449 |
+
all_p = all_presets()
|
| 450 |
+
preset_names = list(all_p.keys())
|
| 451 |
+
preset_choice = st.selectbox(
|
| 452 |
+
"Choose a preset",
|
| 453 |
+
preset_names,
|
| 454 |
+
index=preset_names.index("Default") if "Default" in preset_names else 0
|
| 455 |
+
)
|
| 456 |
+
sel = dict(all_p[preset_choice]) # copy to allow local edits
|
| 457 |
+
emb_mode_preset = st.selectbox(
|
| 458 |
+
"Embedding mode for reranking",
|
| 459 |
+
options=["plain_both", "template_both"],
|
| 460 |
+
index=0 if sel.get("embedding_mode", "plain_both") == "plain_both" else 1,
|
| 461 |
+
help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
|
| 462 |
+
)
|
| 463 |
+
sel["embedding_mode"] = emb_mode_preset
|
| 464 |
+
|
| 465 |
+
cols = st.columns(3)
|
| 466 |
+
with cols[0]:
|
| 467 |
+
st.write(f"**Max candidates:** {sel['max_candidates']}")
|
| 468 |
+
st.write(f"**Temperature:** {sel['temperature']}")
|
| 469 |
+
st.write(f"**Top-p:** {sel['top_p']}")
|
| 470 |
+
st.write(f"**Seed:** {sel['seed']}")
|
| 471 |
+
with cols[1]:
|
| 472 |
+
st.write(f"**No repeat n-gram:** {sel['no_repeat_ngram_size']}")
|
| 473 |
+
st.write(f"**Repetition penalty:** {sel['repetition_penalty']}")
|
| 474 |
+
st.write(f"**Sort by:** {sel['sort_by']}")
|
| 475 |
+
with cols[2]:
|
| 476 |
+
st.write(f"**Select K:** {sel['select_k']}")
|
| 477 |
+
st.write(f"**λ (MMR):** {sel['mmr_lambda']}")
|
| 478 |
+
st.write(f"**Dup ratio:** {sel['dup_ratio']}")
|
| 479 |
+
st.write(f"**Embedding:** {sel['embedding_mode']}")
|
| 480 |
+
|
| 481 |
+
run_gen_preset = st.button("Generate Fan-out (Preset)", key="run_gen_preset")
|
| 482 |
+
if run_gen_preset:
|
| 483 |
+
run_generation(url, query, sel, show_save_controls=False)
|
| 484 |
+
|
| 485 |
+
# ----- Manual Settings sub-tab -----
|
| 486 |
+
with subtab_manual:
|
| 487 |
+
base = DEFAULT_PRESET
|
| 488 |
+
max_candidates = st.number_input("Max candidates", min_value=1, max_value=200, value=int(base["max_candidates"]), step=1)
|
| 489 |
+
temperature = st.number_input("Temperature", min_value=0.1, max_value=2.0, value=float(base["temperature"]), step=0.1)
|
| 490 |
+
top_p = st.number_input("Top-p", min_value=0.1, max_value=1.0, value=float(base["top_p"]), step=0.01)
|
| 491 |
+
no_repeat_ngram_size = st.number_input("No repeat n-gram size (0=off)", min_value=0, max_value=10, value=int(base["no_repeat_ngram_size"]), step=1)
|
| 492 |
+
repetition_penalty = st.number_input("Repetition penalty (1.0=off)", min_value=1.0, max_value=2.0, value=float(base["repetition_penalty"]), step=0.1)
|
| 493 |
+
seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=int(base["seed"]), step=1, key="gen_seed_manual")
|
| 494 |
+
sort_by = st.selectbox("Sort by", ["logp/len", "cosine similarity"], index=0)
|
| 495 |
+
|
| 496 |
+
st.subheader("Diversity-aware Reranking (MMR on internal encoder vectors)")
|
| 497 |
+
embedding_mode_manual = st.selectbox(
|
| 498 |
+
"Embedding mode",
|
| 499 |
+
options=["plain_both", "template_both"],
|
| 500 |
+
index=0,
|
| 501 |
+
help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
|
| 502 |
+
)
|
| 503 |
+
select_k = st.number_input("Select top K after rerank", min_value=1, max_value=200, value=int(base["select_k"]), step=1)
|
| 504 |
+
mmr_lambda = st.number_input("MMR relevance weight λ (higher = more on-topic, lower = more diverse)", min_value=0.0, max_value=1.0, value=float(base["mmr_lambda"]), step=0.01)
|
| 505 |
+
dup_ratio = st.number_input("Near-duplicate threshold (SequenceMatcher ratio)", min_value=0.0, max_value=1.0, value=float(base["dup_ratio"]), step=0.01)
|
| 506 |
+
|
| 507 |
+
run_gen_manual = st.button("Generate Fan-out (Manual Settings)", key="run_gen_manual")
|
| 508 |
+
if run_gen_manual:
|
| 509 |
+
cfg = {
|
| 510 |
+
"max_candidates": int(max_candidates),
|
| 511 |
+
"temperature": float(temperature),
|
| 512 |
+
"top_p": float(top_p),
|
| 513 |
+
"no_repeat_ngram_size": int(no_repeat_ngram_size),
|
| 514 |
+
"repetition_penalty": float(repetition_penalty),
|
| 515 |
+
"seed": int(seed_value),
|
| 516 |
+
"sort_by": str(sort_by),
|
| 517 |
+
"select_k": int(select_k),
|
| 518 |
+
"mmr_lambda": float(mmr_lambda),
|
| 519 |
+
"dup_ratio": float(dup_ratio),
|
| 520 |
+
"embedding_mode": str(embedding_mode_manual),
|
| 521 |
+
}
|
| 522 |
+
run_generation(url, query, cfg, show_save_controls=True)
|
| 523 |
+
|
| 524 |
+
# ----------- TAB 2: TESTING -----------
|
| 525 |
+
with tab2:
|
| 526 |
+
st.header("Testing Mode — Method Comparison")
|
| 527 |
+
url = st.text_input("URL", value="airbnb.com", key="test_url")
|
| 528 |
+
query = st.text_input("Query", value="airbnb reviews", key="test_query")
|
| 529 |
+
num_beams = st.number_input("num_beams", min_value=1, max_value=20, value=5, step=1)
|
| 530 |
+
top_n = st.number_input("top_n", min_value=1, max_value=20, value=5, step=1)
|
| 531 |
+
temperature = st.number_input("temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
|
| 532 |
+
top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.9, step=0.05)
|
| 533 |
+
num_beam_groups = st.number_input("num_beam_groups", min_value=1, max_value=20, value=5, step=1)
|
| 534 |
+
diversity_penalty = st.number_input("diversity_penalty", min_value=0.0, max_value=5.0, value=1.0, step=0.1)
|
| 535 |
+
no_repeat_ngram_size = st.number_input("no_repeat_ngram_size", min_value=0, max_value=10, value=0, step=1)
|
| 536 |
+
repetition_penalty = st.number_input("repetition_penalty", min_value=1.0, max_value=2.0, value=1.0, step=0.1)
|
| 537 |
+
seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=42, step=1, key="test_seed")
|
| 538 |
+
run_test = st.button("Run Comparison", key="run_test")
|
| 539 |
+
|
| 540 |
+
if run_test:
|
| 541 |
+
torch.manual_seed(int(seed_value))
|
| 542 |
+
if torch.cuda.is_available():
|
| 543 |
+
torch.cuda.manual_seed_all(int(seed_value))
|
| 544 |
+
inputs, prompt_txt = build_inputs(tok, url, query, device)
|
| 545 |
+
|
| 546 |
+
best_det = single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty)
|
| 547 |
+
topn_beam_txts, topn_beam_scores = topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty)
|
| 548 |
+
topn_samp_txts, topn_samp_scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
|
| 549 |
+
ranked_txts, ranked_scores = score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
|
| 550 |
+
div_txts, div_scores = diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty)
|
| 551 |
+
per_token = token_by_token_probabilities(tok, model, device, inputs)
|
| 552 |
+
|
| 553 |
+
combined_output = [f"Input: {prompt_txt}",
|
| 554 |
+
"\n[1] Single best (deterministic beam)", best_det,
|
| 555 |
+
"\n[2] Top-N (beam)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_beam_txts, topn_beam_scores))] + \
|
| 556 |
+
["\n[3] Top-N (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_samp_txts, topn_samp_scores))] + \
|
| 557 |
+
["\n[4] Score-ranked (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(ranked_txts, ranked_scores))] + \
|
| 558 |
+
["\n[5] Diverse beams"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(div_txts, div_scores))] + \
|
| 559 |
+
["\n[6] Token-by-token probabilities (greedy)"] + [f"{t} — {p:.4f}" for t, p in per_token]
|
| 560 |
+
|
| 561 |
+
st.markdown("### Copy/Paste Summary")
|
| 562 |
+
st.code("\n".join(combined_output), language="text")
|
| 563 |
+
with open("testing_output.txt", "w", encoding="utf-8") as f:
|
| 564 |
+
f.write("\n".join(combined_output))
|
| 565 |
+
st.success("Saved summary to testing_output.txt")
|
train.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# ---- PyTorch 2.6+ checkpoint‑resume patches ------------------------------
|
| 6 |
+
# 1) allow numpy reconstruct in pickle
|
| 7 |
+
torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
|
| 8 |
+
# 2) force torch.load (weights_only=False) for RNG‑state files
|
| 9 |
+
_orig_torch_load = torch.load
|
| 10 |
+
def _patched_load(*args, **kwargs):
|
| 11 |
+
kwargs.setdefault("weights_only", False)
|
| 12 |
+
return _orig_torch_load(*args, **kwargs)
|
| 13 |
+
torch.load = _patched_load
|
| 14 |
+
# --------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Train mT5-large for query diversification with URL context,
|
| 18 |
+
with resume-from-checkpoint and additional‑epochs support.
|
| 19 |
+
"""
|
| 20 |
+
import pandas as pd
|
| 21 |
+
from transformers import (
|
| 22 |
+
MT5ForConditionalGeneration,
|
| 23 |
+
MT5Tokenizer,
|
| 24 |
+
Seq2SeqTrainer,
|
| 25 |
+
Seq2SeqTrainingArguments,
|
| 26 |
+
DataCollatorForSeq2Seq,
|
| 27 |
+
)
|
| 28 |
+
from sklearn.model_selection import train_test_split
|
| 29 |
+
import numpy as np2 # metrics helper
|
| 30 |
+
from datasets import Dataset as HFDataset
|
| 31 |
+
import wandb
|
| 32 |
+
import os, json
|
| 33 |
+
import gc # Added for memory cleanup
|
| 34 |
+
|
| 35 |
+
# --------------------- CONSTANTS ------------------------------------------
|
| 36 |
+
MODEL_NAME = "google/mt5-large"
|
| 37 |
+
MAX_INPUT_LENGTH = 32
|
| 38 |
+
MAX_TARGET_LENGTH = 16
|
| 39 |
+
BATCH_SIZE = 160
|
| 40 |
+
LEARNING_RATE = 5e-5
|
| 41 |
+
NUM_EPOCHS = 5
|
| 42 |
+
WARMUP_STEPS = 1000
|
| 43 |
+
GRAD_ACC_STEPS = 1
|
| 44 |
+
CACHE_DIR = "./tokenized_cache"
|
| 45 |
+
OUTPUT_DIR = "./mt5-query-diversification"
|
| 46 |
+
# --------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def prepare_datasets(csv_path: str):
|
| 50 |
+
df = pd.read_csv(csv_path)
|
| 51 |
+
train_df, val_df = train_test_split(df, test_size=0.01, random_state=42)
|
| 52 |
+
return train_df, val_df
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compute_metrics(eval_preds, tok):
|
| 56 |
+
preds, labels = eval_preds
|
| 57 |
+
vs = len(tok)
|
| 58 |
+
preds = np2.where(preds < vs, preds, tok.pad_token_id)
|
| 59 |
+
preds = np2.where(preds >= 0, preds, tok.pad_token_id)
|
| 60 |
+
labels = np2.where(labels != -100, labels, tok.pad_token_id)
|
| 61 |
+
pred_str = tok.batch_decode(preds, skip_special_tokens=True)
|
| 62 |
+
label_str = tok.batch_decode(labels, skip_special_tokens=True)
|
| 63 |
+
exact = sum(p.strip() == l.strip() for p, l in zip(pred_str, label_str)) / len(pred_str)
|
| 64 |
+
diff = np2.mean([len(p.split()) - len(l.split()) for p, l in zip(pred_str, label_str)])
|
| 65 |
+
return {"exact_match": exact, "avg_length_diff": diff}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def list_checkpoints(out_dir):
|
| 69 |
+
if not os.path.isdir(out_dir):
|
| 70 |
+
return []
|
| 71 |
+
cps = [d for d in os.listdir(out_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(out_dir, d))]
|
| 72 |
+
cps.sort(key=lambda x: int(x.split("-")[1]))
|
| 73 |
+
return cps
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def select_checkpoint(cps):
|
| 77 |
+
print("\nAvailable checkpoints:")
|
| 78 |
+
for i, cp in enumerate(cps):
|
| 79 |
+
print(f" [{i}] {cp}")
|
| 80 |
+
print(" [n] Start training from scratch")
|
| 81 |
+
sel = input(f"Select checkpoint [0-{len(cps)-1}, n]: ").strip()
|
| 82 |
+
if sel.lower() in {"", "n"}:
|
| 83 |
+
return None
|
| 84 |
+
idx = int(sel)
|
| 85 |
+
return cps[idx] if 0 <= idx < len(cps) else None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def last_epoch(ckpt_path):
|
| 89 |
+
ts = os.path.join(ckpt_path, "trainer_state.json")
|
| 90 |
+
if not os.path.isfile(ts):
|
| 91 |
+
return 0
|
| 92 |
+
with open(ts, "r", encoding="utf-8") as f:
|
| 93 |
+
st = json.load(f)
|
| 94 |
+
if "epoch" in st:
|
| 95 |
+
return float(st["epoch"])
|
| 96 |
+
epochs = [e.get("epoch", 0) for e in st.get("log_history", []) if "epoch" in e]
|
| 97 |
+
return max(epochs) if epochs else 0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def main():
|
| 101 |
+
# Clear GPU memory before starting
|
| 102 |
+
torch.cuda.empty_cache()
|
| 103 |
+
gc.collect()
|
| 104 |
+
|
| 105 |
+
wandb.init(project="query-diversification", name="mt5-large-url-context")
|
| 106 |
+
tok = MT5Tokenizer.from_pretrained(MODEL_NAME)
|
| 107 |
+
|
| 108 |
+
# Load model with memory optimizations
|
| 109 |
+
model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME)
|
| 110 |
+
#model.gradient_checkpointing_enable() # Enable gradient checkpointing
|
| 111 |
+
model.config.use_cache = False # Disable cache during training
|
| 112 |
+
torch.cuda.empty_cache() # Clear cache after model loading
|
| 113 |
+
|
| 114 |
+
# Print memory usage
|
| 115 |
+
print(f"Model loaded. GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
| 116 |
+
|
| 117 |
+
# ----- dataset --------------------------------------------------------
|
| 118 |
+
if os.path.exists(os.path.join(CACHE_DIR, "train")):
|
| 119 |
+
train_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "train"))
|
| 120 |
+
val_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "val"))
|
| 121 |
+
else:
|
| 122 |
+
tr_df, va_df = prepare_datasets("train.csv")
|
| 123 |
+
train_ds = HFDataset.from_pandas(tr_df)
|
| 124 |
+
val_ds = HFDataset.from_pandas(va_df)
|
| 125 |
+
|
| 126 |
+
def tok_fn(ex):
|
| 127 |
+
ins = [f"For URL: {u} diversify query: {q}" for u, q in zip(ex["url"], ex["query"])]
|
| 128 |
+
tars = ex["fanout"]
|
| 129 |
+
mi = tok(ins, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
|
| 130 |
+
lbl = tok(text_target=tars, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
|
| 131 |
+
lbl["input_ids"] = [[(x if x != tok.pad_token_id else -100) for x in l] for l in lbl["input_ids"]]
|
| 132 |
+
mi["labels"] = lbl["input_ids"]
|
| 133 |
+
return mi
|
| 134 |
+
|
| 135 |
+
train_ds = train_ds.map(tok_fn, batched=True, num_proc=4)
|
| 136 |
+
val_ds = val_ds.map(tok_fn, batched=True, num_proc=4)
|
| 137 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 138 |
+
train_ds.save_to_disk(os.path.join(CACHE_DIR, "train"))
|
| 139 |
+
val_ds.save_to_disk(os.path.join(CACHE_DIR, "val"))
|
| 140 |
+
|
| 141 |
+
collator = DataCollatorForSeq2Seq(tok, model=model, padding=True)
|
| 142 |
+
|
| 143 |
+
# ----- checkpoint handling -------------------------------------------
|
| 144 |
+
cps = list_checkpoints(OUTPUT_DIR)
|
| 145 |
+
resume = None
|
| 146 |
+
n_epochs = NUM_EPOCHS
|
| 147 |
+
if cps:
|
| 148 |
+
chosen = select_checkpoint(cps)
|
| 149 |
+
if chosen:
|
| 150 |
+
resume = os.path.join(OUTPUT_DIR, chosen)
|
| 151 |
+
le = last_epoch(resume)
|
| 152 |
+
print(f"\nResuming from {resume} (epoch {le})")
|
| 153 |
+
if le >= NUM_EPOCHS:
|
| 154 |
+
extra = int(input("How many extra epochs? [0]: ").strip() or "0")
|
| 155 |
+
if extra == 0:
|
| 156 |
+
print("No extra epochs. Exit.")
|
| 157 |
+
return
|
| 158 |
+
n_epochs = le + extra
|
| 159 |
+
|
| 160 |
+
args = Seq2SeqTrainingArguments(
|
| 161 |
+
output_dir=OUTPUT_DIR,
|
| 162 |
+
eval_strategy="steps",
|
| 163 |
+
eval_steps=5000,
|
| 164 |
+
learning_rate=LEARNING_RATE,
|
| 165 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 166 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
| 167 |
+
gradient_accumulation_steps=GRAD_ACC_STEPS,
|
| 168 |
+
num_train_epochs=n_epochs,
|
| 169 |
+
warmup_steps=WARMUP_STEPS,
|
| 170 |
+
weight_decay=0.01,
|
| 171 |
+
logging_dir="./logs",
|
| 172 |
+
logging_steps=1,
|
| 173 |
+
save_steps=5000,
|
| 174 |
+
save_total_limit=3,
|
| 175 |
+
predict_with_generate=True,
|
| 176 |
+
generation_max_length=MAX_TARGET_LENGTH,
|
| 177 |
+
generation_num_beams=5,
|
| 178 |
+
bf16=torch.cuda.is_available(),
|
| 179 |
+
load_best_model_at_end=True,
|
| 180 |
+
metric_for_best_model="eval_loss",
|
| 181 |
+
greater_is_better=False,
|
| 182 |
+
report_to="wandb",
|
| 183 |
+
gradient_checkpointing=True,
|
| 184 |
+
optim="adafactor", # Changed from default AdamW - saves ~30% memory
|
| 185 |
+
tf32=True, # Enable TF32 for RTX 4090
|
| 186 |
+
dataloader_pin_memory=False, # Reduce memory fragmentation
|
| 187 |
+
full_determinism=False, # Allow non-deterministic ops for memory efficiency
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Reduce number of beams during evaluation
|
| 191 |
+
args.generation_num_beams = 3 # Instead of 5
|
| 192 |
+
|
| 193 |
+
trainer = Seq2SeqTrainer(
|
| 194 |
+
model=model,
|
| 195 |
+
args=args,
|
| 196 |
+
data_collator=collator,
|
| 197 |
+
train_dataset=train_ds,
|
| 198 |
+
eval_dataset=val_ds,
|
| 199 |
+
tokenizer=tok,
|
| 200 |
+
compute_metrics=lambda p: compute_metrics(p, tok),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Clear cache more aggressively during training
|
| 204 |
+
original_train = trainer.train
|
| 205 |
+
|
| 206 |
+
def train_with_memory_management(*args, **kwargs):
|
| 207 |
+
# Clear cache every 100 steps
|
| 208 |
+
if trainer.state.global_step % 100 == 0:
|
| 209 |
+
torch.cuda.empty_cache()
|
| 210 |
+
return original_train(*args, **kwargs)
|
| 211 |
+
|
| 212 |
+
trainer.train = train_with_memory_management
|
| 213 |
+
|
| 214 |
+
trainer.train(resume_from_checkpoint=resume)
|
| 215 |
+
trainer.save_model("./mt5-query-diversification-final")
|
| 216 |
+
tok.save_pretrained("./mt5-query-diversification-final")
|
| 217 |
+
|
| 218 |
+
# ---- quick sanity generation ----------------------------------------
|
| 219 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 220 |
+
model.to(device).eval()
|
| 221 |
+
model.config.use_cache = True # Re-enable cache for inference
|
| 222 |
+
|
| 223 |
+
samples = [("python.org", "python tutorial"),
|
| 224 |
+
("amazon.com", "laptop deals"),
|
| 225 |
+
("wikipedia.org", "machine learning")]
|
| 226 |
+
for url, q in samples:
|
| 227 |
+
txt = f"For URL: {url} diversify query: {q}"
|
| 228 |
+
ins = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
|
| 229 |
+
ins = {k: v.to(device) for k, v in ins.items()}
|
| 230 |
+
out = model.generate(**ins, max_length=MAX_TARGET_LENGTH,
|
| 231 |
+
num_beams=5, temperature=0.7,
|
| 232 |
+
do_sample=True, top_p=0.9)
|
| 233 |
+
print(f"\nInput: {txt}\nOutput: {tok.decode(out[0], skip_special_tokens=True)}")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
main()
|