dicksinyass commited on
Commit
c845a2a
·
verified ·
1 Parent(s): 284cd04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +347 -312
app.py CHANGED
@@ -5,411 +5,446 @@ import threading
5
  import torch
6
  import os
7
  import time
8
- from typing import List, Dict, Generator, Tuple, Optional
9
  import logging
 
 
 
10
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
13
  logger = logging.getLogger(__name__)
14
 
15
- # --- Best Free Models for Council ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  MODELS = [
17
- ("meta-llama/Meta-Llama-3-8B-Instruct", "Llama 3 8B Instruct"),
18
- ("Qwen/Qwen1.5-7B-Chat", "Qwen1.5 7B Chat"),
19
- ("HuggingFaceH4/zephyr-7b-beta", "Zephyr 7B Beta"),
20
- ("mistralai/Mistral-7B-Instruct-v0.2", "Mistral 7B Instruct"),
21
  ]
22
 
23
- # Define council member personas with enhanced characteristics
24
  PERSONAS = [
25
- {
26
- "name": "Dr. Ana Rodriguez",
27
- "description": "An analytical scientist who values empirical evidence and logical reasoning. Often plays devil's advocate and questions assumptions.",
28
- "traits": "analytical, skeptical, evidence-focused",
29
- "style": "formal, precise, methodical",
30
- "emoji": "🔬"
31
- },
32
- {
33
- "name": "Professor Marcus Chen",
34
- "description": "A creative philosopher with an interest in ethics and societal implications. Considers the bigger picture and long-term consequences.",
35
- "traits": "philosophical, visionary, empathetic",
36
- "style": "eloquent, metaphorical, conceptual",
37
- "emoji": "🧠"
38
- },
39
- {
40
- "name": "Sarah Johnson",
41
- "description": "A pragmatic problem-solver with real-world experience. Focuses on practicality and implementation details.",
42
- "traits": "practical, solution-oriented, experienced",
43
- "style": "direct, concise, example-driven",
44
- "emoji": "🛠️"
45
- },
46
- {
47
- "name": "Dr. Emeka Okafor",
48
- "description": "A social scientist specializing in cultural perspectives and community impacts. Brings diverse viewpoints and contextual understanding.",
49
- "traits": "culturally aware, nuanced, community-focused",
50
- "style": "inclusive, storytelling, perspective-oriented",
51
- "emoji": "🌍"
52
- }
53
  ]
54
 
55
- # Cache for models to avoid reloading
56
  model_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- def load_model(model_id: str) -> Tuple[pipeline, AutoTokenizer]:
59
- """Load model and tokenizer with caching to improve performance"""
60
  global model_cache
61
-
62
- if model_id in model_cache:
63
- logger.info(f"Using cached model: {model_id}")
64
- return model_cache[model_id]
65
-
66
- logger.info(f"Loading model: {model_id}")
 
 
 
 
67
  try:
68
- # Set environmental variables for optimizations
69
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
70
-
71
- # Load tokenizer and model
72
- tokenizer = AutoTokenizer.from_pretrained(model_id)
73
-
74
- # Determine if CUDA is available and set appropriate device
75
- device = "cuda" if torch.cuda.is_available() else "cpu"
76
-
77
- # Configure model loading for memory efficiency
78
  model_kwargs = {
79
  "trust_remote_code": True,
80
- "device_map": "auto",
81
  "torch_dtype": torch.float16 if device == "cuda" else torch.float32
82
  }
83
-
84
- model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
85
-
86
- # Create pipeline with appropriate settings
87
- pipe = pipeline("text-generation",
88
- model=model,
89
- tokenizer=tokenizer,
90
- max_new_tokens=512,
91
- device=model.device)
92
-
93
- # Cache the model and tokenizer
94
- model_cache[model_id] = (pipe, tokenizer)
95
- logger.info(f"Model loaded successfully: {model_id} on {device}")
 
96
  return pipe, tokenizer
97
-
98
  except Exception as e:
99
- logger.error(f"Failed to load model {model_id}: {str(e)}")
100
  raise
101
 
102
- def create_debate_prompt(user_prompt: str,
103
- persona: Dict,
104
- debate_style: str = "Balanced",
105
- previous_responses: Optional[List[str]] = None) -> str:
106
- """Create a prompt that encourages a natural debate-like response with adjustable style"""
107
- persona_desc = f"You are {persona['name']}, {persona['description']} Your communication style is {persona['style']}."
108
-
109
- # Adjust prompt based on debate style
110
- style_guidance = ""
111
- if debate_style == "Collaborative":
112
- style_guidance = "Focus on building upon and synthesizing the ideas of others. Look for common ground and areas of agreement."
113
- elif debate_style == "Adversarial":
114
- style_guidance = "Challenge assumptions and present contrasting viewpoints. Don't be afraid to disagree strongly with others."
115
- else: # Balanced
116
- style_guidance = "Present your authentic perspective while being respectful of other viewpoints. Balance critique with constructive ideas."
117
-
118
- if not previous_responses:
119
- prompt = f"""{persona_desc}
120
 
121
- You are part of a council debating the following topic:
122
- "{user_prompt}"
123
 
124
  {style_guidance}
125
-
126
- Give your authentic perspective on this topic based on your persona. Be natural and conversational.
127
- Directly address the topic without hedging or being overly formal. Make specific points that others can respond to.
128
- Keep your response to 3-4 paragraphs maximum.
129
-
130
- {persona['name']}:"""
131
- else:
132
  debate_history = "\n\n".join(previous_responses)
133
- prompt = f"""{persona_desc}
134
-
135
- You are part of a council debating the following topic:
136
- "{user_prompt}"
137
 
138
- {style_guidance}
139
-
140
- The debate so far:
141
  {debate_history}
142
 
143
- Now it's your turn to speak. Based on your persona and the previous speakers:
144
- - You may agree or disagree with previous points
145
- - Add new perspectives they missed
146
- - Point out flaws in reasoning or suggest compromises
147
- - Address someone directly if appropriate
148
- - Be authentic to your character - don't just summarize
149
 
150
- Give your natural, conversational response as if in a real discussion.
151
- Keep your response to 3-4 paragraphs maximum.
152
-
153
- {persona['name']}:"""
154
-
155
- return prompt
156
 
157
  def create_synthesis_prompt(user_prompt: str, all_responses: List[str]) -> str:
158
- """Create a prompt for the facilitator to synthesize the debate"""
159
  debate_history = "\n\n".join(all_responses)
160
- prompt = f"""You are the Facilitator, responsible for synthesizing the council's discussion on:
161
- "{user_prompt}"
162
 
163
- The full debate:
164
  {debate_history}
165
 
166
- Provide a thoughtful synthesis that:
167
- 1. Identifies the key points of agreement and disagreement
168
- 2. Highlights the most compelling insights from each perspective
169
- 3. Draws a balanced conclusion that respects the nuance of the discussion
170
- 4. Offers a path forward or recommendation when appropriate
171
-
172
- Be concise but comprehensive. Focus on substance over style.
173
- Keep your synthesis to 3-5 paragraphs maximum.
174
 
175
  Facilitator:"""
176
- return prompt
177
 
178
- def stream_model_response(pipe: pipeline,
179
- tokenizer: AutoTokenizer,
180
- prompt: str,
181
- speaker_name: str,
182
- temperature: float = 0.7) -> Generator[str, None, None]:
183
- """Stream model responses with better error handling"""
 
 
184
  try:
185
- # Set up the streamer
186
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
187
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(pipe.model.device)
188
-
189
- # Run model generation in a separate thread
190
  generation_kwargs = dict(
191
  input_ids=input_ids,
192
  streamer=streamer,
193
- max_new_tokens=512,
194
  do_sample=True,
195
  temperature=temperature,
196
  top_p=0.95,
197
  repetition_penalty=1.1,
198
  eos_token_id=tokenizer.eos_token_id,
199
  )
200
-
201
- thread = threading.Thread(
202
- target=pipe.model.generate,
203
- kwargs=generation_kwargs
204
- )
205
  thread.start()
206
-
207
- # Stream the response as it's generated
208
- response = ""
209
  for new_text in streamer:
210
- response += new_text
211
- # Add the emoji to the speaker name
212
- yield f"**{speaker_name}:** {response.strip()}"
213
-
 
 
214
  thread.join()
215
- return response.strip()
216
-
 
 
 
217
  except Exception as e:
218
- logger.error(f"Error streaming response: {str(e)}")
219
- yield f"**{speaker_name}:** [Error generating response. Please try again.]"
 
220
 
221
- def council_chat_stream(user_prompt: str,
222
- num_members: int = 3,
223
- debate_style: str = "Balanced",
224
- temperature: float = 0.7) -> Generator[str, None, None]:
225
- """Generate a council debate with configurable number of members and style"""
226
- # Validate inputs
 
 
 
227
  if not user_prompt.strip():
228
- yield "Please enter a topic for the council to debate."
229
  return
230
-
 
231
  start_time = time.time()
232
-
233
- # Determine which personas and models to use
234
- selected_personas = random.sample(PERSONAS, min(num_members, len(PERSONAS)))
235
- selected_models = random.sample(MODELS, min(num_members, len(MODELS)))
236
-
237
- # Load models
238
  loaded_models = []
239
- for model_id, _ in selected_models:
240
  try:
241
- pipe, tokenizer = load_model(model_id)
242
- loaded_models.append((pipe, tokenizer))
 
 
243
  except Exception as e:
244
- logger.error(f"Failed to load model {model_id}: {str(e)}")
245
- yield f"Error loading model {model_id}. Please try again."
246
- return
247
-
 
 
248
  responses = []
249
  formatted_responses = []
250
  persona_responses = []
251
-
252
- # Generate responses from each council member
253
- for i, (persona, (pipe, tokenizer), (model_id, model_name)) in enumerate(zip(selected_personas, loaded_models, selected_models)):
254
- display_name = f"{persona['emoji']} {persona['name']} ({model_name})"
255
-
256
- if i == 0:
257
- prompt = create_debate_prompt(user_prompt, persona, debate_style)
258
- else:
259
- prompt = create_debate_prompt(user_prompt, persona, debate_style, persona_responses)
260
-
261
- # Stream and collect response
262
- response_text = ""
263
- for partial in stream_model_response(pipe, tokenizer, prompt, display_name, temperature):
264
- # Format the full output
265
- current_output = f"**User:** {user_prompt}\n\n" + "\n\n".join(formatted_responses + [partial])
 
 
 
266
  yield current_output
267
- response_text = partial.split("**:")[-1].strip()
268
-
269
- # Add this response to the collected responses
270
- persona_responses.append(f"{persona['name']}: {response_text}")
271
- formatted_responses.append(partial)
272
-
273
- # Facilitator synthesis (use a random model)
274
- rand_model_idx = random.randint(0, len(loaded_models) - 1)
275
- pipe, tokenizer = loaded_models[rand_model_idx]
276
-
277
- synthesis_prompt = create_synthesis_prompt(user_prompt, persona_responses)
278
  synthesis = ""
279
-
280
- for partial in stream_model_response(pipe, tokenizer, synthesis_prompt, "✨ Facilitator's Synthesis", temperature):
281
- current_output = f"**User:** {user_prompt}\n\n" + "\n\n".join(formatted_responses + [partial])
282
  yield current_output
283
- synthesis = partial
284
-
285
- # Final output with timing
286
  elapsed_time = time.time() - start_time
287
- transcript = f"**User:** {user_prompt}\n\n" + "\n\n".join(formatted_responses) + f"\n\n{synthesis}\n\n---\n*Debate completed in {elapsed_time:.1f} seconds*"
 
 
 
 
 
288
  yield transcript
289
 
290
- # Gradio interface with improved UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  def build_gradio_interface():
292
- """Build a more structured and visually appealing Gradio interface"""
293
-
294
- # Custom CSS for better appearance
295
  custom_css = """
296
- .gradio-container {
297
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
298
- }
299
- .council-header {
300
- text-align: center;
301
- margin-bottom: 1em;
302
- }
303
- .council-member {
304
- margin: 0.5em 0;
305
- padding: 0.5em;
306
- border-radius: 8px;
307
- background-color: #f5f5f5;
308
- }
309
  """
310
-
311
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
312
- gr.Markdown("# 🤖🏛️ AI Council Debate", elem_classes=["council-header"])
313
- gr.Markdown("Ask a question and watch as AI personas debate and deliberate on your topic with different perspectives.")
314
-
315
  with gr.Row():
316
- with gr.Column():
317
- inp = gr.Textbox(
318
- label="Your Topic or Question",
319
- lines=4,
320
- placeholder="Enter a topic, question, or issue for the council to debate..."
 
321
  )
322
-
323
- # Advanced options
324
- with gr.Accordion("Advanced Options", open=False):
325
  with gr.Row():
326
  num_members = gr.Slider(
327
- minimum=2,
328
- maximum=len(PERSONAS),
329
- value=3,
330
- step=1,
331
  label="Number of Council Members"
332
  )
333
-
334
- with gr.Row():
335
  debate_style = gr.Radio(
336
- ["Collaborative", "Adversarial", "Balanced"],
337
- label="Debate Style",
338
- value="Balanced"
339
  )
340
-
341
  with gr.Row():
342
  temperature = gr.Slider(
343
- minimum=0.1,
344
- maximum=1.0,
345
- value=0.7,
346
- step=0.1,
347
- label="Temperature (Creativity)"
 
 
 
 
 
348
  )
349
-
350
- btn = gr.Button("Start Council Debate", variant="primary")
351
-
352
- with gr.Column():
353
- out = gr.Markdown(label="Council Debate Transcript")
354
-
355
- # Display council members information
356
- with gr.Accordion("Meet the Council Members", open=False):
357
- member_info = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  for persona in PERSONAS:
359
- member_info += f"""
360
- <div class="council-member">
361
- <h3>{persona['emoji']} {persona['name']}</h3>
362
- <p><strong>Description:</strong> {persona['description']}</p>
363
- <p><strong>Traits:</strong> {persona['traits']}</p>
364
- <p><strong>Communication Style:</strong> {persona['style']}</p>
365
- </div>
366
- """
367
- gr.HTML(member_info)
368
-
369
- # Example prompts for users to try
370
- with gr.Accordion("Example Topics", open=False):
371
- examples = [
372
- "What role should AI play in education?",
373
- "Is universal basic income a good idea?",
374
- "How should society balance privacy concerns with security needs?",
375
- "What are the ethical implications of genetic engineering?",
376
- "How can we address climate change effectively?"
377
- ]
378
- gr.Examples(examples=examples, inputs=inp)
379
-
380
- # Event handlers
381
- btn.click(
382
- fn=council_chat_stream,
383
- inputs=[inp, num_members, debate_style, temperature],
384
- outputs=out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  )
386
-
387
- # Footer with additional information
388
- gr.Markdown("""
389
- ### About This App
390
-
391
- This application demonstrates how multiple AI models can collaborate in a structured debate.
392
- Each AI persona has distinctive traits and perspectives that influence how they approach topics.
393
-
394
- The models used are open-source LLMs hosted on Hugging Face:
395
- - Meta's Llama 3 8B Instruct
396
- - Qwen 1.5 7B Chat
397
- - Zephyr 7B Beta
398
- - Mistral 7B Instruct v0.2
399
-
400
- ⚠️ Note: First-time loading may take a minute as models are downloaded and initialized.
401
- """)
402
-
403
  return demo
404
 
405
- # Main application
406
  if __name__ == "__main__":
407
- # Check GPU availability
408
- if torch.cuda.is_available():
409
- logger.info(f"GPU available: {torch.cuda.get_device_name(0)}")
 
410
  else:
411
- logger.info("No GPU available, using CPU. Performance may be slower.")
412
-
413
- # Create and launch the Gradio interface
414
  demo = build_gradio_interface()
415
  demo.launch()
 
5
  import torch
6
  import os
7
  import time
8
+ from typing import List, Dict, Generator, Tuple, Optional, Union
9
  import logging
10
+ import warnings
11
+ from dataclasses import dataclass
12
+ import gc
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
  logger = logging.getLogger(__name__)
17
 
18
+ warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly")
19
+
20
+ @dataclass
21
+ class ModelInfo:
22
+ id: str
23
+ name: str
24
+ required_memory: str # Estimated VRAM requirement
25
+
26
+ @dataclass
27
+ class Persona:
28
+ name: str
29
+ description: str
30
+ traits: str
31
+ style: str
32
+ emoji: str
33
+
34
  MODELS = [
35
+ ModelInfo("meta-llama/Meta-Llama-3-8B-Instruct", "Llama 3 8B Instruct", "16GB"),
36
+ ModelInfo("Qwen/Qwen1.5-7B-Chat", "Qwen1.5 7B Chat", "14GB"),
37
+ ModelInfo("HuggingFaceH4/zephyr-7b-beta", "Zephyr 7B Beta", "14GB"),
38
+ ModelInfo("mistralai/Mistral-7B-Instruct-v0.2", "Mistral 7B Instruct", "14GB"),
39
  ]
40
 
 
41
  PERSONAS = [
42
+ Persona(
43
+ name="Dr. Ana Rodriguez",
44
+ description="An analytical scientist who values empirical evidence and logical reasoning.",
45
+ traits="analytical, skeptical, evidence-focused",
46
+ style="formal, precise, methodical",
47
+ emoji="🔬"
48
+ ),
49
+ Persona(
50
+ name="Professor Marcus Chen",
51
+ description="A creative philosopher with an interest in ethics and societal implications.",
52
+ traits="philosophical, visionary, empathetic",
53
+ style="eloquent, metaphorical, conceptual",
54
+ emoji="🧠"
55
+ ),
56
+ Persona(
57
+ name="Sarah Johnson",
58
+ description="A pragmatic problem-solver with real-world experience.",
59
+ traits="practical, solution-oriented, experienced",
60
+ style="direct, concise, example-driven",
61
+ emoji="🛠️"
62
+ ),
63
+ Persona(
64
+ name="Dr. Emeka Okafor",
65
+ description="A social scientist specializing in cultural perspectives.",
66
+ traits="culturally aware, nuanced, community-focused",
67
+ style="inclusive, storytelling, perspective-oriented",
68
+ emoji="🌍"
69
+ )
70
  ]
71
 
 
72
  model_cache = {}
73
+ current_device = None
74
+
75
+ def get_device() -> str:
76
+ global current_device
77
+ if current_device:
78
+ return current_device
79
+ if torch.cuda.is_available():
80
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
81
+ current_device = "cuda"
82
+ logger.info(f"GPU available with {gpu_mem:.1f}GB memory")
83
+ min_required = min(float(model.required_memory.replace("GB", "")) for model in MODELS)
84
+ if gpu_mem < min_required:
85
+ logger.warning(f"GPU memory may be insufficient for some models (has {gpu_mem:.1f}GB, needs {min_required}GB)")
86
+ else:
87
+ current_device = "cpu"
88
+ logger.info("Using CPU")
89
+ return current_device
90
+
91
+ def clear_model_cache():
92
+ global model_cache
93
+ for model_id in list(model_cache.keys()):
94
+ del model_cache[model_id]
95
+ gc.collect()
96
+ torch.cuda.empty_cache()
97
+ model_cache = {}
98
+ logger.info("Model cache cleared")
99
 
100
+ def load_model(model_info: ModelInfo) -> Tuple[pipeline, AutoTokenizer]:
 
101
  global model_cache
102
+ if model_info.id in model_cache:
103
+ logger.info(f"Using cached model: {model_info.name}")
104
+ return model_cache[model_info.id]
105
+ device = get_device()
106
+ if device == "cuda":
107
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
108
+ required_mem = float(model_info.required_memory.replace("GB", ""))
109
+ if gpu_mem < required_mem:
110
+ logger.warning(f"Insufficient GPU memory for {model_info.name} (needs {required_mem}GB, has {gpu_mem:.1f}GB)")
111
+ logger.info(f"Loading {model_info.name} on {device}")
112
  try:
113
+ start_time = time.time()
114
+ tokenizer = AutoTokenizer.from_pretrained(model_info.id)
 
 
 
 
 
 
 
 
115
  model_kwargs = {
116
  "trust_remote_code": True,
117
+ "device_map": "auto" if device == "cuda" else None,
118
  "torch_dtype": torch.float16 if device == "cuda" else torch.float32
119
  }
120
+ with gr.Progress() as progress:
121
+ progress(0, desc=f"Loading {model_info.name}")
122
+ model = AutoModelForCausalLM.from_pretrained(model_info.id, **model_kwargs)
123
+ if device == "cuda":
124
+ model = model.to(device)
125
+ pipe = pipeline(
126
+ "text-generation",
127
+ model=model,
128
+ tokenizer=tokenizer,
129
+ device=model.device
130
+ )
131
+ model_cache[model_info.id] = (pipe, tokenizer)
132
+ load_time = time.time() - start_time
133
+ logger.info(f"Loaded {model_info.name} in {load_time:.1f}s")
134
  return pipe, tokenizer
 
135
  except Exception as e:
136
+ logger.error(f"Failed to load {model_info.name}: {str(e)}")
137
  raise
138
 
139
+ def create_debate_prompt(
140
+ user_prompt: str,
141
+ persona: Persona,
142
+ debate_style: str = "Balanced",
143
+ previous_responses: Optional[List[str]] = None
144
+ ) -> str:
145
+ style_guidance = {
146
+ "Collaborative": "Focus on building upon ideas and finding common ground.",
147
+ "Adversarial": "Challenge assumptions and present strong contrasting views.",
148
+ "Balanced": "Present your perspective while respecting others."
149
+ }.get(debate_style, "Present your authentic perspective.")
150
+ base_prompt = f"""You are {persona.name}, {persona.description}
151
+ Your communication style: {persona.style}
152
+ Traits: {persona.traits}
 
 
 
 
153
 
154
+ You're in a council debating: "{user_prompt}"
 
155
 
156
  {style_guidance}
157
+ Respond naturally in 3-4 paragraphs."""
158
+ if previous_responses:
 
 
 
 
 
159
  debate_history = "\n\n".join(previous_responses)
160
+ return f"""{base_prompt}
 
 
 
161
 
162
+ Current discussion:
 
 
163
  {debate_history}
164
 
165
+ Now respond thoughtfully to the ongoing debate:
166
+ {persona.name}:"""
167
+ return f"""{base_prompt}
 
 
 
168
 
169
+ Begin your response:
170
+ {persona.name}:"""
 
 
 
 
171
 
172
  def create_synthesis_prompt(user_prompt: str, all_responses: List[str]) -> str:
 
173
  debate_history = "\n\n".join(all_responses)
174
+ return f"""As the Facilitator, synthesize this discussion:
175
+ Topic: "{user_prompt}"
176
 
177
+ Debate:
178
  {debate_history}
179
 
180
+ Provide:
181
+ 1. Key agreements/disagreements
182
+ 2. Important insights
183
+ 3. Balanced conclusion
184
+ 4. Recommended next steps
 
 
 
185
 
186
  Facilitator:"""
 
187
 
188
+ def stream_response(
189
+ pipe: pipeline,
190
+ tokenizer: AutoTokenizer,
191
+ prompt: str,
192
+ speaker_name: Optional[str] = None,
193
+ temperature: float = 0.7,
194
+ max_tokens: int = 512
195
+ ) -> Generator[str, None, None]:
196
  try:
 
197
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
198
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(pipe.model.device)
 
 
199
  generation_kwargs = dict(
200
  input_ids=input_ids,
201
  streamer=streamer,
202
+ max_new_tokens=max_tokens,
203
  do_sample=True,
204
  temperature=temperature,
205
  top_p=0.95,
206
  repetition_penalty=1.1,
207
  eos_token_id=tokenizer.eos_token_id,
208
  )
209
+ thread = threading.Thread(target=pipe.model.generate, kwargs=generation_kwargs)
 
 
 
 
210
  thread.start()
211
+ buffer = ""
 
 
212
  for new_text in streamer:
213
+ buffer += new_text
214
+ if new_text and new_text[-1] in " .,;!?\n":
215
+ if speaker_name:
216
+ yield f"**{speaker_name}:** {buffer.strip()}"
217
+ else:
218
+ yield buffer.strip()
219
  thread.join()
220
+ if buffer.strip():
221
+ if speaker_name:
222
+ yield f"**{speaker_name}:** {buffer.strip()}"
223
+ else:
224
+ yield buffer.strip()
225
  except Exception as e:
226
+ logger.error(f"Streaming error: {str(e)}")
227
+ error_msg = f"[Error: {str(e)}]"
228
+ yield f"**{speaker_name}:** {error_msg}" if speaker_name else error_msg
229
 
230
+ def council_chat_stream(
231
+ user_prompt: str,
232
+ num_members: int = 3,
233
+ debate_style: str = "Balanced",
234
+ temperature: float = 0.7,
235
+ selected_models: Optional[List[str]] = None,
236
+ continue_debate: bool = False,
237
+ history: Optional[List[str]] = None
238
+ ) -> Generator[str, None, None]:
239
  if not user_prompt.strip():
240
+ yield "Please enter a topic for debate."
241
  return
242
+ num_members = max(2, min(num_members, len(PERSONAS)))
243
+ temperature = max(0.1, min(temperature, 1.0))
244
  start_time = time.time()
245
+ selected_personas = random.sample(PERSONAS, num_members)
246
+ model_pool = selected_models if selected_models else [model.id for model in MODELS]
247
+ selected_model_infos = random.sample([m for m in MODELS if m.id in model_pool], num_members)
 
 
 
248
  loaded_models = []
249
+ for model_info in selected_model_infos:
250
  try:
251
+ with gr.Progress() as progress:
252
+ progress(0, desc=f"Loading {model_info.name}")
253
+ pipe, tokenizer = load_model(model_info)
254
+ loaded_models.append((pipe, tokenizer, model_info))
255
  except Exception as e:
256
+ logger.error(f"Skipping {model_info.name}: {str(e)}")
257
+ yield f"⚠️ Couldn't load {model_info.name}, skipping..."
258
+ continue
259
+ if not loaded_models:
260
+ yield "❌ No models could be loaded. Please try again later."
261
+ return
262
  responses = []
263
  formatted_responses = []
264
  persona_responses = []
265
+ if continue_debate and history:
266
+ formatted_responses.extend(history)
267
+ persona_responses.extend([r.split("**:")[-1].strip() for r in history if "**:" in r])
268
+ for i, (persona, (pipe, tokenizer, model_info)) in enumerate(zip(selected_personas, loaded_models)):
269
+ display_name = f"{persona.emoji} {persona.name} ({model_info.name})"
270
+ thinking_msg = f"**{display_name}** is thinking..."
271
+ current_output = "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + [thinking_msg])
272
+ yield current_output
273
+ prompt = create_debate_prompt(
274
+ user_prompt,
275
+ persona,
276
+ debate_style,
277
+ persona_responses if i > 0 else None
278
+ )
279
+ full_response = ""
280
+ for chunk in stream_response(pipe, tokenizer, prompt, display_name, temperature):
281
+ full_response = chunk
282
+ current_output = "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + [chunk])
283
  yield current_output
284
+ persona_responses.append(f"{persona.name}: {full_response.split('**:')[-1].strip()}")
285
+ formatted_responses.append(full_response)
286
+ synth_pipe, synth_tokenizer, _ = random.choice(loaded_models)
287
+ synth_prompt = create_synthesis_prompt(user_prompt, persona_responses)
288
+ yield "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + ["✨ **Facilitator** is synthesizing..."])
 
 
 
 
 
 
289
  synthesis = ""
290
+ for chunk in stream_response(synth_pipe, synth_tokenizer, synth_prompt, "✨ Facilitator", temperature):
291
+ synthesis = chunk
292
+ current_output = "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + [chunk])
293
  yield current_output
 
 
 
294
  elapsed_time = time.time() - start_time
295
+ transcript = (
296
+ f"**User:** {user_prompt}\n\n" +
297
+ "\n\n".join(formatted_responses) +
298
+ f"\n\n{synthesis}\n\n" +
299
+ f"---\n*Debate completed in {elapsed_time:.1f} seconds*"
300
+ )
301
  yield transcript
302
 
303
+ def council_chat_stream_chatbot(
304
+ user_prompt: str,
305
+ num_members: int = 3,
306
+ debate_style: str = "Balanced",
307
+ temperature: float = 0.7,
308
+ selected_models: Optional[List[str]] = None,
309
+ continue_debate: bool = False,
310
+ history: Optional[List[str]] = None
311
+ ) -> Generator[list, None, None]:
312
+ chat_history = []
313
+ for output in council_chat_stream(
314
+ user_prompt, num_members, debate_style, temperature, selected_models, continue_debate, history
315
+ ):
316
+ chat_history.append((None, output))
317
+ yield chat_history
318
+
319
  def build_gradio_interface():
 
 
 
320
  custom_css = """
321
+ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
322
+ #transcript-container { position: relative; }
323
+ #copy-btn { position: absolute; top: 10px; right: 10px; z-index: 100; }
324
+ .member-card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; margin-bottom: 15px; background: #f9f9f9; }
325
+ .member-card h3 { margin-top: 0; color: #333; }
 
 
 
 
 
 
 
 
326
  """
 
327
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
328
+ current_debate = gr.State([])
329
+ gr.Markdown("# 🏛️ AI Council Debate\n*Get diverse AI perspectives on any topic*")
 
330
  with gr.Row():
331
+ with gr.Column(scale=2):
332
+ user_prompt = gr.Textbox(
333
+ label="Debate Topic",
334
+ placeholder="Enter your question or topic for debate...",
335
+ lines=4,
336
+ max_lines=6
337
  )
338
+ with gr.Accordion("⚙️ Debate Settings", open=False):
 
 
339
  with gr.Row():
340
  num_members = gr.Slider(
341
+ minimum=2,
342
+ maximum=len(PERSONAS),
343
+ value=3,
344
+ step=1,
345
  label="Number of Council Members"
346
  )
 
 
347
  debate_style = gr.Radio(
348
+ ["Collaborative", "Adversarial", "Balanced"],
349
+ value="Balanced",
350
+ label="Debate Style"
351
  )
 
352
  with gr.Row():
353
  temperature = gr.Slider(
354
+ minimum=0.1,
355
+ maximum=1.0,
356
+ value=0.7,
357
+ step=0.1,
358
+ label="Creativity (Temperature)"
359
+ )
360
+ model_selection = gr.CheckboxGroup(
361
+ choices=[model.name for model in MODELS],
362
+ value=[model.name for model in MODELS],
363
+ label="Models to Use"
364
  )
365
+ with gr.Row():
366
+ continue_btn = gr.Checkbox(
367
+ label="Continue Previous Debate",
368
+ value=False
369
+ )
370
+ clear_cache_btn = gr.Button(
371
+ "Clear Model Cache",
372
+ variant="secondary"
373
+ )
374
+ with gr.Row():
375
+ output_style = gr.Radio(
376
+ ["Transcript (Markdown)", "Chatbot (Chat History)"],
377
+ value="Transcript (Markdown)",
378
+ label="Output Style"
379
+ )
380
+ submit_btn = gr.Button(
381
+ "Start Debate",
382
+ variant="primary"
383
+ )
384
+ stop_btn = gr.Button(
385
+ "Stop",
386
+ variant="stop"
387
+ )
388
+ with gr.Column(scale=3):
389
+ transcript_out = gr.HTML(label="Council Debate", elem_id="transcript-container", visible=True)
390
+ chatbot_out = gr.Chatbot(label="Council Debate (Chat)", visible=False, height=500)
391
+ with gr.Accordion("👥 Meet the Council Members", open=False):
392
  for persona in PERSONAS:
393
+ with gr.Group(elem_classes="member-card"):
394
+ gr.Markdown(f"""
395
+ <h3>{persona.emoji} {persona.name}</h3>
396
+ <p><strong>Description:</strong> {persona.description}</p>
397
+ <p><strong>Traits:</strong> {persona.traits}</p>
398
+ <p><strong>Style:</strong> {persona.style}</p>
399
+ """)
400
+ with gr.Accordion("ℹ️ System Information", open=False):
401
+ gr.Markdown(f"""
402
+ - **Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}
403
+ - **Available Models:** {len(MODELS)}
404
+ - **Council Members:** {len(PERSONAS)}
405
+ - **Note:** First run may take time to download models
406
+ """)
407
+ def route_debate(user_prompt, num_members, debate_style, temperature, model_selection, continue_btn, current_debate, output_style):
408
+ selected_model_ids = [m.id for m in MODELS if m.name in model_selection]
409
+ if output_style == "Transcript (Markdown)":
410
+ for out in council_chat_stream(
411
+ user_prompt, num_members, debate_style, temperature, selected_model_ids, continue_btn, current_debate
412
+ ):
413
+ yield gr.update(visible=True, value=out), gr.update(visible=False)
414
+ else:
415
+ for out in council_chat_stream_chatbot(
416
+ user_prompt, num_members, debate_style, temperature, selected_model_ids, continue_btn, current_debate
417
+ ):
418
+ yield gr.update(visible=False), gr.update(visible=True, value=out)
419
+ submit_btn.click(
420
+ route_debate,
421
+ [user_prompt, num_members, debate_style, temperature, model_selection, continue_btn, current_debate, output_style],
422
+ [transcript_out, chatbot_out],
423
+ queue=True
424
+ )
425
+ stop_btn.click(
426
+ fn=None, inputs=None, outputs=None, cancels=[submit_btn]
427
+ )
428
+ clear_cache_btn.click(
429
+ fn=clear_model_cache, inputs=None, outputs=None
430
+ )
431
+ def update_history(history: List[str], new_output: str) -> List[str]:
432
+ if "✨ Facilitator" in new_output:
433
+ return []
434
+ return history + [new_output] if history else [new_output]
435
+ transcript_out.change(
436
+ fn=update_history,
437
+ inputs=[current_debate, transcript_out],
438
+ outputs=current_debate
439
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  return demo
441
 
 
442
  if __name__ == "__main__":
443
+ device = get_device()
444
+ if device == "cuda":
445
+ gpu_info = torch.cuda.get_device_properties(0)
446
+ logger.info(f"Using GPU: {gpu_info.name} ({gpu_info.total_memory / (1024**3):.1f}GB)")
447
  else:
448
+ logger.info("Using CPU")
 
 
449
  demo = build_gradio_interface()
450
  demo.launch()