Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer | |
| import random | |
| import threading | |
| import torch | |
| import os | |
| import time | |
| import sys | |
| import logging | |
| from typing import List, Dict, Generator, Tuple, Optional | |
| from collections import defaultdict | |
| import gc | |
| # Configure Torch for CPU optimization | |
| torch.set_num_threads(os.cpu_count() or 1) | |
| torch.backends.quantized.engine = 'qnnpack' if torch.backends.quantized.supported_engines else None | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('council_debate.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --- Best Free Models for Council --- | |
| MODELS = [ | |
| ("mistralai/Mistral-7B-Instruct-v0.2", "Mistral 7B Instruct"), | |
| ("HuggingFaceH4/zephyr-7b-beta", "Zephyr 7B Beta"), | |
| ("NousResearch/Hermes-2-Pro-Mistral-7B", "Hermes 2 Pro"), | |
| ("cognitivecomputations/dolphin-2.6-mistral-7b", "Dolphin Mistral"), | |
| ] | |
| # Define council member personas | |
| PERSONAS = [ | |
| { | |
| "name": "Dr. Ana Rodriguez", | |
| "description": "An analytical scientist who values empirical evidence and logical reasoning.", | |
| "traits": "analytical, skeptical, evidence-focused", | |
| "style": "formal, precise, methodical", | |
| "emoji": "🔬", | |
| "preferred_models": ["Mistral 7B Instruct", "Zephyr 7B Beta"] | |
| }, | |
| { | |
| "name": "Professor Marcus Chen", | |
| "description": "A creative philosopher with an interest in ethics and societal implications.", | |
| "traits": "philosophical, visionary, empathetic", | |
| "style": "eloquent, metaphorical, conceptual", | |
| "emoji": "🧠", | |
| "preferred_models": ["Hermes 2 Pro", "Dolphin Mistral"] | |
| }, | |
| { | |
| "name": "Sarah Johnson", | |
| "description": "A pragmatic problem-solver with real-world experience.", | |
| "traits": "practical, solution-oriented, experienced", | |
| "style": "direct, concise, example-driven", | |
| "emoji": "🛠️", | |
| "preferred_models": ["Mistral 7B Instruct", "Hermes 2 Pro"] | |
| }, | |
| { | |
| "name": "Dr. Emeka Okafor", | |
| "description": "A social scientist specializing in cultural perspectives.", | |
| "traits": "culturally aware, nuanced, community-focused", | |
| "style": "inclusive, storytelling, perspective-oriented", | |
| "emoji": "🌍", | |
| "preferred_models": ["Dolphin Mistral", "Zephyr 7B Beta"] | |
| } | |
| ] | |
| # Cache for models | |
| model_cache = {} | |
| model_loading_lock = threading.Lock() | |
| stop_signal = threading.Event() | |
| def get_device_preference(): | |
| """Determine best device based on available resources""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def load_model(model_id: str) -> Tuple[pipeline, AutoTokenizer]: | |
| """Improved model loading with better caching and error handling""" | |
| global model_cache | |
| with model_loading_lock: | |
| if model_id in model_cache: | |
| logger.info(f"Using cached model: {model_id}") | |
| return model_cache[model_id] | |
| logger.info(f"Loading model: {model_id}") | |
| try: | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| device = get_device_preference() | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "device_map": "auto" if device == "cuda" else None, | |
| "torch_dtype": torch.float16 if device == "cuda" else torch.float32 | |
| } | |
| if device == "cpu": | |
| model_kwargs.update({ | |
| "low_cpu_mem_usage": True, | |
| "torch_dtype": torch.float32, | |
| }) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
| if device != "cuda": | |
| model = model.to(device) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=model.device | |
| ) | |
| model_cache[model_id] = (pipe, tokenizer) | |
| logger.info(f"Model loaded successfully: {model_id} on {device}") | |
| return pipe, tokenizer | |
| except Exception as e: | |
| logger.error(f"Failed to load model {model_id}: {str(e)}") | |
| if "out of memory" in str(e).lower() and device == "cuda": | |
| logger.info("Attempting to load with float16 to save memory") | |
| try: | |
| model_kwargs["torch_dtype"] = torch.float16 | |
| model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=model.device) | |
| model_cache[model_id] = (pipe, tokenizer) | |
| return pipe, tokenizer | |
| except Exception as e2: | |
| logger.error(f"Still failed to load model: {str(e2)}") | |
| raise | |
| def create_debate_prompt(user_prompt: str, persona: Dict, debate_style: str = "Balanced", previous_responses: Optional[List[Dict]] = None) -> str: | |
| """Enhanced prompt engineering for better debates""" | |
| persona_desc = ( | |
| f"Roleplay as {persona['name']}, {persona['description']}\n" | |
| f"Communication style: {persona['style']}\n" | |
| f"Key traits: {persona['traits']}\n\n" | |
| ) | |
| style_guidance = { | |
| "Collaborative": "Focus on building consensus and finding common ground. Acknowledge valid points from others.", | |
| "Adversarial": "Challenge assumptions and present strong counter-arguments. Don't shy from disagreement.", | |
| "Balanced": "Present your perspective while considering others' views. Be constructive in criticism." | |
| }.get(debate_style, "Present your authentic perspective.") | |
| context = ( | |
| f"The user has posed this topic for debate:\n\"{user_prompt}\"\n\n" | |
| f"Debate style: {style_guidance}\n" | |
| ) | |
| if previous_responses: | |
| debate_history = "\n\n".join([f"{r['name']}: {r['text']}" for r in previous_responses]) | |
| instructions = ( | |
| f"Previous discussion:\n{debate_history}\n\n" | |
| "Now respond naturally as your persona. Add new insights, agree/disagree respectfully, " | |
| "and maintain your character's style. Keep it to 3-4 paragraphs maximum." | |
| ) | |
| else: | |
| instructions = ( | |
| "Offer your initial perspective on the topic. Establish your position clearly " | |
| "while leaving room for discussion. 3-4 paragraphs maximum." | |
| ) | |
| return f"{persona_desc}{context}{instructions}\n\n{persona['name']}:" | |
| def create_synthesis_prompt(user_prompt: str, all_responses: List[Dict]) -> str: | |
| """Improved synthesis prompt for better conclusions""" | |
| debate_history = "\n\n".join([f"{r['name']} ({r['model']}): {r['text']}" for r in all_responses]) | |
| return f"""As the debate facilitator, synthesize this discussion: | |
| Original topic: "{user_prompt}" | |
| Debate transcript: | |
| {debate_history} | |
| Your synthesis should: | |
| 1. Identify 2-3 key points of agreement | |
| 2. Note major disagreements and why they exist | |
| 3. Highlight unique perspectives | |
| 4. Offer a balanced conclusion | |
| 5. Suggest next steps if appropriate | |
| Write in clear, concise bullet points followed by a short paragraph summary. | |
| Facilitator:""" | |
| def stream_model_response(pipe: pipeline, tokenizer: AutoTokenizer, prompt: str, speaker_name: str = None, temperature: float = 0.7, max_tokens: int = 512) -> Generator[str, None, None]: | |
| """Robust streaming with better formatting and stop handling""" | |
| try: | |
| if stop_signal.is_set(): | |
| yield "[Stopped by user]" if not speaker_name else f"**{speaker_name}:** [Stopped by user]" | |
| return | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(pipe.model.device) | |
| generation_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=min(max(temperature, 0.1), 1.0), | |
| top_p=0.95, | |
| repetition_penalty=1.1, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| thread = threading.Thread(target=pipe.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| if stop_signal.is_set(): | |
| pipe.model.config.use_cache = False | |
| thread.join(timeout=1) | |
| break | |
| buffer += new_text | |
| if " " in new_text or "\n" in new_text: | |
| if speaker_name: | |
| yield f"**{speaker_name}:** {buffer.strip()}" | |
| else: | |
| yield buffer.strip() | |
| if buffer.strip(): | |
| if speaker_name: | |
| yield f"**{speaker_name}:** {buffer.strip()}" | |
| else: | |
| yield buffer.strip() | |
| thread.join() | |
| except Exception as e: | |
| logger.error(f"Error in streaming: {str(e)}") | |
| yield "[Error in generation]" if not speaker_name else f"**{speaker_name}:** [Error in generation]" | |
| finally: | |
| gc.collect() | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| def select_models_for_personas(personas: List[Dict], models: List[Tuple[str, str]]) -> List[Tuple[str, str]]: | |
| """Match models to personas based on preferences""" | |
| selected = [] | |
| model_names = [m[1] for m in models] | |
| for persona in personas: | |
| for pref in persona.get("preferred_models", []): | |
| if pref in model_names: | |
| selected.append(models[model_names.index(pref)]) | |
| break | |
| else: | |
| selected.append(random.choice(models)) | |
| return selected | |
| def council_chat_stream(user_prompt: str, num_members: int = 3, debate_style: str = "Balanced", temperature: float = 0.7) -> Generator[str, None, None]: | |
| """Enhanced debate generation with better state management""" | |
| stop_signal.clear() | |
| if not user_prompt.strip(): | |
| yield "Please enter a topic for the council to debate." | |
| return | |
| start_time = time.time() | |
| try: | |
| selected_personas = random.sample(PERSONAS, min(num_members, len(PERSONAS))) | |
| selected_models = select_models_for_personas(selected_personas, MODELS) | |
| loaded_models = [] | |
| for i, (model_id, model_name) in enumerate(selected_models): | |
| if stop_signal.is_set(): | |
| yield "[Debate stopped during setup]" | |
| return | |
| yield f"**Loading:** {model_name} ({i+1}/{len(selected_models)})..." | |
| try: | |
| pipe, tokenizer = load_model(model_id) | |
| loaded_models.append((pipe, tokenizer, model_name)) | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {str(e)}") | |
| yield f"⚠️ Failed to load {model_name}. Trying with remaining models..." | |
| continue | |
| if not loaded_models: | |
| yield "❌ Error: No models could be loaded. Please try again later." | |
| return | |
| responses = [] | |
| formatted_responses = [] | |
| persona_responses = [] | |
| for i, (persona, (pipe, tokenizer, model_name)) in enumerate(zip(selected_personas, loaded_models)): | |
| if stop_signal.is_set(): | |
| yield "[Debate stopped by user]" | |
| return | |
| display_name = f"{persona['emoji']} {persona['name']} ({model_name})" | |
| prompt = create_debate_prompt(user_prompt, persona, debate_style, persona_responses) | |
| response_text = "" | |
| for partial in stream_model_response(pipe, tokenizer, prompt, display_name, temperature): | |
| if stop_signal.is_set(): | |
| break | |
| yield partial | |
| response_text = partial.split("**:")[-1].strip() | |
| if stop_signal.is_set(): | |
| yield "[Debate stopped during responses]" | |
| return | |
| response_data = { | |
| "name": persona['name'], | |
| "model": model_name, | |
| "text": response_text, | |
| "persona": persona | |
| } | |
| persona_responses.append(response_data) | |
| formatted_responses.append(partial) | |
| if not stop_signal.is_set(): | |
| yield "\n\n**✨ Council is now synthesizing the discussion...**\n" | |
| synthesis_model = random.choice(loaded_models) | |
| synthesis_prompt = create_synthesis_prompt(user_prompt, persona_responses) | |
| for partial in stream_model_response( | |
| synthesis_model[0], | |
| synthesis_model[1], | |
| synthesis_prompt, | |
| "✨ Facilitator's Synthesis", | |
| temperature*0.8 | |
| ): | |
| if stop_signal.is_set(): | |
| break | |
| yield partial | |
| elapsed_time = time.time() - start_time | |
| if not stop_signal.is_set(): | |
| transcript = ( | |
| f"**User Topic:** {user_prompt}\n\n" + | |
| "\n\n".join(formatted_responses) + | |
| f"\n\n---\n*Debate completed in {elapsed_time:.1f} seconds*" | |
| ) | |
| yield transcript | |
| else: | |
| yield "[Debate was stopped before completion]" | |
| except Exception as e: | |
| logger.error(f"Debate error: {str(e)}") | |
| yield f"⚠️ An error occurred during the debate: {str(e)}" | |
| def stop_debate(): | |
| """Signal to stop current debate generation""" | |
| stop_signal.set() | |
| return "Debate stopping... Please wait." | |
| def build_gradio_interface(): | |
| """Enhanced Gradio interface with better controls""" | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| max-width: 900px !important; | |
| } | |
| .council-header { | |
| text-align: center; | |
| margin-bottom: 1em; | |
| background: linear-gradient(45deg, #4b6cb7, #182848); | |
| color: white; | |
| padding: 1em; | |
| border-radius: 8px; | |
| } | |
| .debate-controls { | |
| background: #f8f9fa; | |
| padding: 1em; | |
| border-radius: 8px; | |
| margin-bottom: 1em; | |
| } | |
| .persona-card { | |
| margin: 0.5em 0; | |
| padding: 1em; | |
| border-radius: 8px; | |
| background: #ffffff; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .stop-button { | |
| background: #ff4d4d !important; | |
| color: white !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| <div class="council-header"> | |
| <h1>🤖🏛️ AI Council Debate Chamber</h1> | |
| <p>Experience multi-perspective AI debates with distinct personalities and models</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| inp = gr.Textbox( | |
| label="Debate Topic", | |
| placeholder="Enter a topic or question for the council to debate...", | |
| lines=4, | |
| max_lines=6 | |
| ) | |
| with gr.Group(elem_classes="debate-controls"): | |
| with gr.Row(): | |
| btn = gr.Button("Start Debate", variant="primary") | |
| stop_btn = gr.Button("Stop Debate", variant="stop", elem_classes="stop-button") | |
| with gr.Accordion("Debate Configuration", open=True): | |
| with gr.Row(): | |
| num_members = gr.Slider( | |
| label="Council Members", | |
| minimum=2, | |
| maximum=4, | |
| step=1, | |
| value=3 | |
| ) | |
| debate_style = gr.Dropdown( | |
| label="Debate Style", | |
| choices=["Collaborative", "Adversarial", "Balanced"], | |
| value="Balanced" | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| label="Creativity Level", | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.7 | |
| ) | |
| with gr.Accordion("Meet the Council Members", open=False): | |
| for persona in PERSONAS: | |
| with gr.Group(elem_classes="persona-card"): | |
| gr.Markdown(f""" | |
| **{persona['emoji']} {persona['name']}** | |
| *{persona['description']}* | |
| **Style:** {persona['style']} | |
| **Preferred Models:** {', '.join(persona.get('preferred_models', ['Any']))} | |
| """) | |
| with gr.Column(scale=3): | |
| out = gr.Markdown( | |
| label="Live Debate Transcript", | |
| value="*Debate transcript will appear here...*" | |
| ) | |
| with gr.Accordion("Session Information", open=False): | |
| gr.Markdown(""" | |
| **Technical Details:** | |
| - Uses multiple open-weight LLMs from Hugging Face | |
| - Each persona is matched with suitable models | |
| - Real-time streaming responses | |
| - Debate memory and context tracking | |
| """) | |
| with gr.Accordion("Example Debate Topics", open=False): | |
| examples = gr.Examples( | |
| examples=[ | |
| "Should AI development be regulated internationally?", | |
| "What's the most effective way to address income inequality?", | |
| "How should society balance free speech with preventing harm?", | |
| "Is universal basic income a viable solution for automation job loss?", | |
| "What ethical guidelines should govern genetic engineering?" | |
| ], | |
| inputs=inp, | |
| label="Click to try these examples" | |
| ) | |
| btn.click( | |
| fn=council_chat_stream, | |
| inputs=[inp, num_members, debate_style, temperature], | |
| outputs=out | |
| ) | |
| stop_btn.click( | |
| fn=stop_debate, | |
| outputs=out, | |
| queue=False | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **About This System:** | |
| - Each council member has distinct expertise and communication style | |
| - Different AI models are matched to personas for varied perspectives | |
| - Facilitator synthesizes the discussion at the end | |
| - Works best with GPU acceleration (but runs on CPU) | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| # System checks | |
| device = get_device_preference() | |
| print(f"\n{'='*40}") | |
| print(f"Starting AI Council Debate on {device.upper()}") | |
| print(f"Python: {sys.version.split()[0]}") | |
| print(f"PyTorch: {torch.__version__}") | |
| print(f"Gradio: {gr.__version__}") | |
| print(f"{'='*40}\n") | |
| if device == "cpu": | |
| print("WARNING: Running on CPU - expect slower performance") | |
| print("Recommendations:") | |
| print("- Close other memory-intensive applications") | |
| print("- Reduce number of council members (2-3)") | |
| print("- Be patient with response times (30-90 sec per response)\n") | |
| try: | |
| demo = build_gradio_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |
| except Exception as e: | |
| print(f"\nERROR: {str(e)}") | |
| print("\nTroubleshooting steps:") | |
| print("1. Check internet connection (required for model download)") | |
| print("2. Verify Hugging Face token is set if using Llama models") | |
| print("3. Try reducing number of council members") | |
| print("4. Restart the application\n") | |
| raise |