|
|
|
|
|
import os
|
|
|
import json
|
|
|
import math
|
|
|
import time
|
|
|
import difflib
|
|
|
import torch
|
|
|
import streamlit as st
|
|
|
from typing import List, Tuple, Dict, Any
|
|
|
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
|
|
import torch.nn.functional as F
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
|
MODEL_PATH = "dejanseo/query-fanout"
|
|
|
MAX_INPUT_LENGTH = 32
|
|
|
MAX_TARGET_LENGTH = 16
|
|
|
PRESETS_FILE = "generation_presets.json"
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_PRESET: Dict[str, Any] = {
|
|
|
"name": "Default",
|
|
|
"max_candidates": 50,
|
|
|
"temperature": 0.9,
|
|
|
"top_p": 0.95,
|
|
|
"no_repeat_ngram_size": 2,
|
|
|
"repetition_penalty": 1.1,
|
|
|
"seed": 42,
|
|
|
"sort_by": "logp/len",
|
|
|
"select_k": 20,
|
|
|
"mmr_lambda": 0.70,
|
|
|
"dup_ratio": 0.92,
|
|
|
"embedding_mode": "plain_both",
|
|
|
}
|
|
|
DIVERSE_PRESET: Dict[str, Any] = {
|
|
|
"name": "Diverse",
|
|
|
"max_candidates": 200,
|
|
|
"temperature": 1.10,
|
|
|
"top_p": 0.98,
|
|
|
"no_repeat_ngram_size": 2,
|
|
|
"repetition_penalty": 1.10,
|
|
|
"seed": 42,
|
|
|
"sort_by": "logp/len",
|
|
|
"select_k": 20,
|
|
|
"mmr_lambda": 0.50,
|
|
|
"dup_ratio": 0.88,
|
|
|
"embedding_mode": "plain_both",
|
|
|
}
|
|
|
BUILT_IN_PRESETS = {"Default": DEFAULT_PRESET, "Diverse": DIVERSE_PRESET}
|
|
|
|
|
|
|
|
|
def load_user_presets() -> Dict[str, Dict[str, Any]]:
|
|
|
if not os.path.exists(PRESETS_FILE):
|
|
|
return {}
|
|
|
try:
|
|
|
with open(PRESETS_FILE, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
if isinstance(data, dict):
|
|
|
cleaned: Dict[str, Dict[str, Any]] = {}
|
|
|
for k, v in data.items():
|
|
|
if isinstance(v, dict):
|
|
|
if "embedding_mode" not in v:
|
|
|
v["embedding_mode"] = "plain_both"
|
|
|
cleaned[k] = v
|
|
|
return cleaned
|
|
|
return {}
|
|
|
except Exception:
|
|
|
return {}
|
|
|
|
|
|
def save_user_preset(name: str, cfg: Dict[str, Any]) -> None:
|
|
|
data = load_user_presets()
|
|
|
data[name] = dict(cfg, name=name)
|
|
|
with open(PRESETS_FILE, "w", encoding="utf-8") as f:
|
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def all_presets() -> Dict[str, Dict[str, Any]]:
|
|
|
out: Dict[str, Dict[str, Any]] = {}
|
|
|
out.update(BUILT_IN_PRESETS)
|
|
|
out.update(load_user_presets())
|
|
|
return out
|
|
|
|
|
|
|
|
|
@st.cache_resource
|
|
|
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
|
|
|
tok = MT5Tokenizer.from_pretrained(MODEL_PATH)
|
|
|
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH)
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
model.to(device).eval()
|
|
|
return tok, model, device
|
|
|
|
|
|
|
|
|
def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device):
|
|
|
txt = f"For URL: {url} diversify query: {query}"
|
|
|
enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
|
|
|
return {k: v.to(device) for k, v in enc.items()}, txt
|
|
|
|
|
|
def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]:
|
|
|
return tok.batch_decode(seqs, skip_special_tokens=True)
|
|
|
|
|
|
def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
|
|
|
if not hasattr(gen, "scores") or not gen.scores:
|
|
|
return [float("nan")] * gen.sequences.size(0)
|
|
|
scores = gen.scores
|
|
|
seqs = gen.sequences
|
|
|
nseq = seqs.size(0)
|
|
|
eos_id = tok.eos_token_id if tok.eos_token_id is not None else 1
|
|
|
pad_id = tok.pad_token_id
|
|
|
sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
|
|
|
count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
|
|
|
finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device)
|
|
|
for t in range(len(scores)):
|
|
|
step_logits = scores[t]
|
|
|
step_logprobs = F.log_softmax(step_logits, dim=-1)
|
|
|
step_tok = seqs[:, t + 1]
|
|
|
valid = step_tok.ne(pad_id) & (~finished)
|
|
|
if valid.any():
|
|
|
gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1)
|
|
|
sum_logp += torch.where(valid, gather, torch.zeros_like(gather))
|
|
|
count += valid.float()
|
|
|
finished |= step_tok.eq(eos_id)
|
|
|
count = torch.where(count.eq(0), torch.ones_like(count), count)
|
|
|
return [(lp / c).item() for lp, c in zip(sum_logp, count)]
|
|
|
|
|
|
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p,
|
|
|
no_repeat_ngram_size=0, repetition_penalty=1.0):
|
|
|
kwargs = dict(
|
|
|
max_length=MAX_TARGET_LENGTH,
|
|
|
do_sample=True,
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
num_return_sequences=top_n,
|
|
|
return_dict_in_generate=True,
|
|
|
output_scores=True,
|
|
|
)
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
|
|
if repetition_penalty != 1.0:
|
|
|
kwargs["repetition_penalty"] = float(repetition_penalty)
|
|
|
gen = model.generate(**inputs, **kwargs)
|
|
|
texts = decode_sequences(tok, gen.sequences)
|
|
|
scores = avg_logprobs_from_generate(tok, gen)
|
|
|
return texts, scores
|
|
|
|
|
|
def get_encoder_embedding(tok, model, text: str, device: torch.device):
|
|
|
inputs = tok(text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
|
|
|
with torch.no_grad():
|
|
|
enc_out = model.get_encoder()(**inputs)
|
|
|
return enc_out.last_hidden_state.mean(dim=1).squeeze(0)
|
|
|
|
|
|
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
|
|
|
return float(F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item())
|
|
|
|
|
|
def fmt_score(x: float) -> str:
|
|
|
if x != x or math.isinf(x):
|
|
|
return "n/a"
|
|
|
p = math.exp(x)
|
|
|
return f"logp/len={x:.3f} | p≈{p:.3f}"
|
|
|
|
|
|
|
|
|
def normalize_text(s: str) -> str:
|
|
|
return " ".join(s.strip().lower().split())
|
|
|
|
|
|
def is_near_duplicate(a: str, b: str, ratio_thresh: float) -> bool:
|
|
|
return difflib.SequenceMatcher(None, normalize_text(a), normalize_text(b)).ratio() >= ratio_thresh
|
|
|
|
|
|
def mmr_select(
|
|
|
cand_texts: List[str],
|
|
|
cand_embs: List[torch.Tensor],
|
|
|
query_emb: torch.Tensor,
|
|
|
k: int,
|
|
|
lambd: float
|
|
|
) -> List[int]:
|
|
|
rel = [cosine_similarity(query_emb, e) for e in cand_embs]
|
|
|
selected: List[int] = []
|
|
|
available = set(range(len(cand_texts)))
|
|
|
while available and len(selected) < k:
|
|
|
if not selected:
|
|
|
idx = max(available, key=lambda i: rel[i])
|
|
|
selected.append(idx)
|
|
|
available.remove(idx)
|
|
|
continue
|
|
|
best_idx = None
|
|
|
best_score = -1e9
|
|
|
for i in list(available):
|
|
|
max_sim_to_sel = max(cosine_similarity(cand_embs[i], cand_embs[j]) for j in selected)
|
|
|
score = lambd * rel[i] - (1.0 - lambd) * max_sim_to_sel
|
|
|
if score > best_score:
|
|
|
best_score = score
|
|
|
best_idx = i
|
|
|
selected.append(best_idx)
|
|
|
available.remove(best_idx)
|
|
|
return selected
|
|
|
|
|
|
def distinct_n(texts: List[str], n: int) -> float:
|
|
|
total = 0
|
|
|
uniq = set()
|
|
|
for t in texts:
|
|
|
toks = t.strip().split()
|
|
|
if len(toks) < n:
|
|
|
continue
|
|
|
for i in range(len(toks) - n + 1):
|
|
|
total += 1
|
|
|
uniq.add(tuple(toks[i:i+n]))
|
|
|
return (len(uniq) / total) if total > 0 else 0.0
|
|
|
|
|
|
|
|
|
def embed_text_for_mode(url: str, text: str, mode: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device) -> torch.Tensor:
|
|
|
"""
|
|
|
mode:
|
|
|
- "plain_both": embed raw text
|
|
|
- "template_both": embed with the same instruction template used for inputs
|
|
|
"""
|
|
|
if mode == "template_both":
|
|
|
templated = f"For URL: {url} diversify query: {text}"
|
|
|
return get_encoder_embedding(tok, model, templated, device)
|
|
|
return get_encoder_embedding(tok, model, text, device)
|
|
|
|
|
|
|
|
|
def single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty):
|
|
|
kwargs = dict(
|
|
|
max_length=MAX_TARGET_LENGTH,
|
|
|
do_sample=False,
|
|
|
num_beams=num_beams,
|
|
|
num_return_sequences=1,
|
|
|
)
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
|
|
if repetition_penalty != 1.0:
|
|
|
kwargs["repetition_penalty"] = float(repetition_penalty)
|
|
|
out = model.generate(**inputs, **kwargs)
|
|
|
return decode_sequences(tok, out)[0]
|
|
|
|
|
|
def topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty):
|
|
|
kwargs = dict(
|
|
|
max_length=MAX_TARGET_LENGTH,
|
|
|
do_sample=False,
|
|
|
num_beams=max(num_beams, top_n),
|
|
|
num_return_sequences=top_n,
|
|
|
return_dict_in_generate=True,
|
|
|
output_scores=True,
|
|
|
)
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
|
|
if repetition_penalty != 1.0:
|
|
|
kwargs["repetition_penalty"] = float(repetition_penalty)
|
|
|
gen = model.generate(**inputs, **kwargs)
|
|
|
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
|
|
|
|
|
def topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
|
|
|
kwargs = dict(
|
|
|
max_length=MAX_TARGET_LENGTH,
|
|
|
do_sample=True,
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
num_return_sequences=top_n,
|
|
|
return_dict_in_generate=True,
|
|
|
output_scores=True,
|
|
|
)
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
|
|
if repetition_penalty != 1.0:
|
|
|
kwargs["repetition_penalty"] = float(repetition_penalty)
|
|
|
gen = model.generate(**inputs, **kwargs)
|
|
|
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
|
|
|
|
|
def score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
|
|
|
texts, scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
|
|
|
order = sorted(range(len(texts)), key=lambda i: scores[i], reverse=True)
|
|
|
return [texts[i] for i in order], [scores[i] for i in order]
|
|
|
|
|
|
def diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty):
|
|
|
num_beams = max(num_beams, num_beam_groups * max(1, top_n // max(1, num_beam_groups)))
|
|
|
if num_beams % num_beam_groups != 0:
|
|
|
num_beams = (num_beams // num_beam_groups + 1) * num_beam_groups
|
|
|
top_n = min(top_n, num_beams)
|
|
|
kwargs = dict(
|
|
|
max_length=MAX_TARGET_LENGTH,
|
|
|
do_sample=False,
|
|
|
num_beams=num_beams,
|
|
|
num_beam_groups=num_beam_groups,
|
|
|
diversity_penalty=diversity_penalty,
|
|
|
num_return_sequences=top_n,
|
|
|
return_dict_in_generate=True,
|
|
|
output_scores=True,
|
|
|
)
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
|
|
if repetition_penalty != 1.0:
|
|
|
kwargs["repetition_penalty"] = float(repetition_penalty)
|
|
|
gen = model.generate(**inputs, **kwargs)
|
|
|
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
|
|
|
|
|
def token_by_token_probabilities(tok, model, device, inputs):
|
|
|
gen = model.generate(
|
|
|
**inputs,
|
|
|
max_length=MAX_TARGET_LENGTH,
|
|
|
do_sample=False,
|
|
|
num_beams=1,
|
|
|
return_dict_in_generate=True,
|
|
|
output_scores=True,
|
|
|
)
|
|
|
seq = gen.sequences[0]
|
|
|
token_ids = seq.tolist()
|
|
|
per_token = []
|
|
|
for t, logits in enumerate(gen.scores):
|
|
|
tok_id = token_ids[t + 1]
|
|
|
probs = F.softmax(logits[0], dim=-1)
|
|
|
prob = float(probs[tok_id].detach().cpu())
|
|
|
sp_token = tok.convert_ids_to_tokens([tok_id])[0]
|
|
|
per_token.append((sp_token, prob))
|
|
|
return per_token
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Query Fanout – Generation & Testing", layout="wide")
|
|
|
tok, model, device = load_model()
|
|
|
tab1, tab2 = st.tabs(["Generation", "Testing"])
|
|
|
|
|
|
|
|
|
def run_generation(url: str, query: str, cfg: Dict[str, Any], show_save_controls: bool) -> None:
|
|
|
torch.manual_seed(int(cfg["seed"]))
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.manual_seed_all(int(cfg["seed"]))
|
|
|
start_ts = time.time()
|
|
|
inputs, prompt_txt = build_inputs(tok, url, query, device)
|
|
|
embedding_mode = cfg.get("embedding_mode", "plain_both")
|
|
|
orig_emb = embed_text_for_mode(url, query, embedding_mode, tok, model, device)
|
|
|
|
|
|
texts, scores = sampling_generate(
|
|
|
tok, model, device, inputs,
|
|
|
top_n=int(cfg["max_candidates"]) * 2,
|
|
|
temperature=float(cfg["temperature"]),
|
|
|
top_p=float(cfg["top_p"]),
|
|
|
no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
|
|
|
repetition_penalty=float(cfg["repetition_penalty"]),
|
|
|
)
|
|
|
|
|
|
seen = set()
|
|
|
enriched: List[Dict[str, Any]] = []
|
|
|
for txt, sc in zip(texts, scores):
|
|
|
norm = normalize_text(txt)
|
|
|
if norm not in seen:
|
|
|
seen.add(norm)
|
|
|
cand_emb = embed_text_for_mode(url, txt, embedding_mode, tok, model, device)
|
|
|
cos_sim = cosine_similarity(orig_emb, cand_emb)
|
|
|
enriched.append({"logp/len": sc, "p≈": math.exp(sc), "cos≈": cos_sim, "text": txt, "emb": cand_emb})
|
|
|
if len(enriched) >= int(cfg["max_candidates"]):
|
|
|
break
|
|
|
|
|
|
if cfg["sort_by"] == "logp/len":
|
|
|
enriched.sort(key=lambda x: x["logp/len"], reverse=True)
|
|
|
else:
|
|
|
enriched.sort(key=lambda x: x["cos≈"], reverse=True)
|
|
|
|
|
|
df = pd.DataFrame([{"logp/len": e["logp/len"], "p≈": e["p≈"], "cos≈": e["cos≈"], "text": e["text"]} for e in enriched])
|
|
|
df.index = range(1, len(df) + 1)
|
|
|
elapsed = time.time() - start_ts
|
|
|
st.caption(f"Generated {len(df)} unique fan-out queries in {elapsed:.2f}s")
|
|
|
st.dataframe(df, use_container_width=True)
|
|
|
|
|
|
filtered: List[Dict[str, Any]] = []
|
|
|
for cand in enriched:
|
|
|
keep = True
|
|
|
for kept in filtered:
|
|
|
if is_near_duplicate(cand["text"], kept["text"], float(cfg["dup_ratio"])):
|
|
|
keep = False
|
|
|
break
|
|
|
if keep:
|
|
|
filtered.append(cand)
|
|
|
|
|
|
if filtered:
|
|
|
k_eff = min(int(cfg["select_k"]), len(filtered))
|
|
|
cand_texts = [c["text"] for c in filtered]
|
|
|
cand_embs = [c["emb"] for c in filtered]
|
|
|
sel_idx = mmr_select(cand_texts, cand_embs, orig_emb, k=k_eff, lambd=float(cfg["mmr_lambda"]))
|
|
|
selected = [filtered[i] for i in sel_idx]
|
|
|
|
|
|
st.markdown("### Reranked Top-K (MMR + Dedup)")
|
|
|
st.caption(f"Mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}")
|
|
|
df_sel = pd.DataFrame(
|
|
|
[{"rank": i+1, "cos≈": s["cos≈"], "text": s["text"]} for i, s in enumerate(selected)]
|
|
|
)
|
|
|
df_sel.set_index("rank", inplace=True)
|
|
|
st.dataframe(df_sel, use_container_width=True)
|
|
|
|
|
|
sel_texts = [s["text"] for s in selected]
|
|
|
d1 = distinct_n(sel_texts, 1)
|
|
|
d2 = distinct_n(sel_texts, 2)
|
|
|
st.caption(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f} on selected {len(sel_texts)}")
|
|
|
|
|
|
combined_output = [f"Input: {prompt_txt}"]
|
|
|
for rank, row in df.iterrows():
|
|
|
combined_output.append(f"#{rank} logp/len={row['logp/len']:.3f} | p≈{row['p≈']:.3f} | cos≈{row['cos≈']:.3f} — {row['text']}")
|
|
|
block = "\n".join(combined_output)
|
|
|
|
|
|
st.markdown("### Copy/Paste Summary")
|
|
|
st.code(block, language="text")
|
|
|
with open("generation_output.txt", "w", encoding="utf-8") as f:
|
|
|
f.write(block)
|
|
|
f.write("\n\n[MMR selection]\n")
|
|
|
f.write(f"mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}\n")
|
|
|
for i, s in enumerate(selected, 1):
|
|
|
f.write(f"#{i} cos≈={s['cos≈']:.3f} — {s['text']}\n")
|
|
|
f.write(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f}\n")
|
|
|
with open("generation_selected.txt", "w", encoding="utf-8") as f:
|
|
|
for i, s in enumerate(selected, 1):
|
|
|
f.write(f"{i}\t{s['text']}\n")
|
|
|
st.success("Saved summary to generation_output.txt and selection to generation_selected.txt")
|
|
|
else:
|
|
|
st.warning("All candidates filtered as near-duplicates. Lower the duplicate threshold or increase max candidates.")
|
|
|
|
|
|
if show_save_controls:
|
|
|
st.markdown("---")
|
|
|
with st.form(key="save_preset_form"):
|
|
|
new_name = st.text_input("Preset Name", value="", placeholder="Enter a preset name")
|
|
|
submitted = st.form_submit_button("Save as Preset")
|
|
|
if submitted:
|
|
|
if not new_name.strip():
|
|
|
st.error("Preset name cannot be empty.")
|
|
|
elif new_name in BUILT_IN_PRESETS:
|
|
|
st.error("Cannot overwrite built-in presets (Default, Diverse). Use a different name.")
|
|
|
else:
|
|
|
to_save = {
|
|
|
"max_candidates": int(cfg["max_candidates"]),
|
|
|
"temperature": float(cfg["temperature"]),
|
|
|
"top_p": float(cfg["top_p"]),
|
|
|
"no_repeat_ngram_size": int(cfg["no_repeat_ngram_size"]),
|
|
|
"repetition_penalty": float(cfg["repetition_penalty"]),
|
|
|
"seed": int(cfg["seed"]),
|
|
|
"sort_by": str(cfg["sort_by"]),
|
|
|
"select_k": int(cfg["select_k"]),
|
|
|
"mmr_lambda": float(cfg["mmr_lambda"]),
|
|
|
"dup_ratio": float(cfg["dup_ratio"]),
|
|
|
"embedding_mode": str(cfg.get("embedding_mode", "plain_both")),
|
|
|
}
|
|
|
save_user_preset(new_name.strip(), to_save)
|
|
|
st.success(f"Preset '{new_name.strip()}' saved.")
|
|
|
|
|
|
|
|
|
with tab1:
|
|
|
st.header("Generation Mode — Large Diverse Fan-out")
|
|
|
url = st.text_input("URL", value="airbnb.com", key="gen_url")
|
|
|
query = st.text_input("Query", value="airbnb reviews", key="gen_query")
|
|
|
|
|
|
subtab_presets, subtab_manual = st.tabs(["Presets", "Manual Settings"])
|
|
|
|
|
|
|
|
|
with subtab_presets:
|
|
|
all_p = all_presets()
|
|
|
preset_names = list(all_p.keys())
|
|
|
preset_choice = st.selectbox(
|
|
|
"Choose a preset",
|
|
|
preset_names,
|
|
|
index=preset_names.index("Default") if "Default" in preset_names else 0
|
|
|
)
|
|
|
sel = dict(all_p[preset_choice])
|
|
|
emb_mode_preset = st.selectbox(
|
|
|
"Embedding mode for reranking",
|
|
|
options=["plain_both", "template_both"],
|
|
|
index=0 if sel.get("embedding_mode", "plain_both") == "plain_both" else 1,
|
|
|
help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
|
|
|
)
|
|
|
sel["embedding_mode"] = emb_mode_preset
|
|
|
|
|
|
cols = st.columns(3)
|
|
|
with cols[0]:
|
|
|
st.write(f"**Max candidates:** {sel['max_candidates']}")
|
|
|
st.write(f"**Temperature:** {sel['temperature']}")
|
|
|
st.write(f"**Top-p:** {sel['top_p']}")
|
|
|
st.write(f"**Seed:** {sel['seed']}")
|
|
|
with cols[1]:
|
|
|
st.write(f"**No repeat n-gram:** {sel['no_repeat_ngram_size']}")
|
|
|
st.write(f"**Repetition penalty:** {sel['repetition_penalty']}")
|
|
|
st.write(f"**Sort by:** {sel['sort_by']}")
|
|
|
with cols[2]:
|
|
|
st.write(f"**Select K:** {sel['select_k']}")
|
|
|
st.write(f"**λ (MMR):** {sel['mmr_lambda']}")
|
|
|
st.write(f"**Dup ratio:** {sel['dup_ratio']}")
|
|
|
st.write(f"**Embedding:** {sel['embedding_mode']}")
|
|
|
|
|
|
run_gen_preset = st.button("Generate Fan-out (Preset)", key="run_gen_preset")
|
|
|
if run_gen_preset:
|
|
|
run_generation(url, query, sel, show_save_controls=False)
|
|
|
|
|
|
|
|
|
with subtab_manual:
|
|
|
base = DEFAULT_PRESET
|
|
|
max_candidates = st.number_input("Max candidates", min_value=1, max_value=200, value=int(base["max_candidates"]), step=1)
|
|
|
temperature = st.number_input("Temperature", min_value=0.1, max_value=2.0, value=float(base["temperature"]), step=0.1)
|
|
|
top_p = st.number_input("Top-p", min_value=0.1, max_value=1.0, value=float(base["top_p"]), step=0.01)
|
|
|
no_repeat_ngram_size = st.number_input("No repeat n-gram size (0=off)", min_value=0, max_value=10, value=int(base["no_repeat_ngram_size"]), step=1)
|
|
|
repetition_penalty = st.number_input("Repetition penalty (1.0=off)", min_value=1.0, max_value=2.0, value=float(base["repetition_penalty"]), step=0.1)
|
|
|
seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=int(base["seed"]), step=1, key="gen_seed_manual")
|
|
|
sort_by = st.selectbox("Sort by", ["logp/len", "cosine similarity"], index=0)
|
|
|
|
|
|
st.subheader("Diversity-aware Reranking (MMR on internal encoder vectors)")
|
|
|
embedding_mode_manual = st.selectbox(
|
|
|
"Embedding mode",
|
|
|
options=["plain_both", "template_both"],
|
|
|
index=0,
|
|
|
help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
|
|
|
)
|
|
|
select_k = st.number_input("Select top K after rerank", min_value=1, max_value=200, value=int(base["select_k"]), step=1)
|
|
|
mmr_lambda = st.number_input("MMR relevance weight λ (higher = more on-topic, lower = more diverse)", min_value=0.0, max_value=1.0, value=float(base["mmr_lambda"]), step=0.01)
|
|
|
dup_ratio = st.number_input("Near-duplicate threshold (SequenceMatcher ratio)", min_value=0.0, max_value=1.0, value=float(base["dup_ratio"]), step=0.01)
|
|
|
|
|
|
run_gen_manual = st.button("Generate Fan-out (Manual Settings)", key="run_gen_manual")
|
|
|
if run_gen_manual:
|
|
|
cfg = {
|
|
|
"max_candidates": int(max_candidates),
|
|
|
"temperature": float(temperature),
|
|
|
"top_p": float(top_p),
|
|
|
"no_repeat_ngram_size": int(no_repeat_ngram_size),
|
|
|
"repetition_penalty": float(repetition_penalty),
|
|
|
"seed": int(seed_value),
|
|
|
"sort_by": str(sort_by),
|
|
|
"select_k": int(select_k),
|
|
|
"mmr_lambda": float(mmr_lambda),
|
|
|
"dup_ratio": float(dup_ratio),
|
|
|
"embedding_mode": str(embedding_mode_manual),
|
|
|
}
|
|
|
run_generation(url, query, cfg, show_save_controls=True)
|
|
|
|
|
|
|
|
|
with tab2:
|
|
|
st.header("Testing Mode — Method Comparison")
|
|
|
url = st.text_input("URL", value="airbnb.com", key="test_url")
|
|
|
query = st.text_input("Query", value="airbnb reviews", key="test_query")
|
|
|
num_beams = st.number_input("num_beams", min_value=1, max_value=20, value=5, step=1)
|
|
|
top_n = st.number_input("top_n", min_value=1, max_value=20, value=5, step=1)
|
|
|
temperature = st.number_input("temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
|
|
|
top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.9, step=0.05)
|
|
|
num_beam_groups = st.number_input("num_beam_groups", min_value=1, max_value=20, value=5, step=1)
|
|
|
diversity_penalty = st.number_input("diversity_penalty", min_value=0.0, max_value=5.0, value=1.0, step=0.1)
|
|
|
no_repeat_ngram_size = st.number_input("no_repeat_ngram_size", min_value=0, max_value=10, value=0, step=1)
|
|
|
repetition_penalty = st.number_input("repetition_penalty", min_value=1.0, max_value=2.0, value=1.0, step=0.1)
|
|
|
seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=42, step=1, key="test_seed")
|
|
|
run_test = st.button("Run Comparison", key="run_test")
|
|
|
|
|
|
if run_test:
|
|
|
torch.manual_seed(int(seed_value))
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.manual_seed_all(int(seed_value))
|
|
|
inputs, prompt_txt = build_inputs(tok, url, query, device)
|
|
|
|
|
|
best_det = single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty)
|
|
|
topn_beam_txts, topn_beam_scores = topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty)
|
|
|
topn_samp_txts, topn_samp_scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
|
|
|
ranked_txts, ranked_scores = score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
|
|
|
div_txts, div_scores = diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty)
|
|
|
per_token = token_by_token_probabilities(tok, model, device, inputs)
|
|
|
|
|
|
combined_output = [f"Input: {prompt_txt}",
|
|
|
"\n[1] Single best (deterministic beam)", best_det,
|
|
|
"\n[2] Top-N (beam)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_beam_txts, topn_beam_scores))] + \
|
|
|
["\n[3] Top-N (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_samp_txts, topn_samp_scores))] + \
|
|
|
["\n[4] Score-ranked (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(ranked_txts, ranked_scores))] + \
|
|
|
["\n[5] Diverse beams"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(div_txts, div_scores))] + \
|
|
|
["\n[6] Token-by-token probabilities (greedy)"] + [f"{t} — {p:.4f}" for t, p in per_token]
|
|
|
|
|
|
st.markdown("### Copy/Paste Summary")
|
|
|
st.code("\n".join(combined_output), language="text")
|
|
|
with open("testing_output.txt", "w", encoding="utf-8") as f:
|
|
|
f.write("\n".join(combined_output))
|
|
|
st.success("Saved summary to testing_output.txt")
|
|
|
|