query-fanout / app.py
dejanseo's picture
Upload 2 files
5adc166 verified
# app.py
import os
import json
import math
import time
import difflib
import torch
import streamlit as st
from typing import List, Tuple, Dict, Any
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
import torch.nn.functional as F
import pandas as pd
# ------------------ CONSTANTS ------------------
MODEL_PATH = "dejanseo/query-fanout"
MAX_INPUT_LENGTH = 32
MAX_TARGET_LENGTH = 16
PRESETS_FILE = "generation_presets.json"
# ------------------------------------------------
# ------------------ BUILT-IN PRESETS ------------------
DEFAULT_PRESET: Dict[str, Any] = {
"name": "Default",
"max_candidates": 50,
"temperature": 0.9,
"top_p": 0.95,
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.1,
"seed": 42,
"sort_by": "logp/len",
"select_k": 20,
"mmr_lambda": 0.70,
"dup_ratio": 0.92,
"embedding_mode": "plain_both", # embedding toggle
}
DIVERSE_PRESET: Dict[str, Any] = {
"name": "Diverse",
"max_candidates": 200,
"temperature": 1.10,
"top_p": 0.98,
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.10,
"seed": 42,
"sort_by": "logp/len",
"select_k": 20,
"mmr_lambda": 0.50,
"dup_ratio": 0.88,
"embedding_mode": "plain_both", # embedding toggle
}
BUILT_IN_PRESETS = {"Default": DEFAULT_PRESET, "Diverse": DIVERSE_PRESET}
# ------------------ PRESET IO ------------------
def load_user_presets() -> Dict[str, Dict[str, Any]]:
if not os.path.exists(PRESETS_FILE):
return {}
try:
with open(PRESETS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
cleaned: Dict[str, Dict[str, Any]] = {}
for k, v in data.items():
if isinstance(v, dict):
if "embedding_mode" not in v:
v["embedding_mode"] = "plain_both"
cleaned[k] = v
return cleaned
return {}
except Exception:
return {}
def save_user_preset(name: str, cfg: Dict[str, Any]) -> None:
data = load_user_presets()
data[name] = dict(cfg, name=name)
with open(PRESETS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def all_presets() -> Dict[str, Dict[str, Any]]:
out: Dict[str, Dict[str, Any]] = {}
out.update(BUILT_IN_PRESETS)
out.update(load_user_presets())
return out
# ------------------ MODEL LOADING ------------------
@st.cache_resource
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
tok = MT5Tokenizer.from_pretrained(MODEL_PATH)
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
return tok, model, device
# ------------------ GENERATION HELPERS ------------------
def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device):
txt = f"For URL: {url} diversify query: {query}"
enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
return {k: v.to(device) for k, v in enc.items()}, txt
def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]:
return tok.batch_decode(seqs, skip_special_tokens=True)
def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
if not hasattr(gen, "scores") or not gen.scores:
return [float("nan")] * gen.sequences.size(0)
scores = gen.scores
seqs = gen.sequences
nseq = seqs.size(0)
eos_id = tok.eos_token_id if tok.eos_token_id is not None else 1
pad_id = tok.pad_token_id
sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device)
for t in range(len(scores)):
step_logits = scores[t]
step_logprobs = F.log_softmax(step_logits, dim=-1)
step_tok = seqs[:, t + 1]
valid = step_tok.ne(pad_id) & (~finished)
if valid.any():
gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1)
sum_logp += torch.where(valid, gather, torch.zeros_like(gather))
count += valid.float()
finished |= step_tok.eq(eos_id)
count = torch.where(count.eq(0), torch.ones_like(count), count)
return [(lp / c).item() for lp, c in zip(sum_logp, count)]
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p,
no_repeat_ngram_size=0, repetition_penalty=1.0):
kwargs = dict(
max_length=MAX_TARGET_LENGTH,
do_sample=True,
temperature=temperature,
top_p=top_p,
num_return_sequences=top_n,
return_dict_in_generate=True,
output_scores=True,
)
if no_repeat_ngram_size > 0:
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
if repetition_penalty != 1.0:
kwargs["repetition_penalty"] = float(repetition_penalty)
gen = model.generate(**inputs, **kwargs)
texts = decode_sequences(tok, gen.sequences)
scores = avg_logprobs_from_generate(tok, gen)
return texts, scores
def get_encoder_embedding(tok, model, text: str, device: torch.device):
inputs = tok(text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
with torch.no_grad():
enc_out = model.get_encoder()(**inputs)
return enc_out.last_hidden_state.mean(dim=1).squeeze(0)
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
return float(F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item())
def fmt_score(x: float) -> str:
if x != x or math.isinf(x):
return "n/a"
p = math.exp(x)
return f"logp/len={x:.3f} | p≈{p:.3f}"
# ------------------ RERANK (MMR + DEDUP) ------------------
def normalize_text(s: str) -> str:
return " ".join(s.strip().lower().split())
def is_near_duplicate(a: str, b: str, ratio_thresh: float) -> bool:
return difflib.SequenceMatcher(None, normalize_text(a), normalize_text(b)).ratio() >= ratio_thresh
def mmr_select(
cand_texts: List[str],
cand_embs: List[torch.Tensor],
query_emb: torch.Tensor,
k: int,
lambd: float
) -> List[int]:
rel = [cosine_similarity(query_emb, e) for e in cand_embs]
selected: List[int] = []
available = set(range(len(cand_texts)))
while available and len(selected) < k:
if not selected:
idx = max(available, key=lambda i: rel[i])
selected.append(idx)
available.remove(idx)
continue
best_idx = None
best_score = -1e9
for i in list(available):
max_sim_to_sel = max(cosine_similarity(cand_embs[i], cand_embs[j]) for j in selected)
score = lambd * rel[i] - (1.0 - lambd) * max_sim_to_sel
if score > best_score:
best_score = score
best_idx = i
selected.append(best_idx)
available.remove(best_idx)
return selected
def distinct_n(texts: List[str], n: int) -> float:
total = 0
uniq = set()
for t in texts:
toks = t.strip().split()
if len(toks) < n:
continue
for i in range(len(toks) - n + 1):
total += 1
uniq.add(tuple(toks[i:i+n]))
return (len(uniq) / total) if total > 0 else 0.0
# ------------------ EMBEDDING MODE HELPERS (TOGGLE) ------------------
def embed_text_for_mode(url: str, text: str, mode: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device) -> torch.Tensor:
"""
mode:
- "plain_both": embed raw text
- "template_both": embed with the same instruction template used for inputs
"""
if mode == "template_both":
templated = f"For URL: {url} diversify query: {text}"
return get_encoder_embedding(tok, model, templated, device)
return get_encoder_embedding(tok, model, text, device)
# ------------------ TESTING HELPERS (DEFINED) ------------------
def single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty):
kwargs = dict(
max_length=MAX_TARGET_LENGTH,
do_sample=False,
num_beams=num_beams,
num_return_sequences=1,
)
if no_repeat_ngram_size > 0:
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
if repetition_penalty != 1.0:
kwargs["repetition_penalty"] = float(repetition_penalty)
out = model.generate(**inputs, **kwargs)
return decode_sequences(tok, out)[0]
def topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty):
kwargs = dict(
max_length=MAX_TARGET_LENGTH,
do_sample=False,
num_beams=max(num_beams, top_n),
num_return_sequences=top_n,
return_dict_in_generate=True,
output_scores=True,
)
if no_repeat_ngram_size > 0:
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
if repetition_penalty != 1.0:
kwargs["repetition_penalty"] = float(repetition_penalty)
gen = model.generate(**inputs, **kwargs)
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
def topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
kwargs = dict(
max_length=MAX_TARGET_LENGTH,
do_sample=True,
temperature=temperature,
top_p=top_p,
num_return_sequences=top_n,
return_dict_in_generate=True,
output_scores=True,
)
if no_repeat_ngram_size > 0:
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
if repetition_penalty != 1.0:
kwargs["repetition_penalty"] = float(repetition_penalty)
gen = model.generate(**inputs, **kwargs)
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
def score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty):
texts, scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
order = sorted(range(len(texts)), key=lambda i: scores[i], reverse=True)
return [texts[i] for i in order], [scores[i] for i in order]
def diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty):
num_beams = max(num_beams, num_beam_groups * max(1, top_n // max(1, num_beam_groups)))
if num_beams % num_beam_groups != 0:
num_beams = (num_beams // num_beam_groups + 1) * num_beam_groups
top_n = min(top_n, num_beams)
kwargs = dict(
max_length=MAX_TARGET_LENGTH,
do_sample=False,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
num_return_sequences=top_n,
return_dict_in_generate=True,
output_scores=True,
)
if no_repeat_ngram_size > 0:
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
if repetition_penalty != 1.0:
kwargs["repetition_penalty"] = float(repetition_penalty)
gen = model.generate(**inputs, **kwargs)
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
def token_by_token_probabilities(tok, model, device, inputs):
gen = model.generate(
**inputs,
max_length=MAX_TARGET_LENGTH,
do_sample=False,
num_beams=1,
return_dict_in_generate=True,
output_scores=True,
)
seq = gen.sequences[0]
token_ids = seq.tolist()
per_token = []
for t, logits in enumerate(gen.scores):
tok_id = token_ids[t + 1]
probs = F.softmax(logits[0], dim=-1)
prob = float(probs[tok_id].detach().cpu())
sp_token = tok.convert_ids_to_tokens([tok_id])[0]
per_token.append((sp_token, prob))
return per_token
# ------------------ STREAMLIT APP ------------------
st.set_page_config(page_title="Query Fanout – Generation & Testing", layout="wide")
tok, model, device = load_model()
tab1, tab2 = st.tabs(["Generation", "Testing"])
# ----------- COMMON GENERATION RUNNER -----------
def run_generation(url: str, query: str, cfg: Dict[str, Any], show_save_controls: bool) -> None:
torch.manual_seed(int(cfg["seed"]))
if torch.cuda.is_available():
torch.cuda.manual_seed_all(int(cfg["seed"]))
start_ts = time.time()
inputs, prompt_txt = build_inputs(tok, url, query, device)
embedding_mode = cfg.get("embedding_mode", "plain_both")
orig_emb = embed_text_for_mode(url, query, embedding_mode, tok, model, device)
texts, scores = sampling_generate(
tok, model, device, inputs,
top_n=int(cfg["max_candidates"]) * 2,
temperature=float(cfg["temperature"]),
top_p=float(cfg["top_p"]),
no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
repetition_penalty=float(cfg["repetition_penalty"]),
)
seen = set()
enriched: List[Dict[str, Any]] = []
for txt, sc in zip(texts, scores):
norm = normalize_text(txt)
if norm not in seen:
seen.add(norm)
cand_emb = embed_text_for_mode(url, txt, embedding_mode, tok, model, device)
cos_sim = cosine_similarity(orig_emb, cand_emb)
enriched.append({"logp/len": sc, "p≈": math.exp(sc), "cos≈": cos_sim, "text": txt, "emb": cand_emb})
if len(enriched) >= int(cfg["max_candidates"]):
break
if cfg["sort_by"] == "logp/len":
enriched.sort(key=lambda x: x["logp/len"], reverse=True)
else:
enriched.sort(key=lambda x: x["cos≈"], reverse=True)
df = pd.DataFrame([{"logp/len": e["logp/len"], "p≈": e["p≈"], "cos≈": e["cos≈"], "text": e["text"]} for e in enriched])
df.index = range(1, len(df) + 1)
elapsed = time.time() - start_ts
st.caption(f"Generated {len(df)} unique fan-out queries in {elapsed:.2f}s")
st.dataframe(df, use_container_width=True)
filtered: List[Dict[str, Any]] = []
for cand in enriched:
keep = True
for kept in filtered:
if is_near_duplicate(cand["text"], kept["text"], float(cfg["dup_ratio"])):
keep = False
break
if keep:
filtered.append(cand)
if filtered:
k_eff = min(int(cfg["select_k"]), len(filtered))
cand_texts = [c["text"] for c in filtered]
cand_embs = [c["emb"] for c in filtered]
sel_idx = mmr_select(cand_texts, cand_embs, orig_emb, k=k_eff, lambd=float(cfg["mmr_lambda"]))
selected = [filtered[i] for i in sel_idx]
st.markdown("### Reranked Top-K (MMR + Dedup)")
st.caption(f"Mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}")
df_sel = pd.DataFrame(
[{"rank": i+1, "cos≈": s["cos≈"], "text": s["text"]} for i, s in enumerate(selected)]
)
df_sel.set_index("rank", inplace=True)
st.dataframe(df_sel, use_container_width=True)
sel_texts = [s["text"] for s in selected]
d1 = distinct_n(sel_texts, 1)
d2 = distinct_n(sel_texts, 2)
st.caption(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f} on selected {len(sel_texts)}")
combined_output = [f"Input: {prompt_txt}"]
for rank, row in df.iterrows():
combined_output.append(f"#{rank} logp/len={row['logp/len']:.3f} | p≈{row['p≈']:.3f} | cos≈{row['cos≈']:.3f} — {row['text']}")
block = "\n".join(combined_output)
st.markdown("### Copy/Paste Summary")
st.code(block, language="text")
with open("generation_output.txt", "w", encoding="utf-8") as f:
f.write(block)
f.write("\n\n[MMR selection]\n")
f.write(f"mode={embedding_mode} | λ={float(cfg['mmr_lambda']):.2f} | dup_ratio≥{float(cfg['dup_ratio']):.2f} | K={k_eff}\n")
for i, s in enumerate(selected, 1):
f.write(f"#{i} cos≈={s['cos≈']:.3f} — {s['text']}\n")
f.write(f"Distinct-1={d1:.3f} | Distinct-2={d2:.3f}\n")
with open("generation_selected.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(selected, 1):
f.write(f"{i}\t{s['text']}\n")
st.success("Saved summary to generation_output.txt and selection to generation_selected.txt")
else:
st.warning("All candidates filtered as near-duplicates. Lower the duplicate threshold or increase max candidates.")
if show_save_controls:
st.markdown("---")
with st.form(key="save_preset_form"):
new_name = st.text_input("Preset Name", value="", placeholder="Enter a preset name")
submitted = st.form_submit_button("Save as Preset")
if submitted:
if not new_name.strip():
st.error("Preset name cannot be empty.")
elif new_name in BUILT_IN_PRESETS:
st.error("Cannot overwrite built-in presets (Default, Diverse). Use a different name.")
else:
to_save = {
"max_candidates": int(cfg["max_candidates"]),
"temperature": float(cfg["temperature"]),
"top_p": float(cfg["top_p"]),
"no_repeat_ngram_size": int(cfg["no_repeat_ngram_size"]),
"repetition_penalty": float(cfg["repetition_penalty"]),
"seed": int(cfg["seed"]),
"sort_by": str(cfg["sort_by"]),
"select_k": int(cfg["select_k"]),
"mmr_lambda": float(cfg["mmr_lambda"]),
"dup_ratio": float(cfg["dup_ratio"]),
"embedding_mode": str(cfg.get("embedding_mode", "plain_both")),
}
save_user_preset(new_name.strip(), to_save)
st.success(f"Preset '{new_name.strip()}' saved.")
# ----------- TAB 1: GENERATION -----------
with tab1:
st.header("Generation Mode — Large Diverse Fan-out")
url = st.text_input("URL", value="airbnb.com", key="gen_url")
query = st.text_input("Query", value="airbnb reviews", key="gen_query")
subtab_presets, subtab_manual = st.tabs(["Presets", "Manual Settings"])
# ----- Presets sub-tab -----
with subtab_presets:
all_p = all_presets()
preset_names = list(all_p.keys())
preset_choice = st.selectbox(
"Choose a preset",
preset_names,
index=preset_names.index("Default") if "Default" in preset_names else 0
)
sel = dict(all_p[preset_choice]) # copy to allow local edits
emb_mode_preset = st.selectbox(
"Embedding mode for reranking",
options=["plain_both", "template_both"],
index=0 if sel.get("embedding_mode", "plain_both") == "plain_both" else 1,
help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
)
sel["embedding_mode"] = emb_mode_preset
cols = st.columns(3)
with cols[0]:
st.write(f"**Max candidates:** {sel['max_candidates']}")
st.write(f"**Temperature:** {sel['temperature']}")
st.write(f"**Top-p:** {sel['top_p']}")
st.write(f"**Seed:** {sel['seed']}")
with cols[1]:
st.write(f"**No repeat n-gram:** {sel['no_repeat_ngram_size']}")
st.write(f"**Repetition penalty:** {sel['repetition_penalty']}")
st.write(f"**Sort by:** {sel['sort_by']}")
with cols[2]:
st.write(f"**Select K:** {sel['select_k']}")
st.write(f"**λ (MMR):** {sel['mmr_lambda']}")
st.write(f"**Dup ratio:** {sel['dup_ratio']}")
st.write(f"**Embedding:** {sel['embedding_mode']}")
run_gen_preset = st.button("Generate Fan-out (Preset)", key="run_gen_preset")
if run_gen_preset:
run_generation(url, query, sel, show_save_controls=False)
# ----- Manual Settings sub-tab -----
with subtab_manual:
base = DEFAULT_PRESET
max_candidates = st.number_input("Max candidates", min_value=1, max_value=200, value=int(base["max_candidates"]), step=1)
temperature = st.number_input("Temperature", min_value=0.1, max_value=2.0, value=float(base["temperature"]), step=0.1)
top_p = st.number_input("Top-p", min_value=0.1, max_value=1.0, value=float(base["top_p"]), step=0.01)
no_repeat_ngram_size = st.number_input("No repeat n-gram size (0=off)", min_value=0, max_value=10, value=int(base["no_repeat_ngram_size"]), step=1)
repetition_penalty = st.number_input("Repetition penalty (1.0=off)", min_value=1.0, max_value=2.0, value=float(base["repetition_penalty"]), step=0.1)
seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=int(base["seed"]), step=1, key="gen_seed_manual")
sort_by = st.selectbox("Sort by", ["logp/len", "cosine similarity"], index=0)
st.subheader("Diversity-aware Reranking (MMR on internal encoder vectors)")
embedding_mode_manual = st.selectbox(
"Embedding mode",
options=["plain_both", "template_both"],
index=0,
help="plain_both=embed raw query/candidates; template_both=embed with instruction template"
)
select_k = st.number_input("Select top K after rerank", min_value=1, max_value=200, value=int(base["select_k"]), step=1)
mmr_lambda = st.number_input("MMR relevance weight λ (higher = more on-topic, lower = more diverse)", min_value=0.0, max_value=1.0, value=float(base["mmr_lambda"]), step=0.01)
dup_ratio = st.number_input("Near-duplicate threshold (SequenceMatcher ratio)", min_value=0.0, max_value=1.0, value=float(base["dup_ratio"]), step=0.01)
run_gen_manual = st.button("Generate Fan-out (Manual Settings)", key="run_gen_manual")
if run_gen_manual:
cfg = {
"max_candidates": int(max_candidates),
"temperature": float(temperature),
"top_p": float(top_p),
"no_repeat_ngram_size": int(no_repeat_ngram_size),
"repetition_penalty": float(repetition_penalty),
"seed": int(seed_value),
"sort_by": str(sort_by),
"select_k": int(select_k),
"mmr_lambda": float(mmr_lambda),
"dup_ratio": float(dup_ratio),
"embedding_mode": str(embedding_mode_manual),
}
run_generation(url, query, cfg, show_save_controls=True)
# ----------- TAB 2: TESTING -----------
with tab2:
st.header("Testing Mode — Method Comparison")
url = st.text_input("URL", value="airbnb.com", key="test_url")
query = st.text_input("Query", value="airbnb reviews", key="test_query")
num_beams = st.number_input("num_beams", min_value=1, max_value=20, value=5, step=1)
top_n = st.number_input("top_n", min_value=1, max_value=20, value=5, step=1)
temperature = st.number_input("temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.9, step=0.05)
num_beam_groups = st.number_input("num_beam_groups", min_value=1, max_value=20, value=5, step=1)
diversity_penalty = st.number_input("diversity_penalty", min_value=0.0, max_value=5.0, value=1.0, step=0.1)
no_repeat_ngram_size = st.number_input("no_repeat_ngram_size", min_value=0, max_value=10, value=0, step=1)
repetition_penalty = st.number_input("repetition_penalty", min_value=1.0, max_value=2.0, value=1.0, step=0.1)
seed_value = st.number_input("Seed", min_value=0, max_value=2**31 - 1, value=42, step=1, key="test_seed")
run_test = st.button("Run Comparison", key="run_test")
if run_test:
torch.manual_seed(int(seed_value))
if torch.cuda.is_available():
torch.cuda.manual_seed_all(int(seed_value))
inputs, prompt_txt = build_inputs(tok, url, query, device)
best_det = single_best_output(tok, model, device, inputs, num_beams, no_repeat_ngram_size, repetition_penalty)
topn_beam_txts, topn_beam_scores = topn_outputs_beam(tok, model, device, inputs, num_beams, top_n, no_repeat_ngram_size, repetition_penalty)
topn_samp_txts, topn_samp_scores = topn_outputs_sampling(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
ranked_txts, ranked_scores = score_ranked_outputs(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty)
div_txts, div_scores = diverse_beams(tok, model, device, inputs, num_beams, num_beam_groups, diversity_penalty, top_n, no_repeat_ngram_size, repetition_penalty)
per_token = token_by_token_probabilities(tok, model, device, inputs)
combined_output = [f"Input: {prompt_txt}",
"\n[1] Single best (deterministic beam)", best_det,
"\n[2] Top-N (beam)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_beam_txts, topn_beam_scores))] + \
["\n[3] Top-N (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(topn_samp_txts, topn_samp_scores))] + \
["\n[4] Score-ranked (sampling)"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(ranked_txts, ranked_scores))] + \
["\n[5] Diverse beams"] + [f"#{i+1} {fmt_score(sc)} — {txt}" for i, (txt, sc) in enumerate(zip(div_txts, div_scores))] + \
["\n[6] Token-by-token probabilities (greedy)"] + [f"{t} — {p:.4f}" for t, p in per_token]
st.markdown("### Copy/Paste Summary")
st.code("\n".join(combined_output), language="text")
with open("testing_output.txt", "w", encoding="utf-8") as f:
f.write("\n".join(combined_output))
st.success("Saved summary to testing_output.txt")