mlxen commited on
Commit
9d6065e
Β·
verified Β·
1 Parent(s): 1e26b91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -26
app.py CHANGED
@@ -5,28 +5,28 @@ import time
5
  import plotly.graph_objects as go
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from chatroutes_autobranch import BranchSelector, Candidate
8
-
9
 
10
  # =====================================================
11
- # 🧠 PRESETS
12
  # =====================================================
13
  PRESETS = {
14
  "Reasoning & Problem Solving": {
15
- "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
16
  "embedding": "intfloat/e5-small-v2",
17
  "N": 8, "K": 3, "T": 0.8, "MaxTok": 96,
18
  "novelty_method": "cosine", "novelty_threshold": 0.82,
19
  "weights": {"confidence": 0.55, "relevance": 0.30, "novelty_parent": 0.15},
20
  },
21
  "Creative Writing & Storytelling": {
22
- "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
23
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
24
- "N": 10, "K": 3, "T": 1.1, "MaxTok": 120,
25
  "novelty_method": "cosine", "novelty_threshold": 0.88,
26
- "weights": {"confidence": 0.35, "relevance": 0.25, "novelty_parent": 0.40},
27
  },
28
- "Data Science & Math": {
29
- "model": "microsoft/phi-2",
30
  "embedding": "intfloat/e5-small-v2",
31
  "N": 8, "K": 3, "T": 0.7, "MaxTok": 96,
32
  "novelty_method": "cosine", "novelty_threshold": 0.78,
@@ -39,21 +39,57 @@ PRESETS = {
39
  "novelty_method": "cosine", "novelty_threshold": 0.82,
40
  "weights": {"confidence": 0.50, "relevance": 0.30, "novelty_parent": 0.20},
41
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  "Marketing & Copywriting": {
43
- "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
44
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
45
- "N": 10, "K": 3, "T": 1.1, "MaxTok": 80,
46
- "novelty_method": "cosine", "novelty_threshold": 0.88,
47
- "weights": {"confidence": 0.35, "relevance": 0.25, "novelty_parent": 0.40},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  },
49
  }
50
 
51
 
 
52
  # =====================================================
53
  # βš™οΈ HELPERS
54
  # =====================================================
 
55
  def load_textgen(model_name):
56
- """Load text-generation pipeline for selected model."""
57
  tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  model = AutoModelForCausalLM.from_pretrained(model_name)
59
  if tokenizer.pad_token_id is None:
@@ -62,8 +98,15 @@ def load_textgen(model_name):
62
  return gen, tokenizer
63
 
64
 
 
 
 
 
 
 
 
 
65
  def apply_preset(name):
66
- """Apply preset values to the UI."""
67
  p = PRESETS[name]
68
  return (
69
  p["model"], p["embedding"],
@@ -73,15 +116,14 @@ def apply_preset(name):
73
 
74
 
75
  def make_tree_plot(prompt, kept, pruned):
76
- """Visual tree (prompt β†’ kept/pruned candidates)."""
77
  if not kept and not pruned:
78
  return go.Figure()
79
 
80
  labels = [prompt] + [c.id for c in kept] + [c.id for c in pruned]
81
  colors = (
82
- ["#2563EB"]
83
- + ["#22C55E"] * len(kept)
84
- + ["#EF4444"] * len(pruned)
85
  )
86
  sources = [0] * (len(kept) + len(pruned))
87
  targets = list(range(1, len(labels)))
@@ -115,9 +157,11 @@ def make_tree_plot(prompt, kept, pruned):
115
  return fig
116
 
117
 
 
 
 
118
  def run(prompt, num_candidates, top_k, temperature, max_new_tokens,
119
  novelty_method, novelty_threshold, model_name, embedding_model, beam_weights):
120
- """Main run logic."""
121
  start = time.time()
122
  try:
123
  gen, tokenizer = load_textgen(model_name)
@@ -127,11 +171,12 @@ def run(prompt, num_candidates, top_k, temperature, max_new_tokens,
127
  temperature=float(temperature),
128
  max_new_tokens=int(max_new_tokens),
129
  num_return_sequences=int(num_candidates),
130
- pad_token_id=tokenizer.pad_token_id,
131
- return_full_text=True,
 
132
  )
133
- candidates = [o["generated_text"] for o in out]
134
- cand_objs = [Candidate(id=f"c{i}", text=t) for i, t in enumerate(candidates)]
135
  except Exception:
136
  tb = traceback.format_exc()
137
  return "", "", f"❌ **Generation failed**\n```\n{tb}\n```", go.Figure()
@@ -145,6 +190,7 @@ def run(prompt, num_candidates, top_k, temperature, max_new_tokens,
145
  }
146
  selector = BranchSelector.from_config(cfg)
147
  parent = Candidate(id="root", text=prompt)
 
148
  result = selector.step(parent, cand_objs)
149
 
150
  kept = getattr(result, "kept", getattr(result, "selected", []))
@@ -180,11 +226,12 @@ def run(prompt, num_candidates, top_k, temperature, max_new_tokens,
180
  # =====================================================
181
  # πŸ–₯️ UI
182
  # =====================================================
183
- with gr.Blocks(title="AutoBranch β€” Visual Scenario Explorer") as demo:
184
  gr.Markdown(
185
  """
186
- # 🌳 AutoBranch β€” Visual Scenario Explorer
187
- Select a **scenario preset**, edit parameters, and visualize how beam search selects or prunes candidates based on **novelty**, **entropy**, and **confidence**.
 
188
  """
189
  )
190
 
 
5
  import plotly.graph_objects as go
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from chatroutes_autobranch import BranchSelector, Candidate
8
+ from functools import lru_cache
9
 
10
  # =====================================================
11
+ # 🧠 PRESETS – 10 Scenario Profiles
12
  # =====================================================
13
  PRESETS = {
14
  "Reasoning & Problem Solving": {
15
+ "model": "microsoft/Phi-3-mini-4k-instruct",
16
  "embedding": "intfloat/e5-small-v2",
17
  "N": 8, "K": 3, "T": 0.8, "MaxTok": 96,
18
  "novelty_method": "cosine", "novelty_threshold": 0.82,
19
  "weights": {"confidence": 0.55, "relevance": 0.30, "novelty_parent": 0.15},
20
  },
21
  "Creative Writing & Storytelling": {
22
+ "model": "HuggingFaceH4/zephyr-7b-beta",
23
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
24
+ "N": 10, "K": 3, "T": 1.2, "MaxTok": 160,
25
  "novelty_method": "cosine", "novelty_threshold": 0.88,
26
+ "weights": {"confidence": 0.30, "relevance": 0.20, "novelty_parent": 0.50},
27
  },
28
+ "Data Science & Math Explanations": {
29
+ "model": "microsoft/Phi-3-mini-4k-instruct",
30
  "embedding": "intfloat/e5-small-v2",
31
  "N": 8, "K": 3, "T": 0.7, "MaxTok": 96,
32
  "novelty_method": "cosine", "novelty_threshold": 0.78,
 
39
  "novelty_method": "cosine", "novelty_threshold": 0.82,
40
  "weights": {"confidence": 0.50, "relevance": 0.30, "novelty_parent": 0.20},
41
  },
42
+ "Engineering & Design Trade-offs": {
43
+ "model": "microsoft/Phi-3-mini-4k-instruct",
44
+ "embedding": "intfloat/e5-small-v2",
45
+ "N": 8, "K": 3, "T": 0.75, "MaxTok": 120,
46
+ "novelty_method": "cosine", "novelty_threshold": 0.80,
47
+ "weights": {"confidence": 0.55, "relevance": 0.35, "novelty_parent": 0.10},
48
+ },
49
+ "Ethics & Philosophy": {
50
+ "model": "HuggingFaceH4/zephyr-7b-beta",
51
+ "embedding": "sentence-transformers/all-MiniLM-L6-v2",
52
+ "N": 8, "K": 3, "T": 1.0, "MaxTok": 128,
53
+ "novelty_method": "cosine", "novelty_threshold": 0.85,
54
+ "weights": {"confidence": 0.40, "relevance": 0.30, "novelty_parent": 0.30},
55
+ },
56
+ "Education & Pedagogy": {
57
+ "model": "microsoft/Phi-3-mini-4k-instruct",
58
+ "embedding": "sentence-transformers/all-MiniLM-L6-v2",
59
+ "N": 8, "K": 3, "T": 0.9, "MaxTok": 100,
60
+ "novelty_method": "cosine", "novelty_threshold": 0.84,
61
+ "weights": {"confidence": 0.45, "relevance": 0.30, "novelty_parent": 0.25},
62
+ },
63
  "Marketing & Copywriting": {
64
+ "model": "microsoft/Phi-3-mini-4k-instruct",
65
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
66
+ "N": 10, "K": 3, "T": 1.1, "MaxTok": 96,
67
+ "novelty_method": "cosine", "novelty_threshold": 0.85,
68
+ "weights": {"confidence": 0.40, "relevance": 0.30, "novelty_parent": 0.30},
69
+ },
70
+ "Code Generation & Refactoring": {
71
+ "model": "microsoft/Phi-3-mini-4k-instruct",
72
+ "embedding": "intfloat/e5-small-v2",
73
+ "N": 6, "K": 3, "T": 0.6, "MaxTok": 120,
74
+ "novelty_method": "mmr", "novelty_threshold": 0.75,
75
+ "weights": {"confidence": 0.60, "relevance": 0.30, "novelty_parent": 0.10},
76
+ },
77
+ "Meta / Self-Exploration": {
78
+ "model": "HuggingFaceH4/zephyr-7b-beta",
79
+ "embedding": "sentence-transformers/all-MiniLM-L6-v2",
80
+ "N": 8, "K": 3, "T": 1.0, "MaxTok": 128,
81
+ "novelty_method": "cosine", "novelty_threshold": 0.86,
82
+ "weights": {"confidence": 0.40, "relevance": 0.30, "novelty_parent": 0.30},
83
  },
84
  }
85
 
86
 
87
+
88
  # =====================================================
89
  # βš™οΈ HELPERS
90
  # =====================================================
91
+ @lru_cache(maxsize=3)
92
  def load_textgen(model_name):
 
93
  tokenizer = AutoTokenizer.from_pretrained(model_name)
94
  model = AutoModelForCausalLM.from_pretrained(model_name)
95
  if tokenizer.pad_token_id is None:
 
98
  return gen, tokenizer
99
 
100
 
101
+ def strip_echo(text, prompt):
102
+ lower = text.lower()
103
+ plower = prompt.lower().strip()
104
+ if lower.startswith(plower):
105
+ text = text[len(plower):].lstrip()
106
+ return text.strip()
107
+
108
+
109
  def apply_preset(name):
 
110
  p = PRESETS[name]
111
  return (
112
  p["model"], p["embedding"],
 
116
 
117
 
118
  def make_tree_plot(prompt, kept, pruned):
 
119
  if not kept and not pruned:
120
  return go.Figure()
121
 
122
  labels = [prompt] + [c.id for c in kept] + [c.id for c in pruned]
123
  colors = (
124
+ ["#2563EB"] +
125
+ ["#22C55E"] * len(kept) +
126
+ ["#EF4444"] * len(pruned)
127
  )
128
  sources = [0] * (len(kept) + len(pruned))
129
  targets = list(range(1, len(labels)))
 
157
  return fig
158
 
159
 
160
+ # =====================================================
161
+ # πŸš€ MAIN RUN
162
+ # =====================================================
163
  def run(prompt, num_candidates, top_k, temperature, max_new_tokens,
164
  novelty_method, novelty_threshold, model_name, embedding_model, beam_weights):
 
165
  start = time.time()
166
  try:
167
  gen, tokenizer = load_textgen(model_name)
 
171
  temperature=float(temperature),
172
  max_new_tokens=int(max_new_tokens),
173
  num_return_sequences=int(num_candidates),
174
+ pad_token_id=tokenizer.eos_token_id,
175
+ eos_token_id=tokenizer.eos_token_id,
176
+ return_full_text=False, # ⬅️ removes echo
177
  )
178
+ candidates = [strip_echo(o["generated_text"], prompt)
179
+ for o in out if len(o["generated_text"].strip()) > 30]
180
  except Exception:
181
  tb = traceback.format_exc()
182
  return "", "", f"❌ **Generation failed**\n```\n{tb}\n```", go.Figure()
 
190
  }
191
  selector = BranchSelector.from_config(cfg)
192
  parent = Candidate(id="root", text=prompt)
193
+ cand_objs = [Candidate(id=f"c{i}", text=t) for i, t in enumerate(candidates)]
194
  result = selector.step(parent, cand_objs)
195
 
196
  kept = getattr(result, "kept", getattr(result, "selected", []))
 
226
  # =====================================================
227
  # πŸ–₯️ UI
228
  # =====================================================
229
+ with gr.Blocks(title="AutoBranch β€” Visual Branching Explorer") as demo:
230
  gr.Markdown(
231
  """
232
+ # 🌳 AutoBranch β€” Visual Branching Explorer
233
+ Experiment with how AI reasoning **branches and prunes** through multiple ideas.
234
+ Visualize **entropy**, **novelty**, and **confidence** as the model explores diverse paths.
235
  """
236
  )
237