dicksinyass commited on
Commit
e090a0a
ยท
verified ยท
1 Parent(s): 004447c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -155
app.py CHANGED
@@ -3,7 +3,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIter
3
  import random
4
  import threading
5
  import torch
6
- import os
7
  import time
8
  from typing import List, Dict, Generator, Tuple, Optional, Union
9
  import logging
@@ -19,7 +18,7 @@ from datetime import datetime
19
  # Set up logging
20
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21
  logger = logging.getLogger(__name__)
22
- warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly")
23
 
24
  # Enums and Data Classes
25
  class DebateStyle(str, Enum):
@@ -28,14 +27,14 @@ class DebateStyle(str, Enum):
28
  BALANCED = "Balanced"
29
 
30
  class OutputStyle(str, Enum):
31
- TRANSCRIPT = "Transcript (Markdown)"
32
- CHATBOT = "Chatbot (Chat History)"
33
 
34
  @dataclass
35
  class ModelInfo:
36
  id: str
37
  name: str
38
- required_memory: str # Estimated VRAM requirement
39
  supports_quantization: bool = False
40
  quantization_config: Optional[Dict] = field(default_factory=dict)
41
 
@@ -118,15 +117,8 @@ class DebateHistoryManager:
118
  # Constants
119
  MODELS = [
120
  ModelInfo(
121
- "meta-llama/Meta-Llama-3-8B-Instruct",
122
- "Llama 3 8B Instruct",
123
- "16GB",
124
- True,
125
- {"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.float16}
126
- ),
127
- ModelInfo(
128
- "Qwen/Qwen1.5-7B-Chat",
129
- "Qwen1.5 7B Chat",
130
  "14GB",
131
  True,
132
  {"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.float16}
@@ -138,12 +130,12 @@ MODELS = [
138
  False
139
  ),
140
  ModelInfo(
141
- "mistralai/Mistral-7B-Instruct-v0.2",
142
- "Mistral 7B Instruct",
143
  "14GB",
144
  True,
145
  {"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.float16}
146
- ),
147
  ]
148
 
149
  PERSONAS = [
@@ -167,13 +159,6 @@ PERSONAS = [
167
  traits="practical, solution-oriented, experienced",
168
  style="direct, concise, example-driven",
169
  emoji="๐Ÿ› ๏ธ"
170
- ),
171
- Persona(
172
- name="Dr. Emeka Okafor",
173
- description="A social scientist specializing in cultural perspectives.",
174
- traits="culturally aware, nuanced, community-focused",
175
- style="inclusive, storytelling, perspective-oriented",
176
- emoji="๐ŸŒ"
177
  )
178
  ]
179
 
@@ -182,7 +167,6 @@ model_cache = {}
182
  current_device = None
183
  performance_monitor = ModelPerformance()
184
 
185
- # Core Functions
186
  def get_device() -> str:
187
  global current_device
188
  if current_device:
@@ -201,11 +185,10 @@ def get_device() -> str:
201
 
202
  def clear_model_cache():
203
  global model_cache
204
- for model_id in list(model_cache.keys()):
205
- del model_cache[model_id]
206
  gc.collect()
207
- torch.cuda.empty_cache()
208
- model_cache = {}
209
  logger.info("Model cache cleared")
210
 
211
  def load_model(model_info: ModelInfo) -> Tuple[pipeline, AutoTokenizer]:
@@ -217,13 +200,6 @@ def load_model(model_info: ModelInfo) -> Tuple[pipeline, AutoTokenizer]:
217
  device = get_device()
218
  kwargs = {"trust_remote_code": True}
219
 
220
- if device == "cuda":
221
- gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
222
- required_mem = float(model_info.required_memory.replace("GB", ""))
223
- if gpu_mem < required_mem and not model_info.supports_quantization:
224
- logger.warning(f"Insufficient GPU memory for {model_info.name} (needs {required_mem}GB, has {gpu_mem:.1f}GB)")
225
-
226
- # Handle quantization if supported and on CUDA
227
  if device == "cuda" and model_info.supports_quantization:
228
  kwargs.update(model_info.quantization_config)
229
  kwargs["device_map"] = "auto"
@@ -236,9 +212,6 @@ def load_model(model_info: ModelInfo) -> Tuple[pipeline, AutoTokenizer]:
236
  tokenizer = AutoTokenizer.from_pretrained(model_info.id)
237
  model = AutoModelForCausalLM.from_pretrained(model_info.id, **kwargs)
238
 
239
- if device == "cuda" and not model_info.supports_quantization:
240
- model = model.to(device)
241
-
242
  pipe = pipeline(
243
  "text-generation",
244
  model=model,
@@ -354,7 +327,6 @@ def stream_response(
354
  else:
355
  yield buffer.strip()
356
 
357
- # Record performance metrics
358
  performance_monitor.record_generation(
359
  pipe.model.config._name_or_path,
360
  time.time() - start_time,
@@ -380,7 +352,6 @@ def council_chat_stream(
380
  yield "Please enter a topic for debate."
381
  return
382
 
383
- # Convert string style to Enum if needed
384
  if isinstance(debate_style, str):
385
  try:
386
  debate_style = DebateStyle(debate_style)
@@ -398,10 +369,8 @@ def council_chat_stream(
398
  loaded_models = []
399
  for model_info in selected_model_infos:
400
  try:
401
- with gr.Progress() as progress:
402
- progress(0, desc=f"Loading {model_info.name}")
403
- pipe, tokenizer = load_model(model_info)
404
- loaded_models.append((pipe, tokenizer, model_info))
405
  except Exception as e:
406
  logger.error(f"Skipping {model_info.name}: {str(e)}")
407
  yield f"โš ๏ธ Couldn't load {model_info.name}, skipping..."
@@ -424,10 +393,7 @@ def council_chat_stream(
424
  display_name = f"{persona.emoji} {persona.name} ({model_info.name})"
425
  participant_names.append(display_name)
426
 
427
- thinking_msg = f"**{display_name}** is thinking..."
428
- current_output = "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + [thinking_msg])
429
- yield current_output
430
-
431
  prompt = create_debate_prompt(
432
  user_prompt,
433
  persona,
@@ -438,33 +404,26 @@ def council_chat_stream(
438
  full_response = ""
439
  for chunk in stream_response(pipe, tokenizer, prompt, display_name, temperature):
440
  full_response = chunk
441
- current_output = "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + [chunk])
442
- yield current_output
443
 
444
  persona_responses.append(f"{persona.name}: {full_response.split('**:')[-1].strip()}")
445
  formatted_responses.append(full_response)
446
 
447
- # Generate synthesis
448
- synth_pipe, synth_tokenizer, _ = random.choice(loaded_models)
449
  synth_prompt = create_synthesis_prompt(user_prompt, persona_responses)
450
 
451
- yield "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + ["โœจ **Facilitator** is synthesizing..."])
452
-
453
- synthesis = ""
454
- for chunk in stream_response(synth_pipe, synth_tokenizer, synth_prompt, "โœจ Facilitator", temperature):
455
- synthesis = chunk
456
- current_output = "\n\n".join([f"**User:** {user_prompt}"] + formatted_responses + [chunk])
457
- yield current_output
458
 
459
  elapsed_time = time.time() - start_time
460
  transcript = (
461
  f"**User:** {user_prompt}\n\n" +
462
  "\n\n".join(formatted_responses) +
463
- f"\n\n{synthesis}\n\n" +
464
  f"---\n*Debate completed in {elapsed_time:.1f} seconds*"
465
  )
466
 
467
- # Save to history
468
  if save_history:
469
  history_item = DebateHistoryItem(
470
  id=str(uuid.uuid4()),
@@ -478,104 +437,115 @@ def council_chat_stream(
478
 
479
  yield transcript
480
 
481
- def council_chat_stream_chatbot(
482
- user_prompt: str,
483
- num_members: int = 3,
484
- debate_style: Union[DebateStyle, str] = DebateStyle.BALANCED,
485
- temperature: float = 0.7,
486
- selected_models: Optional[List[str]] = None,
487
- continue_debate: bool = False,
488
- history: Optional[List[str]] = None,
489
- save_history: bool = True
490
- ) -> Generator[list, None, None]:
491
- chat_history = []
492
- for output in council_chat_stream(
493
- user_prompt, num_members, debate_style, temperature,
494
- selected_models, continue_debate, history, save_history
495
- ):
496
- chat_history.append((None, output))
497
- yield chat_history
498
-
499
- # UI Components
500
- def build_persona_card(persona: Persona) -> gr.Box:
501
- with gr.Box(elem_classes="member-card") as card:
502
- gr.Markdown(f"""
503
- <h3>{persona.emoji} {persona.name}</h3>
504
- <p><strong>Description:</strong> {persona.description}</p>
505
- <p><strong>Traits:</strong> {persona.traits}</p>
506
- <p><strong>Style:</strong> {persona.style}</p>
507
- """)
508
- return card
509
-
510
- def build_model_info_card(model: ModelInfo) -> gr.Box:
511
- with gr.Box(elem_classes="model-card") as card:
512
- gr.Markdown(f"""
513
- <h3>{model.name}</h3>
514
- <p><strong>ID:</strong> {model.id}</p>
515
- <p><strong>Memory Requirement:</strong> {model.required_memory}</p>
516
- <p><strong>Quantization:</strong> {'Supported' if model.supports_quantization else 'Not Supported'}</p>
517
- """)
518
- return card
519
-
520
- def build_history_item_ui(history_item: Dict) -> gr.Box:
521
- with gr.Box(elem_classes="history-item") as item:
522
- with gr.Row():
523
- with gr.Column(scale=3):
524
- gr.Markdown(f"**{history_item['topic']}**")
525
- gr.Markdown(f"*{datetime.fromtimestamp(history_item['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}*")
526
- with gr.Column(scale=1):
527
- view_btn = gr.Button("View", size="sm")
528
- load_btn = gr.Button("Load", size="sm")
529
- return item, view_btn, load_btn
530
-
531
- # Gradio Interface
532
- def build_gradio_interface():
533
- custom_css = """
534
- .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
535
- .member-card, .model-card, .history-item {
536
- border: 1px solid #e0e0e0;
537
- border-radius: 8px;
538
- padding: 15px;
539
- margin-bottom: 15px;
540
- background: #f9f9f9;
541
- }
542
- .member-card h3, .model-card h3 { margin-top: 0; color: #333; }
543
- #transcript-container { position: relative; max-height: 600px; overflow-y: auto; }
544
- #chatbot-container { max-height: 600px; }
545
- .stats-table { width: 100%; border-collapse: collapse; }
546
  .stats-table th, .stats-table td { padding: 8px; text-align: left; border-bottom: 1px solid #ddd; }
547
  """
548
 
549
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
550
  current_debate = gr.State([])
551
- current_history_id = gr.State(None)
552
 
553
- gr.Markdown("# ๐Ÿ›๏ธ AI Council Debate\n*Get diverse AI perspectives on any topic*")
554
 
555
  with gr.Row():
556
  with gr.Column(scale=2):
557
- # Debate Input Section
558
- with gr.Group():
559
- user_prompt = gr.Textbox(
560
- label="Debate Topic",
561
- placeholder="Enter your question or topic for debate...",
562
- lines=4,
563
- max_lines=6
564
- )
565
-
566
- with gr.Accordion("โš™๏ธ Debate Settings", open=False):
567
- with gr.Row():
568
- num_members = gr.Slider(
569
- minimum=2,
570
- maximum=len(PERSONAS),
571
- value=3,
572
- step=1,
573
- label="Number of Council Members"
574
- )
575
- debate_style = gr.Radio(
576
- list(DebateStyle),
577
- value=DebateStyle.BALANCED,
578
- label="Debate Style"
579
- )
580
-
581
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import random
4
  import threading
5
  import torch
 
6
  import time
7
  from typing import List, Dict, Generator, Tuple, Optional, Union
8
  import logging
 
18
  # Set up logging
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
20
  logger = logging.getLogger(__name__)
21
+ warnings.filterwarnings("ignore")
22
 
23
  # Enums and Data Classes
24
  class DebateStyle(str, Enum):
 
27
  BALANCED = "Balanced"
28
 
29
  class OutputStyle(str, Enum):
30
+ TRANSCRIPT = "Transcript"
31
+ CHATBOT = "Chatbot"
32
 
33
  @dataclass
34
  class ModelInfo:
35
  id: str
36
  name: str
37
+ required_memory: str
38
  supports_quantization: bool = False
39
  quantization_config: Optional[Dict] = field(default_factory=dict)
40
 
 
117
  # Constants
118
  MODELS = [
119
  ModelInfo(
120
+ "mistralai/Mistral-7B-Instruct-v0.2",
121
+ "Mistral 7B Instruct",
 
 
 
 
 
 
 
122
  "14GB",
123
  True,
124
  {"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.float16}
 
130
  False
131
  ),
132
  ModelInfo(
133
+ "Qwen/Qwen1.5-7B-Chat",
134
+ "Qwen1.5 7B Chat",
135
  "14GB",
136
  True,
137
  {"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.float16}
138
+ )
139
  ]
140
 
141
  PERSONAS = [
 
159
  traits="practical, solution-oriented, experienced",
160
  style="direct, concise, example-driven",
161
  emoji="๐Ÿ› ๏ธ"
 
 
 
 
 
 
 
162
  )
163
  ]
164
 
 
167
  current_device = None
168
  performance_monitor = ModelPerformance()
169
 
 
170
  def get_device() -> str:
171
  global current_device
172
  if current_device:
 
185
 
186
  def clear_model_cache():
187
  global model_cache
188
+ model_cache.clear()
 
189
  gc.collect()
190
+ if torch.cuda.is_available():
191
+ torch.cuda.empty_cache()
192
  logger.info("Model cache cleared")
193
 
194
  def load_model(model_info: ModelInfo) -> Tuple[pipeline, AutoTokenizer]:
 
200
  device = get_device()
201
  kwargs = {"trust_remote_code": True}
202
 
 
 
 
 
 
 
 
203
  if device == "cuda" and model_info.supports_quantization:
204
  kwargs.update(model_info.quantization_config)
205
  kwargs["device_map"] = "auto"
 
212
  tokenizer = AutoTokenizer.from_pretrained(model_info.id)
213
  model = AutoModelForCausalLM.from_pretrained(model_info.id, **kwargs)
214
 
 
 
 
215
  pipe = pipeline(
216
  "text-generation",
217
  model=model,
 
327
  else:
328
  yield buffer.strip()
329
 
 
330
  performance_monitor.record_generation(
331
  pipe.model.config._name_or_path,
332
  time.time() - start_time,
 
352
  yield "Please enter a topic for debate."
353
  return
354
 
 
355
  if isinstance(debate_style, str):
356
  try:
357
  debate_style = DebateStyle(debate_style)
 
369
  loaded_models = []
370
  for model_info in selected_model_infos:
371
  try:
372
+ pipe, tokenizer = load_model(model_info)
373
+ loaded_models.append((pipe, tokenizer, model_info))
 
 
374
  except Exception as e:
375
  logger.error(f"Skipping {model_info.name}: {str(e)}")
376
  yield f"โš ๏ธ Couldn't load {model_info.name}, skipping..."
 
393
  display_name = f"{persona.emoji} {persona.name} ({model_info.name})"
394
  participant_names.append(display_name)
395
 
396
+ yield f"**{display_name}** is thinking..."
 
 
 
397
  prompt = create_debate_prompt(
398
  user_prompt,
399
  persona,
 
404
  full_response = ""
405
  for chunk in stream_response(pipe, tokenizer, prompt, display_name, temperature):
406
  full_response = chunk
407
+ yield chunk
 
408
 
409
  persona_responses.append(f"{persona.name}: {full_response.split('**:')[-1].strip()}")
410
  formatted_responses.append(full_response)
411
 
412
+ synth_pipe, synth_tokenizer, _ = loaded_models[0]
 
413
  synth_prompt = create_synthesis_prompt(user_prompt, persona_responses)
414
 
415
+ yield "โœจ **Facilitator** is synthesizing..."
416
+ for chunk in stream_response(synth_pipe, synth_tokenizer, synth_prompt, "Facilitator", temperature):
417
+ yield chunk
 
 
 
 
418
 
419
  elapsed_time = time.time() - start_time
420
  transcript = (
421
  f"**User:** {user_prompt}\n\n" +
422
  "\n\n".join(formatted_responses) +
423
+ f"\n\n**Facilitator:** {chunk.split('**:')[-1].strip()}\n\n" +
424
  f"---\n*Debate completed in {elapsed_time:.1f} seconds*"
425
  )
426
 
 
427
  if save_history:
428
  history_item = DebateHistoryItem(
429
  id=str(uuid.uuid4()),
 
437
 
438
  yield transcript
439
 
440
+ def create_interface():
441
+ css = """
442
+ .member-card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; margin: 10px; background: #f9f9f9; }
443
+ .member-card h3 { margin-top: 0; }
444
+ #debate-output { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ddd; border-radius: 8px; }
445
+ .history-item { border: 1px solid #e0e0e0; border-radius: 8px; padding: 10px; margin: 5px 0; background: #f5f5f5; }
446
+ .stats-table { width: 100%; border-collapse: collapse; margin-top: 10px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  .stats-table th, .stats-table td { padding: 8px; text-align: left; border-bottom: 1px solid #ddd; }
448
  """
449
 
450
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as app:
451
  current_debate = gr.State([])
 
452
 
453
+ gr.Markdown("# ๐Ÿ›๏ธ AI Council Debate")
454
 
455
  with gr.Row():
456
  with gr.Column(scale=2):
457
+ user_input = gr.Textbox(label="Debate Topic", lines=3)
458
+ with gr.Row():
459
+ num_members = gr.Slider(2, len(PERSONAS), value=3, step=1, label="Number of Members")
460
+ temperature = gr.Slider(0.1, 1.0, value=0.7, label="Creativity")
461
+ debate_style = gr.Radio(
462
+ list(DebateStyle),
463
+ value=DebateStyle.BALANCED,
464
+ label="Debate Style"
465
+ )
466
+ model_selection = gr.CheckboxGroup(
467
+ choices=[model.name for model in MODELS],
468
+ value=[model.name for model in MODELS],
469
+ label="Select Models"
470
+ )
471
+ with gr.Row():
472
+ submit_btn = gr.Button("Start Debate", variant="primary")
473
+ clear_btn = gr.Button("Clear", variant="secondary")
474
+ continue_btn = gr.Checkbox(label="Continue Debate", value=False)
475
+ save_history = gr.Checkbox(label="Save History", value=True)
476
+
477
+ with gr.Column(scale=3):
478
+ output = gr.HTML(elem_id="debate-output")
479
+
480
+ with gr.Accordion("๐Ÿ‘ฅ Council Members", open=False):
481
+ for persona in PERSONAS:
482
+ with gr.Group(elem_classes="member-card"):
483
+ gr.Markdown(f"""
484
+ <h3>{persona.emoji} {persona.name}</h3>
485
+ <p><strong>Description:</strong> {persona.description}</p>
486
+ <p><strong>Style:</strong> {persona.style}</p>
487
+ <p><strong>Traits:</strong> {persona.traits}</p>
488
+ """)
489
+
490
+ with gr.Accordion("๐Ÿ“œ Debate History", open=False):
491
+ history_output = gr.Column()
492
+ refresh_history = gr.Button("Refresh History")
493
+
494
+ with gr.Accordion("๐Ÿ“Š Performance Stats", open=False):
495
+ stats_output = gr.HTML()
496
+ refresh_stats = gr.Button("Refresh Stats")
497
+
498
+ def debate_wrapper(user_prompt, num_members, debate_style, temperature, model_selection, continue_debate, save_history, current_debate):
499
+ selected_models = [m.id for m in MODELS if m.name in model_selection]
500
+ return council_chat_stream(
501
+ user_prompt, num_members, debate_style, temperature,
502
+ selected_models, continue_debate, current_debate, save_history
503
+ )
504
+
505
+ def update_history(history, new_output):
506
+ if "Facilitator" in new_output:
507
+ return []
508
+ return history + [new_output] if history else [new_output]
509
+
510
+ def load_history():
511
+ history = DebateHistoryManager.load_history()
512
+ return [
513
+ gr.Group(elem_classes="history-item", visible=True, render=False) for _ in history
514
+ ]
515
+
516
+ def show_stats():
517
+ stats = "<table class='stats-table'><tr><th>Model</th><th>Calls</th><th>Avg Time</th><th>Tokens/s</th></tr>"
518
+ for model in MODELS:
519
+ data = performance_monitor.get_stats(model.id)
520
+ stats += f"""
521
+ <tr>
522
+ <td>{model.name}</td>
523
+ <td>{data['total_calls']}</td>
524
+ <td>{data['avg_time']:.2f}s</td>
525
+ <td>{data['tokens_per_second']:.1f}</td>
526
+ </tr>
527
+ """
528
+ stats += "</table>"
529
+ return stats
530
+
531
+ submit_btn.click(
532
+ debate_wrapper,
533
+ [user_input, num_members, debate_style, temperature, model_selection, continue_btn, save_history, current_debate],
534
+ output
535
+ ).then(
536
+ lambda x: x,
537
+ output,
538
+ current_debate,
539
+ preprocess=update_history
540
+ )
541
+
542
+ clear_btn.click(lambda: "", None, output)
543
+ refresh_history.click(load_history, None, history_output)
544
+ refresh_stats.click(show_stats, None, stats_output)
545
+
546
+ return app
547
+
548
+ if __name__ == "__main__":
549
+ get_device()
550
+ app = create_interface()
551
+ app.launch()