mlxen's picture
Update app.py
9d6065e verified
# app.py
import gradio as gr
import traceback
import time
import plotly.graph_objects as go
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from chatroutes_autobranch import BranchSelector, Candidate
from functools import lru_cache
# =====================================================
# 🧠 PRESETS – 10 Scenario Profiles
# =====================================================
PRESETS = {
"Reasoning & Problem Solving": {
"model": "microsoft/Phi-3-mini-4k-instruct",
"embedding": "intfloat/e5-small-v2",
"N": 8, "K": 3, "T": 0.8, "MaxTok": 96,
"novelty_method": "cosine", "novelty_threshold": 0.82,
"weights": {"confidence": 0.55, "relevance": 0.30, "novelty_parent": 0.15},
},
"Creative Writing & Storytelling": {
"model": "HuggingFaceH4/zephyr-7b-beta",
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"N": 10, "K": 3, "T": 1.2, "MaxTok": 160,
"novelty_method": "cosine", "novelty_threshold": 0.88,
"weights": {"confidence": 0.30, "relevance": 0.20, "novelty_parent": 0.50},
},
"Data Science & Math Explanations": {
"model": "microsoft/Phi-3-mini-4k-instruct",
"embedding": "intfloat/e5-small-v2",
"N": 8, "K": 3, "T": 0.7, "MaxTok": 96,
"novelty_method": "cosine", "novelty_threshold": 0.78,
"weights": {"confidence": 0.60, "relevance": 0.30, "novelty_parent": 0.10},
},
"Business & Product Thinking": {
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"embedding": "intfloat/e5-small-v2",
"N": 8, "K": 3, "T": 0.85, "MaxTok": 100,
"novelty_method": "cosine", "novelty_threshold": 0.82,
"weights": {"confidence": 0.50, "relevance": 0.30, "novelty_parent": 0.20},
},
"Engineering & Design Trade-offs": {
"model": "microsoft/Phi-3-mini-4k-instruct",
"embedding": "intfloat/e5-small-v2",
"N": 8, "K": 3, "T": 0.75, "MaxTok": 120,
"novelty_method": "cosine", "novelty_threshold": 0.80,
"weights": {"confidence": 0.55, "relevance": 0.35, "novelty_parent": 0.10},
},
"Ethics & Philosophy": {
"model": "HuggingFaceH4/zephyr-7b-beta",
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"N": 8, "K": 3, "T": 1.0, "MaxTok": 128,
"novelty_method": "cosine", "novelty_threshold": 0.85,
"weights": {"confidence": 0.40, "relevance": 0.30, "novelty_parent": 0.30},
},
"Education & Pedagogy": {
"model": "microsoft/Phi-3-mini-4k-instruct",
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"N": 8, "K": 3, "T": 0.9, "MaxTok": 100,
"novelty_method": "cosine", "novelty_threshold": 0.84,
"weights": {"confidence": 0.45, "relevance": 0.30, "novelty_parent": 0.25},
},
"Marketing & Copywriting": {
"model": "microsoft/Phi-3-mini-4k-instruct",
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"N": 10, "K": 3, "T": 1.1, "MaxTok": 96,
"novelty_method": "cosine", "novelty_threshold": 0.85,
"weights": {"confidence": 0.40, "relevance": 0.30, "novelty_parent": 0.30},
},
"Code Generation & Refactoring": {
"model": "microsoft/Phi-3-mini-4k-instruct",
"embedding": "intfloat/e5-small-v2",
"N": 6, "K": 3, "T": 0.6, "MaxTok": 120,
"novelty_method": "mmr", "novelty_threshold": 0.75,
"weights": {"confidence": 0.60, "relevance": 0.30, "novelty_parent": 0.10},
},
"Meta / Self-Exploration": {
"model": "HuggingFaceH4/zephyr-7b-beta",
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"N": 8, "K": 3, "T": 1.0, "MaxTok": 128,
"novelty_method": "cosine", "novelty_threshold": 0.86,
"weights": {"confidence": 0.40, "relevance": 0.30, "novelty_parent": 0.30},
},
}
# =====================================================
# βš™οΈ HELPERS
# =====================================================
@lru_cache(maxsize=3)
def load_textgen(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
return gen, tokenizer
def strip_echo(text, prompt):
lower = text.lower()
plower = prompt.lower().strip()
if lower.startswith(plower):
text = text[len(plower):].lstrip()
return text.strip()
def apply_preset(name):
p = PRESETS[name]
return (
p["model"], p["embedding"],
p["N"], p["K"], p["T"], p["MaxTok"],
p["novelty_method"], p["novelty_threshold"], p["weights"]
)
def make_tree_plot(prompt, kept, pruned):
if not kept and not pruned:
return go.Figure()
labels = [prompt] + [c.id for c in kept] + [c.id for c in pruned]
colors = (
["#2563EB"] +
["#22C55E"] * len(kept) +
["#EF4444"] * len(pruned)
)
sources = [0] * (len(kept) + len(pruned))
targets = list(range(1, len(labels)))
values = [1] * len(targets)
hovertexts = []
for c in kept:
hovertexts.append(f"βœ… {c.id}: kept")
for c in pruned:
reason = getattr(c, "prune_reason", "low novelty/confidence")
hovertexts.append(f"πŸ—‚οΈ {c.id}: {reason}")
fig = go.Figure(go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.3),
label=labels,
color=colors,
hovertemplate="%{label}<extra></extra>",
),
link=dict(
source=sources,
target=targets,
value=values,
color=[colors[i + 1] for i in range(len(targets))],
hovertemplate=[t + "<extra></extra>" for t in hovertexts],
),
))
fig.update_layout(title="Branch Selection Tree", font=dict(size=12))
return fig
# =====================================================
# πŸš€ MAIN RUN
# =====================================================
def run(prompt, num_candidates, top_k, temperature, max_new_tokens,
novelty_method, novelty_threshold, model_name, embedding_model, beam_weights):
start = time.time()
try:
gen, tokenizer = load_textgen(model_name)
out = gen(
prompt,
do_sample=True,
temperature=float(temperature),
max_new_tokens=int(max_new_tokens),
num_return_sequences=int(num_candidates),
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=False, # ⬅️ removes echo
)
candidates = [strip_echo(o["generated_text"], prompt)
for o in out if len(o["generated_text"].strip()) > 30]
except Exception:
tb = traceback.format_exc()
return "", "", f"❌ **Generation failed**\n```\n{tb}\n```", go.Figure()
try:
cfg = {
"beam": {"k": int(top_k), "weights": beam_weights},
"novelty": {"method": str(novelty_method), "threshold": float(novelty_threshold)},
"entropy": {"min_entropy": 0.5},
"embeddings": {"provider": "huggingface", "model": embedding_model},
}
selector = BranchSelector.from_config(cfg)
parent = Candidate(id="root", text=prompt)
cand_objs = [Candidate(id=f"c{i}", text=t) for i, t in enumerate(candidates)]
result = selector.step(parent, cand_objs)
kept = getattr(result, "kept", getattr(result, "selected", []))
kept_ids = {c.id for c in kept}
pruned = [c for c in cand_objs if c.id not in kept_ids]
kept_text = "\n\n--- kept ---\n\n".join([c.text for c in kept]) or "β€”"
pruned_text = "\n\n--- pruned ---\n\n".join([c.text for c in pruned]) or "β€”"
entropy_val = None
if hasattr(result, "metrics"):
ent = result.metrics.get("entropy", None)
if isinstance(ent, dict) and "value" in ent:
entropy_val = ent["value"]
ent_line = (
f"βœ… Model: {model_name}<br>Embedding: {embedding_model}"
+ (f"<br>Entropy: **{entropy_val:.3f}**" if entropy_val else "")
+ f"<br>⏱️ {time.time() - start:.1f}s"
)
tree_plot = make_tree_plot(prompt, kept, pruned)
return kept_text, pruned_text, ent_line, tree_plot
except Exception:
tb = traceback.format_exc()
return (
"\n\n---\n\n".join(candidates),
"β€”",
f"⚠️ **Selector failed**\n```\n{tb}\n```",
go.Figure(),
)
# =====================================================
# πŸ–₯️ UI
# =====================================================
with gr.Blocks(title="AutoBranch β€” Visual Branching Explorer") as demo:
gr.Markdown(
"""
# 🌳 AutoBranch β€” Visual Branching Explorer
Experiment with how AI reasoning **branches and prunes** through multiple ideas.
Visualize **entropy**, **novelty**, and **confidence** as the model explores diverse paths.
"""
)
with gr.Row():
scenario = gr.Dropdown(
choices=list(PRESETS.keys()),
value="Reasoning & Problem Solving",
label="Scenario Preset",
)
with gr.Row():
model_name = gr.Dropdown(
choices=list({v["model"] for v in PRESETS.values()}),
value="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
label="Text Generation Model (Hugging Face)",
)
embedding_model = gr.Dropdown(
choices=list({v["embedding"] for v in PRESETS.values()}),
value="intfloat/e5-small-v2",
label="Embedding Model",
)
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
lines=3,
placeholder="Write a short poem about entropy.",
)
with gr.Row():
num_candidates = gr.Slider(3, 12, value=8, step=1, label="Num candidates (N)")
top_k = gr.Slider(1, 8, value=3, step=1, label="Keep top-K")
temperature = gr.Slider(0.3, 1.5, value=0.9, step=0.1, label="Creativity (temperature)")
max_new_tokens = gr.Slider(32, 128, value=96, step=8, label="Max new tokens")
with gr.Row():
novelty_method = gr.Dropdown(choices=["cosine", "mmr"], value="cosine", label="Novelty method")
novelty_threshold = gr.Slider(0.6, 0.95, value=0.85, step=0.01, label="Novelty threshold")
beam_weights_state = gr.State({"confidence": 0.5, "relevance": 0.3, "novelty_parent": 0.2})
run_btn = gr.Button("Generate ➜ Select", variant="primary")
kept = gr.Textbox(label="βœ… Kept (diverse & high-score)", lines=12)
pruned = gr.Textbox(label="πŸ—‚οΈ Pruned (too similar / low score)", lines=8)
status = gr.Markdown()
tree_plot = gr.Plot(label="Branch Selection Tree")
scenario.change(
apply_preset,
inputs=[scenario],
outputs=[
model_name, embedding_model,
num_candidates, top_k, temperature, max_new_tokens,
novelty_method, novelty_threshold, beam_weights_state,
],
)
run_btn.click(
run,
inputs=[
prompt, num_candidates, top_k, temperature, max_new_tokens,
novelty_method, novelty_threshold, model_name, embedding_model, beam_weights_state,
],
outputs=[kept, pruned, status, tree_plot],
)
demo.launch()