KacBo commited on
Commit
7bf9c5c
·
1 Parent(s): 535e62d

copy MA review board

Browse files
Files changed (1) hide show
  1. app.py +419 -0
app.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient, repo_exists
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+
5
+ # ---------------------------------------------------------------------------
6
+ # Constants
7
+ # ---------------------------------------------------------------------------
8
+
9
+ DEFAULT_MODELS = [
10
+ "meta-llama/Llama-3.1-8B-Instruct",
11
+ "Qwen/Qwen2.5-7B-Instruct-1M",
12
+ "google/gemma-2-2b-it",
13
+ "Qwen/Qwen2.5-Coder-32B-Instruct",
14
+ "deepseek-ai/DeepSeek-R1",
15
+ "mistralai/Mistral-7B-Instruct-v0.3",
16
+ ]
17
+
18
+ DEFAULT_TEMP = 0.7
19
+
20
+ SYSTEM_MESSAGE = (
21
+ "You are a helpful assistant participating in a multi-agent review board. "
22
+ "Provide thoughtful, well-reasoned responses. When reviewing other agents' "
23
+ "responses in later rounds, carefully consider their reasoning and update "
24
+ "your answer if you find compelling arguments."
25
+ )
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Core debate logic (adapted from the research notebook for the HF Inference
30
+ # API -- replaces local transformers pipelines with InferenceClient calls)
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ def generate_answer(
35
+ token: str,
36
+ model: str,
37
+ messages: list[dict],
38
+ temperature: float,
39
+ ) -> str:
40
+ """Call the HF Inference API for a single agent turn."""
41
+ client = InferenceClient(token=token, model=model)
42
+ response = client.chat_completion(
43
+ messages=messages,
44
+ max_tokens=2048,
45
+ temperature=temperature,
46
+ top_p=0.9,
47
+ )
48
+ return response.choices[0].message.content
49
+
50
+
51
+ def construct_review_message(other_responses: list[str]) -> dict:
52
+ """Build a peer-review prompt containing the other agents' latest answers."""
53
+ if not other_responses:
54
+ return {
55
+ "role": "user",
56
+ "content": "Please double-check your answer and provide your final response.",
57
+ }
58
+
59
+ parts = ["These are the responses to the problem from other agents:\n"]
60
+ for i, resp in enumerate(other_responses, 1):
61
+ parts.append(f"Agent {i}'s response:\n```\n{resp}\n```\n")
62
+ parts.append(
63
+ "Using the reasoning from other agents as additional advice, can you give an updated answer? Examine your solution and that of the other agents step by step. Provide your final, updated response."
64
+ )
65
+ return {"role": "user", "content": "\n".join(parts)}
66
+
67
+
68
+ def handle_inference_error(error: Exception, model_name: str) -> str:
69
+ """Return a user-friendly error string for common Inference API failures."""
70
+ raw = repr(error)
71
+ low = raw.lower()
72
+ etype = type(error).__name__.lower()
73
+
74
+ if "timeout" in etype or "timeout" in low:
75
+ return (
76
+ f"Request to '{model_name}' timed out. The model may be loading "
77
+ "(cold start) or overloaded. Try again in a moment."
78
+ )
79
+ if "401" in raw or "403" in raw:
80
+ return (
81
+ f"Access denied for '{model_name}'. Visit the model page on "
82
+ "huggingface.co to accept its license/terms."
83
+ )
84
+ if "404" in raw:
85
+ return f"Model '{model_name}' was not found on Hugging Face Hub."
86
+ if "422" in raw:
87
+ return (
88
+ f"Model '{model_name}' does not support chat completion "
89
+ "via the Inference API."
90
+ )
91
+ if "429" in raw:
92
+ return "Rate limited. Please wait a moment and try again."
93
+ if "402" in raw or "payment" in low or "credit" in low:
94
+ return (
95
+ "Out of Inference API credits. "
96
+ "Check huggingface.co/settings/billing."
97
+ )
98
+ return f"Error with '{model_name}': {raw[:300]}"
99
+
100
+
101
+ def validate_model(model_id: str, token: str | None = None) -> tuple[bool, str]:
102
+ """Return *(ok, error_message)* after checking the model exists on the Hub."""
103
+ if not model_id or not model_id.strip():
104
+ return False, "Model ID cannot be empty."
105
+ model_id = model_id.strip()
106
+ if model_id in DEFAULT_MODELS:
107
+ return True, ""
108
+ try:
109
+ if not repo_exists(model_id, token=token):
110
+ return False, f"Model '{model_id}' not found on Hugging Face Hub."
111
+ return True, ""
112
+ except Exception as exc:
113
+ return False, f"Could not verify '{model_id}': {exc}"
114
+
115
+
116
+ def run_review_board(
117
+ prompt: str,
118
+ agent_configs: list[dict],
119
+ num_rounds: int,
120
+ token: str,
121
+ ):
122
+ """Generator yielding *(status_line, results_or_None)* tuples.
123
+
124
+ *results* is ``None`` during processing and a dict mapping agent labels to
125
+ their final-round response text on the very last yield.
126
+ """
127
+ num_agents = len(agent_configs)
128
+
129
+ # Each agent gets its own conversation history
130
+ agent_contexts: list[list[dict]] = [
131
+ [
132
+ {"role": "system", "content": SYSTEM_MESSAGE},
133
+ {"role": "user", "content": prompt},
134
+ ]
135
+ for _ in range(num_agents)
136
+ ]
137
+
138
+ for round_num in range(num_rounds):
139
+ tag = f"Round {round_num + 1}/{num_rounds}"
140
+ yield f"**{tag}** -- Submitting requests...", None
141
+
142
+ # After the first round, inject peer-review context
143
+ if round_num > 0:
144
+ for i in range(num_agents):
145
+ others: list[str] = []
146
+ for j in range(num_agents):
147
+ if j == i:
148
+ continue
149
+ # Grab the most recent assistant message from agent j
150
+ for msg in reversed(agent_contexts[j]):
151
+ if msg["role"] == "assistant":
152
+ others.append(msg["content"])
153
+ break
154
+ agent_contexts[i].append(construct_review_message(others))
155
+
156
+ # Fan out requests concurrently
157
+ futures: dict = {}
158
+ with ThreadPoolExecutor(max_workers=num_agents) as pool:
159
+ for i, cfg in enumerate(agent_configs):
160
+ fut = pool.submit(
161
+ generate_answer,
162
+ token,
163
+ cfg["model"],
164
+ list(agent_contexts[i]), # shallow copy for thread safety
165
+ cfg["temp"],
166
+ )
167
+ futures[fut] = i
168
+
169
+ for fut in as_completed(futures):
170
+ idx = futures[fut]
171
+ model = agent_configs[idx]["model"]
172
+ try:
173
+ text = fut.result()
174
+ agent_contexts[idx].append(
175
+ {"role": "assistant", "content": text}
176
+ )
177
+ yield (
178
+ f"**{tag}** -- Agent {idx + 1} (`{model}`) responded.",
179
+ None,
180
+ )
181
+ except Exception as exc:
182
+ err = handle_inference_error(exc, model)
183
+ agent_contexts[idx].append(
184
+ {"role": "assistant", "content": f"[Error: {err}]"}
185
+ )
186
+ yield f"**{tag}** -- Agent {idx + 1} error: {err}", None
187
+
188
+ # Collect each agent's final response
189
+ results: dict[str, str] = {}
190
+ for i, cfg in enumerate(agent_configs):
191
+ last = "[No response generated]"
192
+ for msg in reversed(agent_contexts[i]):
193
+ if msg["role"] == "assistant":
194
+ last = msg["content"]
195
+ break
196
+ results[f"Agent {i + 1} -- {cfg['model']}"] = last
197
+
198
+ yield (
199
+ "**Complete!** Select an agent tab below to view their final response.",
200
+ results,
201
+ )
202
+
203
+
204
+ # ---------------------------------------------------------------------------
205
+ # Gradio UI
206
+ # ---------------------------------------------------------------------------
207
+
208
+ CUSTOM_CSS = """
209
+ .agent-header-wrap {
210
+ padding: 0 !important;
211
+ min-height: 0 !important;
212
+ background: rgba(78, 70, 229, 1);
213
+ }
214
+ .agent-header {
215
+ display: block;
216
+ text-align: center;
217
+ cursor: help;
218
+ }
219
+
220
+ .sidebar .group {
221
+ margin-bottom: 8px !important;
222
+ }
223
+
224
+ *::-webkit-scrollbar {
225
+ width: 8px;
226
+ height: 8px;
227
+ }
228
+ *::-webkit-scrollbar-track {
229
+ background: transparent;
230
+ }
231
+ *::-webkit-scrollbar-thumb {
232
+ background: rgba(139, 92, 246, 0.45);
233
+ border-radius: 4px;
234
+ }
235
+ *::-webkit-scrollbar-thumb:hover {
236
+ background: rgba(139, 92, 246, 0.7);
237
+ }
238
+
239
+ /* Themed scrollbar -- Firefox */
240
+ * {
241
+ scrollbar-width: thin;
242
+ scrollbar-color: rgba(139, 92, 246, 0.45) transparent;
243
+ }
244
+ """
245
+
246
+ with gr.Blocks(
247
+ title="Multi-Agent Review Board",
248
+ theme=gr.themes.Soft(),
249
+ css=CUSTOM_CSS,
250
+ ) as demo:
251
+
252
+ # Shared state --------------------------------------------------------
253
+ agents_state = gr.State([0, 1]) # list of unique agent IDs
254
+ next_id_state = gr.State(2) # counter for the next ID to assign
255
+ results_state = gr.State({}) # final responses dict (empty until a run)
256
+
257
+ # ---- Sidebar --------------------------------------------------------
258
+ with gr.Sidebar():
259
+ gr.LoginButton()
260
+ gr.Markdown("---")
261
+
262
+ gr.Markdown("### Settings")
263
+ num_rounds = gr.Slider(
264
+ minimum=1,
265
+ maximum=10,
266
+ value=2,
267
+ step=1,
268
+ label="Rounds",
269
+ info="Round 1 = independent answers. Round 2+ = peer review.",
270
+ interactive=True
271
+ )
272
+
273
+ gr.Markdown("---")
274
+ gr.Markdown("### Agents")
275
+
276
+ # Dynamic agent configuration rows
277
+ @gr.render(inputs=agents_state)
278
+ def render_agents(agent_ids):
279
+ dropdowns: list = []
280
+ sliders: list = []
281
+
282
+ for idx, aid in enumerate(agent_ids):
283
+ default_model = DEFAULT_MODELS[idx % len(DEFAULT_MODELS)]
284
+
285
+ with gr.Group():
286
+ with gr.Row():
287
+ gr.HTML(
288
+ f'<span class="agent-header" title="Pick a model or type any HF model ID">'
289
+ f'<strong>Agent {idx + 1}</strong></span>',
290
+ elem_classes=["agent-header-wrap"],
291
+ )
292
+ if len(agent_ids) > 2:
293
+ del_btn = gr.Button(
294
+ "✕",
295
+ variant="stop",
296
+ size="sm",
297
+ min_width=36,
298
+ scale=0,
299
+ key=f"del-{aid}",
300
+ )
301
+
302
+ # Freeze `aid` via default-arg so each button deletes the correct agent
303
+ def _delete(current_ids, _target=aid):
304
+ return [x for x in current_ids if x != _target]
305
+
306
+ del_btn.click(_delete, agents_state, agents_state)
307
+
308
+ dd = gr.Dropdown(
309
+ choices=DEFAULT_MODELS,
310
+ value=default_model,
311
+ allow_custom_value=True,
312
+ label=None,
313
+ show_label=False,
314
+ key=f"model-{aid}",
315
+ interactive=True
316
+ )
317
+ temp = gr.Slider(
318
+ minimum=0.1,
319
+ maximum=2.0,
320
+ value=DEFAULT_TEMP,
321
+ step=0.1,
322
+ label="Temperature",
323
+ key=f"temp-{aid}",
324
+ interactive=True
325
+ )
326
+
327
+ dropdowns.append(dd)
328
+ sliders.append(temp)
329
+
330
+ # ---- Wire the Run button (defined further below) ----
331
+ def on_run(data, hf_token: gr.OAuthToken | None = None):
332
+ if hf_token is None:
333
+ raise gr.Error(
334
+ "Please log in with your Hugging Face account first "
335
+ "(use the Login button in the sidebar)."
336
+ )
337
+
338
+ prompt = data[prompt_tb]
339
+ rounds = data[num_rounds]
340
+
341
+ if not prompt or not prompt.strip():
342
+ raise gr.Error("Please enter a prompt.")
343
+
344
+ models = [data[dd] for dd in dropdowns]
345
+ temps = [data[sl] for sl in sliders]
346
+
347
+ # Build and validate agent configs
348
+ configs: list[dict] = []
349
+ for i, (model, t) in enumerate(zip(models, temps)):
350
+ if not model or not model.strip():
351
+ raise gr.Error(
352
+ f"Agent {i + 1}: please select or enter a model."
353
+ )
354
+ model = model.strip()
355
+ if model not in DEFAULT_MODELS:
356
+ ok, err = validate_model(model, hf_token.token)
357
+ if not ok:
358
+ raise gr.Error(f"Agent {i + 1}: {err}")
359
+ configs.append({"model": model, "temp": float(t)})
360
+
361
+ # Stream progress as an accumulating log
362
+ log: list[str] = []
363
+ for status_line, results in run_review_board(
364
+ prompt.strip(), configs, int(rounds), hf_token.token
365
+ ):
366
+ log.append(status_line)
367
+ yield (
368
+ "\n\n".join(log),
369
+ results if results is not None else {},
370
+ )
371
+
372
+ run_btn.click(
373
+ on_run,
374
+ inputs={prompt_tb, num_rounds} | set(dropdowns) | set(sliders),
375
+ outputs=[status_md, results_state],
376
+ )
377
+
378
+ # "Add Agent" sits outside @gr.render so it stays at the bottom
379
+ add_btn = gr.Button("+ Add Agent", variant="secondary", size="sm")
380
+
381
+ def _add_agent(ids, nid):
382
+ return ids + [nid], nid + 1
383
+
384
+ add_btn.click(
385
+ _add_agent,
386
+ [agents_state, next_id_state],
387
+ [agents_state, next_id_state],
388
+ )
389
+
390
+ # ---- Main area ------------------------------------------------------
391
+ gr.Markdown("# Multi-Agent Review Board")
392
+ gr.Markdown(
393
+ "Configure your agents in the sidebar, enter a prompt, and let "
394
+ "multiple AI models debate and refine their answers across rounds."
395
+ )
396
+
397
+ prompt_tb = gr.Textbox(
398
+ label="Prompt",
399
+ placeholder="Enter your question or prompt here...",
400
+ lines=4,
401
+ )
402
+ run_btn = gr.Button("Run Review Board", variant="primary", size="lg")
403
+ status_md = gr.Markdown("")
404
+
405
+ # ---- Dynamic results tabs -------------------------------------------
406
+ @gr.render(inputs=results_state)
407
+ def render_results(results):
408
+ if not results:
409
+ return
410
+ gr.Markdown("---")
411
+ gr.Markdown("### Final Responses")
412
+ with gr.Tabs():
413
+ for name, response in results.items():
414
+ with gr.TabItem(name):
415
+ gr.Markdown(response)
416
+
417
+
418
+ if __name__ == "__main__":
419
+ demo.launch()