dicksinyass commited on
Commit
e22f850
·
verified ·
1 Parent(s): 3472123

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -75
app.py CHANGED
@@ -5,20 +5,33 @@ import threading
5
  import torch
6
  import os
7
  import time
8
- from typing import List, Dict, Generator, Tuple, Optional
9
  import logging
 
10
  from collections import defaultdict
 
 
 
 
 
11
 
12
  # Set up logging
13
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
 
 
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
  # --- Best Free Models for Council ---
17
  MODELS = [
18
- ("mistralai/Mistral-7B-Instruct-v0.2", "Mistral 7B Instruct"), # Good default choice
19
- ("HuggingFaceH4/zephyr-7b-beta", "Zephyr 7B Beta"), # Smaller alternative
20
- ("NousResearch/Hermes-2-Pro-Mistral-7B", "Hermes 2 Pro"), # Good for debate
21
- ("cognitivecomputations/dolphin-2.6-mistral-7b", "Dolphin Mistral"), # Uncensored
22
  ]
23
 
24
  # Define council member personas
@@ -29,7 +42,7 @@ PERSONAS = [
29
  "traits": "analytical, skeptical, evidence-focused",
30
  "style": "formal, precise, methodical",
31
  "emoji": "🔬",
32
- "preferred_models": ["Mistral 7B Instruct", "Zephyr 7B Beta"] # More factual models
33
  },
34
  {
35
  "name": "Professor Marcus Chen",
@@ -37,7 +50,7 @@ PERSONAS = [
37
  "traits": "philosophical, visionary, empathetic",
38
  "style": "eloquent, metaphorical, conceptual",
39
  "emoji": "🧠",
40
- "preferred_models": ["Hermes 2 Pro", "Dolphin Mistral"] # More creative models
41
  },
42
  {
43
  "name": "Sarah Johnson",
@@ -57,10 +70,9 @@ PERSONAS = [
57
  }
58
  ]
59
 
60
- # Cache for models to avoid reloading
61
  model_cache = {}
62
  model_loading_lock = threading.Lock()
63
- active_sessions = defaultdict(dict)
64
  stop_signal = threading.Event()
65
 
66
  def get_device_preference():
@@ -93,13 +105,15 @@ def load_model(model_id: str) -> Tuple[pipeline, AutoTokenizer]:
93
  "torch_dtype": torch.float16 if device == "cuda" else torch.float32
94
  }
95
 
96
- # More efficient loading for low-memory systems
97
  if device == "cpu":
98
- model_kwargs["low_cpu_mem_usage"] = True
 
 
 
99
 
100
  model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
101
 
102
- if device != "cuda": # For CPU/MPS, manually move to device
103
  model = model.to(device)
104
 
105
  pipe = pipeline(
@@ -115,7 +129,6 @@ def load_model(model_id: str) -> Tuple[pipeline, AutoTokenizer]:
115
 
116
  except Exception as e:
117
  logger.error(f"Failed to load model {model_id}: {str(e)}")
118
- # Try with smaller precision if failed
119
  if "out of memory" in str(e).lower() and device == "cuda":
120
  logger.info("Attempting to load with float16 to save memory")
121
  try:
@@ -128,10 +141,7 @@ def load_model(model_id: str) -> Tuple[pipeline, AutoTokenizer]:
128
  logger.error(f"Still failed to load model: {str(e2)}")
129
  raise
130
 
131
- def create_debate_prompt(user_prompt: str,
132
- persona: Dict,
133
- debate_style: str = "Balanced",
134
- previous_responses: Optional[List[str]] = None) -> str:
135
  """Enhanced prompt engineering for better debates"""
136
  persona_desc = (
137
  f"Roleplay as {persona['name']}, {persona['description']}\n"
@@ -187,12 +197,7 @@ Write in clear, concise bullet points followed by a short paragraph summary.
187
 
188
  Facilitator:"""
189
 
190
- def stream_model_response(pipe: pipeline,
191
- tokenizer: AutoTokenizer,
192
- prompt: str,
193
- speaker_name: str = None,
194
- temperature: float = 0.7,
195
- max_tokens: int = 512) -> Generator[str, None, None]:
196
  """Robust streaming with better formatting and stop handling"""
197
  try:
198
  if stop_signal.is_set():
@@ -207,7 +212,7 @@ def stream_model_response(pipe: pipeline,
207
  streamer=streamer,
208
  max_new_tokens=max_tokens,
209
  do_sample=True,
210
- temperature=min(max(temperature, 0.1), 1.0), # Clamped to reasonable range
211
  top_p=0.95,
212
  repetition_penalty=1.1,
213
  eos_token_id=tokenizer.eos_token_id,
@@ -219,19 +224,17 @@ def stream_model_response(pipe: pipeline,
219
  buffer = ""
220
  for new_text in streamer:
221
  if stop_signal.is_set():
222
- pipe.model.config.use_cache = False # Try to stop generation
223
  thread.join(timeout=1)
224
  break
225
 
226
  buffer += new_text
227
- # Only yield when we have a complete word to avoid mid-word breaks
228
  if " " in new_text or "\n" in new_text:
229
  if speaker_name:
230
  yield f"**{speaker_name}:** {buffer.strip()}"
231
  else:
232
  yield buffer.strip()
233
 
234
- # Yield any remaining content
235
  if buffer.strip():
236
  if speaker_name:
237
  yield f"**{speaker_name}:** {buffer.strip()}"
@@ -243,6 +246,9 @@ def stream_model_response(pipe: pipeline,
243
  except Exception as e:
244
  logger.error(f"Error in streaming: {str(e)}")
245
  yield "[Error in generation]" if not speaker_name else f"**{speaker_name}:** [Error in generation]"
 
 
 
246
 
247
  def select_models_for_personas(personas: List[Dict], models: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
248
  """Match models to personas based on preferences"""
@@ -250,22 +256,16 @@ def select_models_for_personas(personas: List[Dict], models: List[Tuple[str, str
250
  model_names = [m[1] for m in models]
251
 
252
  for persona in personas:
253
- # Try to match preferred models first
254
  for pref in persona.get("preferred_models", []):
255
  if pref in model_names:
256
  selected.append(models[model_names.index(pref)])
257
  break
258
  else:
259
- # Fallback to random selection
260
  selected.append(random.choice(models))
261
 
262
  return selected
263
 
264
- def council_chat_stream(user_prompt: str,
265
- num_members: int = 3,
266
- debate_style: str = "Balanced",
267
- temperature: float = 0.7,
268
- session_id: str = None) -> Generator[str, None, None]:
269
  """Enhanced debate generation with better state management"""
270
  stop_signal.clear()
271
 
@@ -276,11 +276,9 @@ def council_chat_stream(user_prompt: str,
276
  start_time = time.time()
277
 
278
  try:
279
- # Select personas and models
280
  selected_personas = random.sample(PERSONAS, min(num_members, len(PERSONAS)))
281
  selected_models = select_models_for_personas(selected_personas, MODELS)
282
 
283
- # Load all models first with progress updates
284
  loaded_models = []
285
  for i, (model_id, model_name) in enumerate(selected_models):
286
  if stop_signal.is_set():
@@ -300,7 +298,6 @@ def council_chat_stream(user_prompt: str,
300
  yield "❌ Error: No models could be loaded. Please try again later."
301
  return
302
 
303
- # Conduct the debate
304
  responses = []
305
  formatted_responses = []
306
  persona_responses = []
@@ -311,17 +308,10 @@ def council_chat_stream(user_prompt: str,
311
  return
312
 
313
  display_name = f"{persona['emoji']} {persona['name']} ({model_name})"
314
- prompt = create_debate_prompt( user_prompt, persona, debate_style, persona_responses)
315
 
316
- # Stream and collect response
317
  response_text = ""
318
- for partial in stream_model_response(
319
- pipe,
320
- tokenizer,
321
- prompt,
322
- display_name,
323
- temperature
324
- ):
325
  if stop_signal.is_set():
326
  break
327
  yield partial
@@ -331,7 +321,6 @@ def council_chat_stream(user_prompt: str,
331
  yield "[Debate stopped during responses]"
332
  return
333
 
334
- # Store response data
335
  response_data = {
336
  "name": persona['name'],
337
  "model": model_name,
@@ -341,7 +330,6 @@ def council_chat_stream(user_prompt: str,
341
  persona_responses.append(response_data)
342
  formatted_responses.append(partial)
343
 
344
- # Facilitator synthesis
345
  if not stop_signal.is_set():
346
  yield "\n\n**✨ Council is now synthesizing the discussion...**\n"
347
  synthesis_model = random.choice(loaded_models)
@@ -352,13 +340,12 @@ def council_chat_stream(user_prompt: str,
352
  synthesis_model[1],
353
  synthesis_prompt,
354
  "✨ Facilitator's Synthesis",
355
- temperature*0.8 # Slightly lower temp for synthesis
356
  ):
357
  if stop_signal.is_set():
358
  break
359
  yield partial
360
 
361
- # Final output
362
  elapsed_time = time.time() - start_time
363
  if not stop_signal.is_set():
364
  transcript = (
@@ -414,7 +401,6 @@ def build_gradio_interface():
414
  """
415
 
416
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
417
- # Header section
418
  with gr.Row():
419
  gr.Markdown("""
420
  <div class="council-header">
@@ -423,10 +409,8 @@ def build_gradio_interface():
423
  </div>
424
  """)
425
 
426
- # Main controls
427
  with gr.Row():
428
  with gr.Column(scale=2):
429
- # Input section
430
  inp = gr.Textbox(
431
  label="Debate Topic",
432
  placeholder="Enter a topic or question for the council to debate...",
@@ -434,7 +418,6 @@ def build_gradio_interface():
434
  max_lines=6
435
  )
436
 
437
- # Debate controls
438
  with gr.Group(elem_classes="debate-controls"):
439
  with gr.Row():
440
  btn = gr.Button("Start Debate", variant="primary")
@@ -447,8 +430,7 @@ def build_gradio_interface():
447
  minimum=2,
448
  maximum=4,
449
  step=1,
450
- value=3,
451
- info="Number of AI participants"
452
  )
453
  debate_style = gr.Dropdown(
454
  label="Debate Style",
@@ -461,11 +443,9 @@ def build_gradio_interface():
461
  minimum=0.1,
462
  maximum=1.0,
463
  step=0.1,
464
- value=0.7,
465
- info="Higher = more creative/random"
466
  )
467
 
468
- # Persona information
469
  with gr.Accordion("Meet the Council Members", open=False):
470
  for persona in PERSONAS:
471
  with gr.Group(elem_classes="persona-card"):
@@ -476,7 +456,6 @@ def build_gradio_interface():
476
  **Preferred Models:** {', '.join(persona.get('preferred_models', ['Any']))}
477
  """)
478
 
479
- # Output section
480
  with gr.Column(scale=3):
481
  out = gr.Markdown(
482
  label="Live Debate Transcript",
@@ -491,7 +470,6 @@ def build_gradio_interface():
491
  - Debate memory and context tracking
492
  """)
493
 
494
- # Example prompts
495
  with gr.Accordion("Example Debate Topics", open=False):
496
  examples = gr.Examples(
497
  examples=[
@@ -505,7 +483,6 @@ def build_gradio_interface():
505
  label="Click to try these examples"
506
  )
507
 
508
- # Event handlers
509
  btn.click(
510
  fn=council_chat_stream,
511
  inputs=[inp, num_members, debate_style, temperature],
@@ -517,7 +494,6 @@ def build_gradio_interface():
517
  queue=False
518
  )
519
 
520
- # Footer
521
  gr.Markdown("""
522
  ---
523
  **About This System:**
@@ -529,19 +505,36 @@ def build_gradio_interface():
529
 
530
  return demo
531
 
532
- # Main application
533
  if __name__ == "__main__":
534
- # Check system resources
535
  device = get_device_preference()
536
- logger.info(f"Running on device: {device}")
 
 
 
 
 
537
 
538
  if device == "cpu":
539
- logger.warning("Running on CPU - performance will be significantly slower than GPU")
 
 
 
 
540
 
541
- # Launch interface
542
- demo = build_gradio_interface()
543
- demo.queue(concurrency_count=1).launch(
544
- server_name="0.0.0.0",
545
- server_port=7860,
546
- share=False
547
- )
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import os
7
  import time
8
+ import sys
9
  import logging
10
+ from typing import List, Dict, Generator, Tuple, Optional
11
  from collections import defaultdict
12
+ import gc
13
+
14
+ # Configure Torch for CPU optimization
15
+ torch.set_num_threads(os.cpu_count() or 1)
16
+ torch.backends.quantized.engine = 'qnnpack' if torch.backends.quantized.supported_engines else None
17
 
18
  # Set up logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
22
+ handlers=[
23
+ logging.FileHandler('council_debate.log'),
24
+ logging.StreamHandler()
25
+ ]
26
+ )
27
  logger = logging.getLogger(__name__)
28
 
29
  # --- Best Free Models for Council ---
30
  MODELS = [
31
+ ("mistralai/Mistral-7B-Instruct-v0.2", "Mistral 7B Instruct"),
32
+ ("HuggingFaceH4/zephyr-7b-beta", "Zephyr 7B Beta"),
33
+ ("NousResearch/Hermes-2-Pro-Mistral-7B", "Hermes 2 Pro"),
34
+ ("cognitivecomputations/dolphin-2.6-mistral-7b", "Dolphin Mistral"),
35
  ]
36
 
37
  # Define council member personas
 
42
  "traits": "analytical, skeptical, evidence-focused",
43
  "style": "formal, precise, methodical",
44
  "emoji": "🔬",
45
+ "preferred_models": ["Mistral 7B Instruct", "Zephyr 7B Beta"]
46
  },
47
  {
48
  "name": "Professor Marcus Chen",
 
50
  "traits": "philosophical, visionary, empathetic",
51
  "style": "eloquent, metaphorical, conceptual",
52
  "emoji": "🧠",
53
+ "preferred_models": ["Hermes 2 Pro", "Dolphin Mistral"]
54
  },
55
  {
56
  "name": "Sarah Johnson",
 
70
  }
71
  ]
72
 
73
+ # Cache for models
74
  model_cache = {}
75
  model_loading_lock = threading.Lock()
 
76
  stop_signal = threading.Event()
77
 
78
  def get_device_preference():
 
105
  "torch_dtype": torch.float16 if device == "cuda" else torch.float32
106
  }
107
 
 
108
  if device == "cpu":
109
+ model_kwargs.update({
110
+ "low_cpu_mem_usage": True,
111
+ "torch_dtype": torch.float32,
112
+ })
113
 
114
  model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
115
 
116
+ if device != "cuda":
117
  model = model.to(device)
118
 
119
  pipe = pipeline(
 
129
 
130
  except Exception as e:
131
  logger.error(f"Failed to load model {model_id}: {str(e)}")
 
132
  if "out of memory" in str(e).lower() and device == "cuda":
133
  logger.info("Attempting to load with float16 to save memory")
134
  try:
 
141
  logger.error(f"Still failed to load model: {str(e2)}")
142
  raise
143
 
144
+ def create_debate_prompt(user_prompt: str, persona: Dict, debate_style: str = "Balanced", previous_responses: Optional[List[Dict]] = None) -> str:
 
 
 
145
  """Enhanced prompt engineering for better debates"""
146
  persona_desc = (
147
  f"Roleplay as {persona['name']}, {persona['description']}\n"
 
197
 
198
  Facilitator:"""
199
 
200
+ def stream_model_response(pipe: pipeline, tokenizer: AutoTokenizer, prompt: str, speaker_name: str = None, temperature: float = 0.7, max_tokens: int = 512) -> Generator[str, None, None]:
 
 
 
 
 
201
  """Robust streaming with better formatting and stop handling"""
202
  try:
203
  if stop_signal.is_set():
 
212
  streamer=streamer,
213
  max_new_tokens=max_tokens,
214
  do_sample=True,
215
+ temperature=min(max(temperature, 0.1), 1.0),
216
  top_p=0.95,
217
  repetition_penalty=1.1,
218
  eos_token_id=tokenizer.eos_token_id,
 
224
  buffer = ""
225
  for new_text in streamer:
226
  if stop_signal.is_set():
227
+ pipe.model.config.use_cache = False
228
  thread.join(timeout=1)
229
  break
230
 
231
  buffer += new_text
 
232
  if " " in new_text or "\n" in new_text:
233
  if speaker_name:
234
  yield f"**{speaker_name}:** {buffer.strip()}"
235
  else:
236
  yield buffer.strip()
237
 
 
238
  if buffer.strip():
239
  if speaker_name:
240
  yield f"**{speaker_name}:** {buffer.strip()}"
 
246
  except Exception as e:
247
  logger.error(f"Error in streaming: {str(e)}")
248
  yield "[Error in generation]" if not speaker_name else f"**{speaker_name}:** [Error in generation]"
249
+ finally:
250
+ gc.collect()
251
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
252
 
253
  def select_models_for_personas(personas: List[Dict], models: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
254
  """Match models to personas based on preferences"""
 
256
  model_names = [m[1] for m in models]
257
 
258
  for persona in personas:
 
259
  for pref in persona.get("preferred_models", []):
260
  if pref in model_names:
261
  selected.append(models[model_names.index(pref)])
262
  break
263
  else:
 
264
  selected.append(random.choice(models))
265
 
266
  return selected
267
 
268
+ def council_chat_stream(user_prompt: str, num_members: int = 3, debate_style: str = "Balanced", temperature: float = 0.7) -> Generator[str, None, None]:
 
 
 
 
269
  """Enhanced debate generation with better state management"""
270
  stop_signal.clear()
271
 
 
276
  start_time = time.time()
277
 
278
  try:
 
279
  selected_personas = random.sample(PERSONAS, min(num_members, len(PERSONAS)))
280
  selected_models = select_models_for_personas(selected_personas, MODELS)
281
 
 
282
  loaded_models = []
283
  for i, (model_id, model_name) in enumerate(selected_models):
284
  if stop_signal.is_set():
 
298
  yield "❌ Error: No models could be loaded. Please try again later."
299
  return
300
 
 
301
  responses = []
302
  formatted_responses = []
303
  persona_responses = []
 
308
  return
309
 
310
  display_name = f"{persona['emoji']} {persona['name']} ({model_name})"
311
+ prompt = create_debate_prompt(user_prompt, persona, debate_style, persona_responses)
312
 
 
313
  response_text = ""
314
+ for partial in stream_model_response(pipe, tokenizer, prompt, display_name, temperature):
 
 
 
 
 
 
315
  if stop_signal.is_set():
316
  break
317
  yield partial
 
321
  yield "[Debate stopped during responses]"
322
  return
323
 
 
324
  response_data = {
325
  "name": persona['name'],
326
  "model": model_name,
 
330
  persona_responses.append(response_data)
331
  formatted_responses.append(partial)
332
 
 
333
  if not stop_signal.is_set():
334
  yield "\n\n**✨ Council is now synthesizing the discussion...**\n"
335
  synthesis_model = random.choice(loaded_models)
 
340
  synthesis_model[1],
341
  synthesis_prompt,
342
  "✨ Facilitator's Synthesis",
343
+ temperature*0.8
344
  ):
345
  if stop_signal.is_set():
346
  break
347
  yield partial
348
 
 
349
  elapsed_time = time.time() - start_time
350
  if not stop_signal.is_set():
351
  transcript = (
 
401
  """
402
 
403
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
 
404
  with gr.Row():
405
  gr.Markdown("""
406
  <div class="council-header">
 
409
  </div>
410
  """)
411
 
 
412
  with gr.Row():
413
  with gr.Column(scale=2):
 
414
  inp = gr.Textbox(
415
  label="Debate Topic",
416
  placeholder="Enter a topic or question for the council to debate...",
 
418
  max_lines=6
419
  )
420
 
 
421
  with gr.Group(elem_classes="debate-controls"):
422
  with gr.Row():
423
  btn = gr.Button("Start Debate", variant="primary")
 
430
  minimum=2,
431
  maximum=4,
432
  step=1,
433
+ value=3
 
434
  )
435
  debate_style = gr.Dropdown(
436
  label="Debate Style",
 
443
  minimum=0.1,
444
  maximum=1.0,
445
  step=0.1,
446
+ value=0.7
 
447
  )
448
 
 
449
  with gr.Accordion("Meet the Council Members", open=False):
450
  for persona in PERSONAS:
451
  with gr.Group(elem_classes="persona-card"):
 
456
  **Preferred Models:** {', '.join(persona.get('preferred_models', ['Any']))}
457
  """)
458
 
 
459
  with gr.Column(scale=3):
460
  out = gr.Markdown(
461
  label="Live Debate Transcript",
 
470
  - Debate memory and context tracking
471
  """)
472
 
 
473
  with gr.Accordion("Example Debate Topics", open=False):
474
  examples = gr.Examples(
475
  examples=[
 
483
  label="Click to try these examples"
484
  )
485
 
 
486
  btn.click(
487
  fn=council_chat_stream,
488
  inputs=[inp, num_members, debate_style, temperature],
 
494
  queue=False
495
  )
496
 
 
497
  gr.Markdown("""
498
  ---
499
  **About This System:**
 
505
 
506
  return demo
507
 
 
508
  if __name__ == "__main__":
509
+ # System checks
510
  device = get_device_preference()
511
+ print(f"\n{'='*40}")
512
+ print(f"Starting AI Council Debate on {device.upper()}")
513
+ print(f"Python: {sys.version.split()[0]}")
514
+ print(f"PyTorch: {torch.__version__}")
515
+ print(f"Gradio: {gr.__version__}")
516
+ print(f"{'='*40}\n")
517
 
518
  if device == "cpu":
519
+ print("WARNING: Running on CPU - expect slower performance")
520
+ print("Recommendations:")
521
+ print("- Close other memory-intensive applications")
522
+ print("- Reduce number of council members (2-3)")
523
+ print("- Be patient with response times (30-90 sec per response)\n")
524
 
525
+ try:
526
+ demo = build_gradio_interface()
527
+ demo.launch(
528
+ server_name="0.0.0.0",
529
+ server_port=7860,
530
+ share=False,
531
+ show_error=True
532
+ )
533
+ except Exception as e:
534
+ print(f"\nERROR: {str(e)}")
535
+ print("\nTroubleshooting steps:")
536
+ print("1. Check internet connection (required for model download)")
537
+ print("2. Verify Hugging Face token is set if using Llama models")
538
+ print("3. Try reducing number of council members")
539
+ print("4. Restart the application\n")
540
+ raise