dejanseo commited on
Commit
5adc166
·
verified ·
1 Parent(s): 7d1765d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +565 -0
  2. train.py +237 -0
app.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import math
5
+ import time
6
+ import difflib
7
+ import torch
8
+ import streamlit as st
9
+ from typing import List, Tuple, Dict, Any
10
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
11
+ import torch.nn.functional as F
12
+ import pandas as pd
13
+
14
+ # ------------------ CONSTANTS ------------------
15
+ MODEL_PATH = "dejanseo/query-fanout"
16
+ MAX_INPUT_LENGTH = 32
17
+ MAX_TARGET_LENGTH = 16
18
+ PRESETS_FILE = "generation_presets.json"
19
+ # ------------------------------------------------
20
+
21
+ # ------------------ BUILT-IN PRESETS ------------------
22
+ DEFAULT_PRESET: Dict[str, Any] = {
23
+ "name": "Default",
24
+ "max_candidates": 50,
25
+ "temperature": 0.9,
26
+ "top_p": 0.95,
27
+ "no_repeat_ngram_size": 2,
28
+ "repetition_penalty": 1.1,
29
+ "seed": 42,
30
+ "sort_by": "logp/len",
31
+ "select_k": 20,
32
+ "mmr_lambda": 0.70,
33
+ "dup_ratio": 0.92,
34
+ "embedding_mode": "plain_both", # embedding toggle
35
+ }
36
+ DIVERSE_PRESET: Dict[str, Any] = {
37
+ "name": "Diverse",
38
+ "max_candidates": 200,
39
+ "temperature": 1.10,
40
+ "top_p": 0.98,
41
+ "no_repeat_ngram_size": 2,
42
+ "repetition_penalty": 1.10,
43
+ "seed": 42,
44
+ "sort_by": "logp/len",
45
+ "select_k": 20,
46
+ "mmr_lambda": 0.50,
47
+ "dup_ratio": 0.88,
48
+ "embedding_mode": "plain_both", # embedding toggle
49
+ }
50
+ BUILT_IN_PRESETS = {"Default": DEFAULT_PRESET, "Diverse": DIVERSE_PRESET}
51
+
52
+ # ------------------ PRESET IO ------------------
53
+ def load_user_presets() -> Dict[str, Dict[str, Any]]:
54
+ if not os.path.exists(PRESETS_FILE):
55
+ return {}
56
+ try:
57
+ with open(PRESETS_FILE, "r", encoding="utf-8") as f:
58
+ data = json.load(f)
59
+ if isinstance(data, dict):
60
+ cleaned: Dict[str, Dict[str, Any]] = {}
61
+ for k, v in data.items():
62
+ if isinstance(v, dict):
63
+ if "embedding_mode" not in v:
64
+ v["embedding_mode"] = "plain_both"
65
+ cleaned[k] = v
66
+ return cleaned
67
+ return {}
68
+ except Exception:
69
+ return {}
70
+
71
+ def save_user_preset(name: str, cfg: Dict[str, Any]) -> None:
72
+ data = load_user_presets()
73
+ data[name] = dict(cfg, name=name)
74
+ with open(PRESETS_FILE, "w", encoding="utf-8") as f:
75
+ json.dump(data, f, ensure_ascii=False, indent=2)
76
+
77
+ def all_presets() -> Dict[str, Dict[str, Any]]:
78
+ out: Dict[str, Dict[str, Any]] = {}
79
+ out.update(BUILT_IN_PRESETS)
80
+ out.update(load_user_presets())
81
+ return out
82
+
83
+ # ------------------ MODEL LOADING ------------------
84
+ @st.cache_resource
85
+ def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
86
+ tok = MT5Tokenizer.from_pretrained(MODEL_PATH)
87
+ model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH)
88
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ model.to(device).eval()
90
+ return tok, model, device
91
+
92
+ # ------------------ GENERATION HELPERS ------------------
93
+ def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device):
94
+ txt = f"For URL: {url} diversify query: {query}"
95
+ enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
96
+ return {k: v.to(device) for k, v in enc.items()}, txt
97
+
98
+ def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]:
99
+ return tok.batch_decode(seqs, skip_special_tokens=True)
100
+
101
+ def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
102
+ if not hasattr(gen, "scores") or not gen.scores:
103
+ return [float("nan")] * gen.sequences.size(0)
104
+ scores = gen.scores
105
+ seqs = gen.sequences
106
+ nseq = seqs.size(0)
107
+ eos_id = tok.eos_token_id if tok.eos_token_id is not None else 1
108
+ pad_id = tok.pad_token_id
109
+ sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
110
+ count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
111
+ finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device)
112
+ for t in range(len(scores)):
113
+ step_logits = scores[t]
114
+ step_logprobs = F.log_softmax(step_logits, dim=-1)
115
+ step_tok = seqs[:, t + 1]
116
+ valid = step_tok.ne(pad_id) & (~finished)
117
+ if valid.any():
118
+ gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1)
119
+ sum_logp += torch.where(valid, gather, torch.zeros_like(gather))
120
+ count += valid.float()
121
+ finished |= step_tok.eq(eos_id)
122
+ count = torch.where(count.eq(0), torch.ones_like(count), count)
123
+ return [(lp / c).item() for lp, c in zip(sum_logp, count)]
124
+
125
+ def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p,
126
+ no_repeat_ngram_size=0, repetition_penalty=1.0):
127
+ kwargs = dict(
128
+ max_length=MAX_TARGET_LENGTH,
129
+ do_sample=True,
130
+ temperature=temperature,
131
+ top_p=top_p,
132
+ num_return_sequences=top_n,
133
+ return_dict_in_generate=True,
134
+ output_scores=True,
135
+ )
136
+ if no_repeat_ngram_size > 0:
137
+ kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
138
+ if repetition_penalty != 1.0:
139
+ kwargs["repetition_penalty"] = float(repetition_penalty)
140
+ gen = model.generate(**inputs, **kwargs)
141
+ texts = decode_sequences(tok, gen.sequences)
142
+ scores = avg_logprobs_from_generate(tok, gen)
143
+ return texts, scores
144
+
145
+ def get_encoder_embedding(tok, model, text: str, device: torch.device):
146
+ inputs = tok(text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
147
+ with torch.no_grad():
148
+ enc_out = model.get_encoder()(**inputs)
149
+ return enc_out.last_hidden_state.mean(dim=1).squeeze(0)
150
+
151
+ def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
152
+ return float(F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item())
153
+
154
+ def fmt_score(x: float) -> str:
155
+ if x != x or math.isinf(x):
156
+ return "n/a"
157
+ p = math.exp(x)
158
+ return f"logp/len={x:.3f} | p≈{p:.3f}"
159
+
160
+ # ------------------ RERANK (MMR + DEDUP) ------------------
161
+ def normalize_text(s: str) -> str:
162
+ return " ".join(s.strip().lower().split())
163
+
164
+ def is_near_duplicate(a: str, b: str, ratio_thresh: float) -> bool:
165
+ return difflib.SequenceMatcher(None, normalize_text(a), normalize_text(b)).ratio() >= ratio_thresh
166
+
167
+ def mmr_select(
168
+ cand_texts: List[str],
169
+ cand_embs: List[torch.Tensor],
170
+ query_emb: torch.Tensor,
171
+ k: int,
172
+ lambd: float
173
+ ) -> List[int]:
174
+ rel = [cosine_similarity(query_emb, e) for e in cand_embs]
175
+ selected: List[int] = []
176
+ available = set(range(len(cand_texts)))
177
+ while available and len(selected) < k:
178
+ if not selected:
179
+ idx = max(available, key=lambda i: rel[i])
180
+ selected.append(idx)
181
+ available.remove(idx)
182
+ continue
183
+ best_idx = None
184
+ best_score = -1e9
185
+ for i in list(available):
186
+ max_sim_to_sel = max(cosine_similarity(cand_embs[i], cand_embs[j]) for j in selected)
187
+ score = lambd * rel[i] - (1.0 - lambd) * max_sim_to_sel
188
+ if score > best_score:
189
+ best_score = score
190
+ best_idx = i
191
+ selected.append(best_idx)
192
+ available.remove(best_idx)
193
+ return selected
194
+
195
+ def distinct_n(texts: List[str], n: int) -> float:
196
+ total = 0
197
+ uniq = set()
198
+ for t in texts:
199
+ toks = t.strip().split()
200
+ if len(toks) < n:
201
+ continue
202
+ for i in range(len(toks) - n + 1):
203
+ total += 1
204
+ uniq.add(tuple(toks[i:i+n]))
205
+ return (len(uniq) / total) if total > 0 else 0.0
206
+
207
+ # ------------------ EMBEDDING MODE HELPERS (TOGGLE) ------------------
208
+ def embed_text_for_mode(url: str, text: str, mode: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device) -> torch.Tensor:
209
+ """
210
+ mode:
211
+ - "plain_both": embed raw text
212
+ - "template_both": embed with the same instruction template used for inputs
213
+ """
214
+ if mode == "template_both":
215
+ templated = f"For URL: {url} diversify query: {text}"
216
+ return get_encoder_embedding(tok, model, templated, device)
217
+ return get_encoder_embedding(tok, model, text, device)
218
+
219
+ # ------------------ TESTING HELPERS (DEFINED) ------------------
220
+ def single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty):
221
+ kwargs = dict(
222
+ max_length=MAX_TARGET_LENGTH,
223
+ do_sample=False,
224
+ num_beams=num_beams,
225
+ num_return_sequences=1,
226
+ )
227
+ if no_repeat_ngram_size > 0:
228
+ kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
229
+ if repetition_penalty != 1.0:
230
+ kwargs["repetition_penalty"] = float(repetition_penalty)
231
+ out = model.generate(**inputs, **kwargs)
232
+ return decode_sequences(tok, out)[0]
233
+
234
+ def topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty):
235
+ kwargs = dict(
236
+ max_length=MAX_TARGET_LENGTH,
237
+ do_sample=False,
238
+ num_beams=max(num_beams, top_n),
239
+ num_return_sequences=top_n,
240
+ return_dict_in_generate=True,
241
+ output_scores=True,
242
+ )
243
+ if no_repeat_ngram_size > 0:
244
+ kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
245
+ if repetition_penalty != 1.0:
246
+ kwargs["repetition_penalty"] = float(repetition_penalty)
247
+ gen = model.generate(**inputs, **kwargs)
248
+ return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
249
+
250
+ def topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
251
+ kwargs = dict(
252
+ max_length=MAX_TARGET_LENGTH,
253
+ do_sample=True,
254
+ temperature=temperature,
255
+ top_p=top_p,
256
+ num_return_sequences=top_n,
257
+ return_dict_in_generate=True,
258
+ output_scores=True,
259
+ )
260
+ if no_repeat_ngram_size > 0:
261
+ kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
262
+ if repetition_penalty != 1.0:
263
+ kwargs["repetition_penalty"] = float(repetition_penalty)
264
+ gen = model.generate(**inputs, **kwargs)
265
+ return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
266
+
267
+ def score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
268
+ texts, scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
269
+ order = sorted(range(len(texts)), key=lambda i: scores[i], reverse=True)
270
+ return [texts[i] for i in order], [scores[i] for i in order]
271
+
272
+ def diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty):
273
+ num_beams = max(num_beams, num_beam_groups * max(1, top_n // max(1, num_beam_groups)))
274
+ if num_beams % num_beam_groups != 0:
275
+ num_beams = (num_beams // num_beam_groups + 1) * num_beam_groups
276
+ top_n = min(top_n, num_beams)
277
+ kwargs = dict(
278
+ max_length=MAX_TARGET_LENGTH,
279
+ do_sample=False,
280
+ num_beams=num_beams,
281
+ num_beam_groups=num_beam_groups,
282
+ diversity_penalty=diversity_penalty,
283
+ num_return_sequences=top_n,
284
+ return_dict_in_generate=True,
285
+ output_scores=True,
286
+ )
287
+ if no_repeat_ngram_size > 0:
288
+ kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
289
+ if repetition_penalty != 1.0:
290
+ kwargs["repetition_penalty"] = float(repetition_penalty)
291
+ gen = model.generate(**inputs, **kwargs)
292
+ return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
293
+
294
+ def token_by_token_probabilities(tok, model, device, inputs):
295
+ gen = model.generate(
296
+ **inputs,
297
+ max_length=MAX_TARGET_LENGTH,
298
+ do_sample=False,
299
+ num_beams=1,
300
+ return_dict_in_generate=True,
301
+ output_scores=True,
302
+ )
303
+ seq = gen.sequences[0]
304
+ token_ids = seq.tolist()
305
+ per_token = []
306
+ for t, logits in enumerate(gen.scores):
307
+ tok_id = token_ids[t + 1]
308
+ probs = F.softmax(logits[0], dim=-1)
309
+ prob = float(probs[tok_id].detach().cpu())
310
+ sp_token = tok.convert_ids_to_tokens([tok_id])[0]
311
+ per_token.append((sp_token, prob))
312
+ return per_token
313
+
314
+ # ------------------ STREAMLIT APP ------------------
315
+ st.set_page_config(page_title="Query Fanout – Generation & Testing", layout="wide")
316
+ tok, model, device = load_model()
317
+ tab1, tab2 = st.tabs(["Generation", "Testing"])
318
+
319
+ # ----------- COMMON GENERATION RUNNER -----------
320
+ def run_generation(url: str, query: str, cfg: Dict[str, Any], show_save_controls: bool) -> None:
321
+ torch.manual_seed(int(cfg["seed"]))
322
+ if torch.cuda.is_available():
323
+ torch.cuda.manual_seed_all(int(cfg["seed"]))
324
+ start_ts = time.time()
325
+ inputs, prompt_txt = build_inputs(tok, url, query, device)
326
+ embedding_mode = cfg.get("embedding_mode", "plain_both")
327
+ orig_emb = embed_text_for_mode(url, query, embedding_mode, tok, model, device)
328
+
329
+ texts, scores = sampling_generate(
330
+ tok, model, device, inputs,
331
+ top_n=int(cfg["max_candidates"]) * 2,
332
+ temperature=float(cfg["temperature"]),
333
+ top_p=float(cfg["top_p"]),
334
+ no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
335
+ repetition_penalty=float(cfg["repetition_penalty"]),
336
+ )
337
+
338
+ seen = set()
339
+ enriched: List[Dict[str, Any]] = []
340
+ for txt, sc in zip(texts, scores):
341
+ norm = normalize_text(txt)
342
+ if norm not in seen:
343
+ seen.add(norm)
344
+ cand_emb = embed_text_for_mode(url, txt, embedding_mode, tok, model, device)
345
+ cos_sim = cosine_similarity(orig_emb, cand_emb)
346
+ enriched.append({"logp/len": sc, "p≈": math.exp(sc), "cos≈": cos_sim, "text": txt, "emb": cand_emb})
347
+ if len(enriched) >= int(cfg["max_candidates"]):
348
+ break
349
+
350
+ if cfg["sort_by"] == "logp/len":
351
+ enriched.sort(key=lambda x: x["logp/len"], reverse=True)
352
+ else:
353
+ enriched.sort(key=lambda x: x["cos≈"], reverse=True)
354
+
355
+ df = pd.DataFrame([{"logp/len": e["logp/len"], "p≈": e["p≈"], "cos≈": e["cos≈"], "text": e["text"]} for e in enriched])
356
+ df.index = range(1, len(df) + 1)
357
+ elapsed = time.time() - start_ts
358
+ st.caption(f"Generated {len(df)} unique fan-out queries in {elapsed:.2f}s")
359
+ st.dataframe(df, use_container_width=True)
360
+
361
+ filtered: List[Dict[str, Any]] = []
362
+ for cand in enriched:
363
+ keep = True
364
+ for kept in filtered:
365
+ if is_near_duplicate(cand["text"], kept["text"], float(cfg["dup_ratio"])):
366
+ keep = False
367
+ break
368
+ if keep:
369
+ filtered.append(cand)
370
+
371
+ if filtered:
372
+ k_eff = min(int(cfg["select_k"]), len(filtered))
373
+ cand_texts = [c["text"] for c in filtered]
374
+ cand_embs = [c["emb"] for c in filtered]
375
+ sel_idx = mmr_select(cand_texts, cand_embs, orig_emb, k=k_eff, lambd=float(cfg["mmr_lambda"]))
376
+ selected = [filtered[i] for i in sel_idx]
377
+
378
+ st.markdown("### Reranked Top-K (MMR + Dedup)")
379
+ st.caption(f"Mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}")
380
+ df_sel = pd.DataFrame(
381
+ [{"rank": i+1, "cos≈": s["cos≈"], "text": s["text"]} for i, s in enumerate(selected)]
382
+ )
383
+ df_sel.set_index("rank", inplace=True)
384
+ st.dataframe(df_sel, use_container_width=True)
385
+
386
+ sel_texts = [s["text"] for s in selected]
387
+ d1 = distinct_n(sel_texts, 1)
388
+ d2 = distinct_n(sel_texts, 2)
389
+ st.caption(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f} on selected {len(sel_texts)}")
390
+
391
+ combined_output = [f"Input: {prompt_txt}"]
392
+ for rank, row in df.iterrows():
393
+ combined_output.append(f"#{rank} logp/len={row['logp/len']:.3f} | p≈{row['p≈']:.3f} | cos≈{row['cos≈']:.3f} — {row['text']}")
394
+ block = "\n".join(combined_output)
395
+
396
+ st.markdown("### Copy/Paste Summary")
397
+ st.code(block, language="text")
398
+ with open("generation_output.txt", "w", encoding="utf-8") as f:
399
+ f.write(block)
400
+ f.write("\n\n[MMR selection]\n")
401
+ f.write(f"mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}\n")
402
+ for i, s in enumerate(selected, 1):
403
+ f.write(f"#{i} cos≈={s['cos≈']:.3f} — {s['text']}\n")
404
+ f.write(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f}\n")
405
+ with open("generation_selected.txt", "w", encoding="utf-8") as f:
406
+ for i, s in enumerate(selected, 1):
407
+ f.write(f"{i}\t{s['text']}\n")
408
+ st.success("Saved summary to generation_output.txt and selection to generation_selected.txt")
409
+ else:
410
+ st.warning("All candidates filtered as near-duplicates. Lower the duplicate threshold or increase max candidates.")
411
+
412
+ if show_save_controls:
413
+ st.markdown("---")
414
+ with st.form(key="save_preset_form"):
415
+ new_name = st.text_input("Preset Name", value="", placeholder="Enter a preset name")
416
+ submitted = st.form_submit_button("Save as Preset")
417
+ if submitted:
418
+ if not new_name.strip():
419
+ st.error("Preset name cannot be empty.")
420
+ elif new_name in BUILT_IN_PRESETS:
421
+ st.error("Cannot overwrite built-in presets (Default, Diverse). Use a different name.")
422
+ else:
423
+ to_save = {
424
+ "max_candidates": int(cfg["max_candidates"]),
425
+ "temperature": float(cfg["temperature"]),
426
+ "top_p": float(cfg["top_p"]),
427
+ "no_repeat_ngram_size": int(cfg["no_repeat_ngram_size"]),
428
+ "repetition_penalty": float(cfg["repetition_penalty"]),
429
+ "seed": int(cfg["seed"]),
430
+ "sort_by": str(cfg["sort_by"]),
431
+ "select_k": int(cfg["select_k"]),
432
+ "mmr_lambda": float(cfg["mmr_lambda"]),
433
+ "dup_ratio": float(cfg["dup_ratio"]),
434
+ "embedding_mode": str(cfg.get("embedding_mode", "plain_both")),
435
+ }
436
+ save_user_preset(new_name.strip(), to_save)
437
+ st.success(f"Preset '{new_name.strip()}' saved.")
438
+
439
+ # ----------- TAB 1: GENERATION -----------
440
+ with tab1:
441
+ st.header("Generation Mode — Large Diverse Fan-out")
442
+ url = st.text_input("URL", value="airbnb.com", key="gen_url")
443
+ query = st.text_input("Query", value="airbnb reviews", key="gen_query")
444
+
445
+ subtab_presets, subtab_manual = st.tabs(["Presets", "Manual Settings"])
446
+
447
+ # ----- Presets sub-tab -----
448
+ with subtab_presets:
449
+ all_p = all_presets()
450
+ preset_names = list(all_p.keys())
451
+ preset_choice = st.selectbox(
452
+ "Choose a preset",
453
+ preset_names,
454
+ index=preset_names.index("Default") if "Default" in preset_names else 0
455
+ )
456
+ sel = dict(all_p[preset_choice]) # copy to allow local edits
457
+ emb_mode_preset = st.selectbox(
458
+ "Embedding mode for reranking",
459
+ options=["plain_both", "template_both"],
460
+ index=0 if sel.get("embedding_mode", "plain_both") == "plain_both" else 1,
461
+ help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
462
+ )
463
+ sel["embedding_mode"] = emb_mode_preset
464
+
465
+ cols = st.columns(3)
466
+ with cols[0]:
467
+ st.write(f"**Max candidates:** {sel['max_candidates']}")
468
+ st.write(f"**Temperature:** {sel['temperature']}")
469
+ st.write(f"**Top-p:** {sel['top_p']}")
470
+ st.write(f"**Seed:** {sel['seed']}")
471
+ with cols[1]:
472
+ st.write(f"**No repeat n-gram:** {sel['no_repeat_ngram_size']}")
473
+ st.write(f"**Repetition penalty:** {sel['repetition_penalty']}")
474
+ st.write(f"**Sort by:** {sel['sort_by']}")
475
+ with cols[2]:
476
+ st.write(f"**Select K:** {sel['select_k']}")
477
+ st.write(f"**λ (MMR):** {sel['mmr_lambda']}")
478
+ st.write(f"**Dup ratio:** {sel['dup_ratio']}")
479
+ st.write(f"**Embedding:** {sel['embedding_mode']}")
480
+
481
+ run_gen_preset = st.button("Generate Fan-out (Preset)", key="run_gen_preset")
482
+ if run_gen_preset:
483
+ run_generation(url, query, sel, show_save_controls=False)
484
+
485
+ # ----- Manual Settings sub-tab -----
486
+ with subtab_manual:
487
+ base = DEFAULT_PRESET
488
+ max_candidates = st.number_input("Max candidates", min_value=1, max_value=200, value=int(base["max_candidates"]), step=1)
489
+ temperature = st.number_input("Temperature", min_value=0.1, max_value=2.0, value=float(base["temperature"]), step=0.1)
490
+ top_p = st.number_input("Top-p", min_value=0.1, max_value=1.0, value=float(base["top_p"]), step=0.01)
491
+ no_repeat_ngram_size = st.number_input("No repeat n-gram size (0=off)", min_value=0, max_value=10, value=int(base["no_repeat_ngram_size"]), step=1)
492
+ repetition_penalty = st.number_input("Repetition penalty (1.0=off)", min_value=1.0, max_value=2.0, value=float(base["repetition_penalty"]), step=0.1)
493
+ seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=int(base["seed"]), step=1, key="gen_seed_manual")
494
+ sort_by = st.selectbox("Sort by", ["logp/len", "cosine similarity"], index=0)
495
+
496
+ st.subheader("Diversity-aware Reranking (MMR on internal encoder vectors)")
497
+ embedding_mode_manual = st.selectbox(
498
+ "Embedding mode",
499
+ options=["plain_both", "template_both"],
500
+ index=0,
501
+ help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
502
+ )
503
+ select_k = st.number_input("Select top K after rerank", min_value=1, max_value=200, value=int(base["select_k"]), step=1)
504
+ mmr_lambda = st.number_input("MMR relevance weight λ (higher = more on-topic, lower = more diverse)", min_value=0.0, max_value=1.0, value=float(base["mmr_lambda"]), step=0.01)
505
+ dup_ratio = st.number_input("Near-duplicate threshold (SequenceMatcher ratio)", min_value=0.0, max_value=1.0, value=float(base["dup_ratio"]), step=0.01)
506
+
507
+ run_gen_manual = st.button("Generate Fan-out (Manual Settings)", key="run_gen_manual")
508
+ if run_gen_manual:
509
+ cfg = {
510
+ "max_candidates": int(max_candidates),
511
+ "temperature": float(temperature),
512
+ "top_p": float(top_p),
513
+ "no_repeat_ngram_size": int(no_repeat_ngram_size),
514
+ "repetition_penalty": float(repetition_penalty),
515
+ "seed": int(seed_value),
516
+ "sort_by": str(sort_by),
517
+ "select_k": int(select_k),
518
+ "mmr_lambda": float(mmr_lambda),
519
+ "dup_ratio": float(dup_ratio),
520
+ "embedding_mode": str(embedding_mode_manual),
521
+ }
522
+ run_generation(url, query, cfg, show_save_controls=True)
523
+
524
+ # ----------- TAB 2: TESTING -----------
525
+ with tab2:
526
+ st.header("Testing Mode — Method Comparison")
527
+ url = st.text_input("URL", value="airbnb.com", key="test_url")
528
+ query = st.text_input("Query", value="airbnb reviews", key="test_query")
529
+ num_beams = st.number_input("num_beams", min_value=1, max_value=20, value=5, step=1)
530
+ top_n = st.number_input("top_n", min_value=1, max_value=20, value=5, step=1)
531
+ temperature = st.number_input("temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
532
+ top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.9, step=0.05)
533
+ num_beam_groups = st.number_input("num_beam_groups", min_value=1, max_value=20, value=5, step=1)
534
+ diversity_penalty = st.number_input("diversity_penalty", min_value=0.0, max_value=5.0, value=1.0, step=0.1)
535
+ no_repeat_ngram_size = st.number_input("no_repeat_ngram_size", min_value=0, max_value=10, value=0, step=1)
536
+ repetition_penalty = st.number_input("repetition_penalty", min_value=1.0, max_value=2.0, value=1.0, step=0.1)
537
+ seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=42, step=1, key="test_seed")
538
+ run_test = st.button("Run Comparison", key="run_test")
539
+
540
+ if run_test:
541
+ torch.manual_seed(int(seed_value))
542
+ if torch.cuda.is_available():
543
+ torch.cuda.manual_seed_all(int(seed_value))
544
+ inputs, prompt_txt = build_inputs(tok, url, query, device)
545
+
546
+ best_det = single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty)
547
+ topn_beam_txts, topn_beam_scores = topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty)
548
+ topn_samp_txts, topn_samp_scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
549
+ ranked_txts, ranked_scores = score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
550
+ div_txts, div_scores = diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty)
551
+ per_token = token_by_token_probabilities(tok, model, device, inputs)
552
+
553
+ combined_output = [f"Input: {prompt_txt}",
554
+ "\n[1] Single best (deterministic beam)", best_det,
555
+ "\n[2] Top-N (beam)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_beam_txts, topn_beam_scores))] + \
556
+ ["\n[3] Top-N (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_samp_txts, topn_samp_scores))] + \
557
+ ["\n[4] Score-ranked (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(ranked_txts, ranked_scores))] + \
558
+ ["\n[5] Diverse beams"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(div_txts, div_scores))] + \
559
+ ["\n[6] Token-by-token probabilities (greedy)"] + [f"{t} — {p:.4f}" for t, p in per_token]
560
+
561
+ st.markdown("### Copy/Paste Summary")
562
+ st.code("\n".join(combined_output), language="text")
563
+ with open("testing_output.txt", "w", encoding="utf-8") as f:
564
+ f.write("\n".join(combined_output))
565
+ st.success("Saved summary to testing_output.txt")
train.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import torch
3
+ import numpy as np
4
+
5
+ # ---- PyTorch 2.6+ checkpoint‑resume patches ------------------------------
6
+ # 1) allow numpy reconstruct in pickle
7
+ torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
8
+ # 2) force torch.load (weights_only=False) for RNG‑state files
9
+ _orig_torch_load = torch.load
10
+ def _patched_load(*args, **kwargs):
11
+ kwargs.setdefault("weights_only", False)
12
+ return _orig_torch_load(*args, **kwargs)
13
+ torch.load = _patched_load
14
+ # --------------------------------------------------------------------------
15
+
16
+ """
17
+ Train mT5-large for query diversification with URL context,
18
+ with resume-from-checkpoint and additional‑epochs support.
19
+ """
20
+ import pandas as pd
21
+ from transformers import (
22
+ MT5ForConditionalGeneration,
23
+ MT5Tokenizer,
24
+ Seq2SeqTrainer,
25
+ Seq2SeqTrainingArguments,
26
+ DataCollatorForSeq2Seq,
27
+ )
28
+ from sklearn.model_selection import train_test_split
29
+ import numpy as np2 # metrics helper
30
+ from datasets import Dataset as HFDataset
31
+ import wandb
32
+ import os, json
33
+ import gc # Added for memory cleanup
34
+
35
+ # --------------------- CONSTANTS ------------------------------------------
36
+ MODEL_NAME = "google/mt5-large"
37
+ MAX_INPUT_LENGTH = 32
38
+ MAX_TARGET_LENGTH = 16
39
+ BATCH_SIZE = 160
40
+ LEARNING_RATE = 5e-5
41
+ NUM_EPOCHS = 5
42
+ WARMUP_STEPS = 1000
43
+ GRAD_ACC_STEPS = 1
44
+ CACHE_DIR = "./tokenized_cache"
45
+ OUTPUT_DIR = "./mt5-query-diversification"
46
+ # --------------------------------------------------------------------------
47
+
48
+
49
+ def prepare_datasets(csv_path: str):
50
+ df = pd.read_csv(csv_path)
51
+ train_df, val_df = train_test_split(df, test_size=0.01, random_state=42)
52
+ return train_df, val_df
53
+
54
+
55
+ def compute_metrics(eval_preds, tok):
56
+ preds, labels = eval_preds
57
+ vs = len(tok)
58
+ preds = np2.where(preds < vs, preds, tok.pad_token_id)
59
+ preds = np2.where(preds >= 0, preds, tok.pad_token_id)
60
+ labels = np2.where(labels != -100, labels, tok.pad_token_id)
61
+ pred_str = tok.batch_decode(preds, skip_special_tokens=True)
62
+ label_str = tok.batch_decode(labels, skip_special_tokens=True)
63
+ exact = sum(p.strip() == l.strip() for p, l in zip(pred_str, label_str)) / len(pred_str)
64
+ diff = np2.mean([len(p.split()) - len(l.split()) for p, l in zip(pred_str, label_str)])
65
+ return {"exact_match": exact, "avg_length_diff": diff}
66
+
67
+
68
+ def list_checkpoints(out_dir):
69
+ if not os.path.isdir(out_dir):
70
+ return []
71
+ cps = [d for d in os.listdir(out_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(out_dir, d))]
72
+ cps.sort(key=lambda x: int(x.split("-")[1]))
73
+ return cps
74
+
75
+
76
+ def select_checkpoint(cps):
77
+ print("\nAvailable checkpoints:")
78
+ for i, cp in enumerate(cps):
79
+ print(f" [{i}] {cp}")
80
+ print(" [n] Start training from scratch")
81
+ sel = input(f"Select checkpoint [0-{len(cps)-1}, n]: ").strip()
82
+ if sel.lower() in {"", "n"}:
83
+ return None
84
+ idx = int(sel)
85
+ return cps[idx] if 0 <= idx < len(cps) else None
86
+
87
+
88
+ def last_epoch(ckpt_path):
89
+ ts = os.path.join(ckpt_path, "trainer_state.json")
90
+ if not os.path.isfile(ts):
91
+ return 0
92
+ with open(ts, "r", encoding="utf-8") as f:
93
+ st = json.load(f)
94
+ if "epoch" in st:
95
+ return float(st["epoch"])
96
+ epochs = [e.get("epoch", 0) for e in st.get("log_history", []) if "epoch" in e]
97
+ return max(epochs) if epochs else 0
98
+
99
+
100
+ def main():
101
+ # Clear GPU memory before starting
102
+ torch.cuda.empty_cache()
103
+ gc.collect()
104
+
105
+ wandb.init(project="query-diversification", name="mt5-large-url-context")
106
+ tok = MT5Tokenizer.from_pretrained(MODEL_NAME)
107
+
108
+ # Load model with memory optimizations
109
+ model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME)
110
+ #model.gradient_checkpointing_enable() # Enable gradient checkpointing
111
+ model.config.use_cache = False # Disable cache during training
112
+ torch.cuda.empty_cache() # Clear cache after model loading
113
+
114
+ # Print memory usage
115
+ print(f"Model loaded. GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
116
+
117
+ # ----- dataset --------------------------------------------------------
118
+ if os.path.exists(os.path.join(CACHE_DIR, "train")):
119
+ train_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "train"))
120
+ val_ds = HFDataset.load_from_disk(os.path.join(CACHE_DIR, "val"))
121
+ else:
122
+ tr_df, va_df = prepare_datasets("train.csv")
123
+ train_ds = HFDataset.from_pandas(tr_df)
124
+ val_ds = HFDataset.from_pandas(va_df)
125
+
126
+ def tok_fn(ex):
127
+ ins = [f"For URL: {u} diversify query: {q}" for u, q in zip(ex["url"], ex["query"])]
128
+ tars = ex["fanout"]
129
+ mi = tok(ins, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
130
+ lbl = tok(text_target=tars, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
131
+ lbl["input_ids"] = [[(x if x != tok.pad_token_id else -100) for x in l] for l in lbl["input_ids"]]
132
+ mi["labels"] = lbl["input_ids"]
133
+ return mi
134
+
135
+ train_ds = train_ds.map(tok_fn, batched=True, num_proc=4)
136
+ val_ds = val_ds.map(tok_fn, batched=True, num_proc=4)
137
+ os.makedirs(CACHE_DIR, exist_ok=True)
138
+ train_ds.save_to_disk(os.path.join(CACHE_DIR, "train"))
139
+ val_ds.save_to_disk(os.path.join(CACHE_DIR, "val"))
140
+
141
+ collator = DataCollatorForSeq2Seq(tok, model=model, padding=True)
142
+
143
+ # ----- checkpoint handling -------------------------------------------
144
+ cps = list_checkpoints(OUTPUT_DIR)
145
+ resume = None
146
+ n_epochs = NUM_EPOCHS
147
+ if cps:
148
+ chosen = select_checkpoint(cps)
149
+ if chosen:
150
+ resume = os.path.join(OUTPUT_DIR, chosen)
151
+ le = last_epoch(resume)
152
+ print(f"\nResuming from {resume} (epoch {le})")
153
+ if le >= NUM_EPOCHS:
154
+ extra = int(input("How many extra epochs? [0]: ").strip() or "0")
155
+ if extra == 0:
156
+ print("No extra epochs. Exit.")
157
+ return
158
+ n_epochs = le + extra
159
+
160
+ args = Seq2SeqTrainingArguments(
161
+ output_dir=OUTPUT_DIR,
162
+ eval_strategy="steps",
163
+ eval_steps=5000,
164
+ learning_rate=LEARNING_RATE,
165
+ per_device_train_batch_size=BATCH_SIZE,
166
+ per_device_eval_batch_size=BATCH_SIZE,
167
+ gradient_accumulation_steps=GRAD_ACC_STEPS,
168
+ num_train_epochs=n_epochs,
169
+ warmup_steps=WARMUP_STEPS,
170
+ weight_decay=0.01,
171
+ logging_dir="./logs",
172
+ logging_steps=1,
173
+ save_steps=5000,
174
+ save_total_limit=3,
175
+ predict_with_generate=True,
176
+ generation_max_length=MAX_TARGET_LENGTH,
177
+ generation_num_beams=5,
178
+ bf16=torch.cuda.is_available(),
179
+ load_best_model_at_end=True,
180
+ metric_for_best_model="eval_loss",
181
+ greater_is_better=False,
182
+ report_to="wandb",
183
+ gradient_checkpointing=True,
184
+ optim="adafactor", # Changed from default AdamW - saves ~30% memory
185
+ tf32=True, # Enable TF32 for RTX 4090
186
+ dataloader_pin_memory=False, # Reduce memory fragmentation
187
+ full_determinism=False, # Allow non-deterministic ops for memory efficiency
188
+ )
189
+
190
+ # Reduce number of beams during evaluation
191
+ args.generation_num_beams = 3 # Instead of 5
192
+
193
+ trainer = Seq2SeqTrainer(
194
+ model=model,
195
+ args=args,
196
+ data_collator=collator,
197
+ train_dataset=train_ds,
198
+ eval_dataset=val_ds,
199
+ tokenizer=tok,
200
+ compute_metrics=lambda p: compute_metrics(p, tok),
201
+ )
202
+
203
+ # Clear cache more aggressively during training
204
+ original_train = trainer.train
205
+
206
+ def train_with_memory_management(*args, **kwargs):
207
+ # Clear cache every 100 steps
208
+ if trainer.state.global_step % 100 == 0:
209
+ torch.cuda.empty_cache()
210
+ return original_train(*args, **kwargs)
211
+
212
+ trainer.train = train_with_memory_management
213
+
214
+ trainer.train(resume_from_checkpoint=resume)
215
+ trainer.save_model("./mt5-query-diversification-final")
216
+ tok.save_pretrained("./mt5-query-diversification-final")
217
+
218
+ # ---- quick sanity generation ----------------------------------------
219
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
220
+ model.to(device).eval()
221
+ model.config.use_cache = True # Re-enable cache for inference
222
+
223
+ samples = [("python.org", "python tutorial"),
224
+ ("amazon.com", "laptop deals"),
225
+ ("wikipedia.org", "machine learning")]
226
+ for url, q in samples:
227
+ txt = f"For URL: {url} diversify query: {q}"
228
+ ins = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
229
+ ins = {k: v.to(device) for k, v in ins.items()}
230
+ out = model.generate(**ins, max_length=MAX_TARGET_LENGTH,
231
+ num_beams=5, temperature=0.7,
232
+ do_sample=True, top_p=0.9)
233
+ print(f"\nInput: {txt}\nOutput: {tok.decode(out[0], skip_special_tokens=True)}")
234
+
235
+
236
+ if __name__ == "__main__":
237
+ main()