Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| AI Executive Chatbot - Gradio Application | |
| Main web interface for the CEO chatbot powered by dual-LLM architecture. | |
| Designed for deployment on Hugging Face Spaces. | |
| Usage: | |
| python app/app.py | |
| python app/app.py --share | |
| python app/app.py --voice-model username/model | |
| Environment: | |
| HF_TOKEN - Hugging Face token for loading models | |
| VOICE_MODEL_REPO - Voice model repository ID | |
| REFINEMENT_MODEL - Refinement model ID (optional) | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional | |
| # Add parent directory to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| import gradio as gr | |
| from loguru import logger | |
| from app.theme import ( | |
| get_akatsuki_theme, | |
| CUSTOM_CSS, | |
| get_header_html, | |
| get_footer_html, | |
| get_status_html, | |
| ) | |
| # Configuration | |
| DEFAULT_VOICE_MODEL = os.environ.get("VOICE_MODEL_REPO", "Chaitanya-aitf/autotrain-epz4k-fl17r") | |
| DEFAULT_REFINEMENT_MODEL = os.environ.get("REFINEMENT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| ENABLE_REFINEMENT = os.environ.get("ENABLE_REFINEMENT", "true").lower() == "true" | |
| MAX_HISTORY_TURNS = int(os.environ.get("MAX_HISTORY_TURNS", "5")) | |
| # Global pipeline instance | |
| pipeline = None | |
| model_status = "loading" | |
| status_message = "Initializing..." | |
| def load_pipeline( | |
| voice_model_id: str = DEFAULT_VOICE_MODEL, | |
| refinement_model_id: str = DEFAULT_REFINEMENT_MODEL, | |
| enable_refinement: bool = ENABLE_REFINEMENT, | |
| ) -> bool: | |
| """ | |
| Load the dual LLM pipeline. | |
| Args: | |
| voice_model_id: Voice model repository ID | |
| refinement_model_id: Refinement model ID | |
| enable_refinement: Whether to enable refinement | |
| Returns: | |
| True if loaded successfully | |
| """ | |
| global pipeline, model_status, status_message | |
| try: | |
| from src.inference.dual_llm_pipeline import DualLLMPipeline | |
| model_status = "loading" | |
| status_message = f"Loading voice model: {voice_model_id}..." | |
| logger.info(status_message) | |
| pipeline = DualLLMPipeline.from_hub( | |
| voice_model_id=voice_model_id, | |
| refinement_model_id=refinement_model_id if enable_refinement else None, | |
| load_in_4bit=True, | |
| enable_refinement=enable_refinement, | |
| enable_cache=True, | |
| ) | |
| model_status = "ready" | |
| status_message = "Model loaded successfully" | |
| logger.info(status_message) | |
| return True | |
| except Exception as e: | |
| model_status = "error" | |
| status_message = f"Failed to load model: {str(e)}" | |
| logger.error(status_message) | |
| return False | |
| def generate_response( | |
| message: str, | |
| history: list, | |
| temperature: float = 0.7, | |
| skip_refinement: bool = False, | |
| ) -> tuple[str, list, str]: | |
| """ | |
| Generate a response to the user message. | |
| Args: | |
| message: User's message | |
| history: Conversation history | |
| temperature: Generation temperature | |
| skip_refinement: Skip refinement stage | |
| Returns: | |
| Tuple of (response, updated_history, status_info) | |
| """ | |
| global pipeline | |
| if not message.strip(): | |
| return "", history, "" | |
| if pipeline is None: | |
| return "Model is still loading. Please wait...", history, "Error: Model not loaded" | |
| try: | |
| start_time = time.time() | |
| # Convert tuple history to messages format for pipeline | |
| conv_history = [] | |
| if history: | |
| for user_msg, bot_msg in history: | |
| conv_history.append({"role": "user", "content": user_msg}) | |
| if bot_msg: | |
| conv_history.append({"role": "assistant", "content": bot_msg}) | |
| # Update pipeline history | |
| pipeline.conversation_history = conv_history | |
| # Generate response | |
| result = pipeline.generate( | |
| user_message=message, | |
| skip_refinement=skip_refinement, | |
| voice_temperature=temperature, | |
| ) | |
| elapsed = time.time() - start_time | |
| # Build status info | |
| status_parts = [f"Time: {elapsed:.2f}s"] | |
| if result.was_refined: | |
| status_parts.append(f"Voice: {result.voice_model_time:.2f}s") | |
| status_parts.append(f"Refine: {result.refinement_time:.2f}s") | |
| status_info = " | ".join(status_parts) | |
| # Update history (tuple format for Gradio) | |
| history.append((message, result.final_response)) | |
| return result.final_response, history, status_info | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| return f"An error occurred: {str(e)}", history, f"Error: {str(e)}" | |
| def clear_conversation(): | |
| """Clear conversation history.""" | |
| global pipeline | |
| if pipeline: | |
| pipeline.clear_history() | |
| return [], "", "" | |
| def export_conversation(history: list) -> str: | |
| """Export conversation to JSON.""" | |
| if not history: | |
| return None | |
| # Convert tuple format to messages format for export | |
| messages = [] | |
| for user_msg, bot_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if bot_msg: | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| export_data = { | |
| "exported_at": datetime.now().isoformat(), | |
| "messages": messages, | |
| } | |
| # Create temp file | |
| temp_path = Path("conversation_export.json") | |
| with open(temp_path, "w", encoding="utf-8") as f: | |
| json.dump(export_data, f, indent=2, ensure_ascii=False) | |
| return str(temp_path) | |
| def create_app( | |
| voice_model_id: Optional[str] = None, | |
| enable_refinement: bool = True, | |
| share: bool = False, | |
| ) -> gr.Blocks: | |
| """ | |
| Create the Gradio application. | |
| Args: | |
| voice_model_id: Voice model to load | |
| enable_refinement: Enable refinement model | |
| share: Create public share link | |
| Returns: | |
| Gradio Blocks app | |
| """ | |
| with gr.Blocks( | |
| title="AI Executive Assistant", | |
| ) as app: | |
| # Header | |
| gr.HTML(get_header_html( | |
| title="AI Executive Assistant", | |
| subtitle="Chat with our CEO powered by AI", | |
| )) | |
| # Status indicator | |
| status_html = gr.HTML( | |
| get_status_html("loading", "Initializing model..."), | |
| elem_classes=["status-indicator"], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=500, | |
| elem_classes=["chat-container"], | |
| ) | |
| # Input area | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| placeholder="Ask the CEO anything...", | |
| show_label=False, | |
| container=False, | |
| scale=4, | |
| elem_classes=["input-area"], | |
| ) | |
| submit_btn = gr.Button( | |
| "Send", | |
| variant="primary", | |
| scale=1, | |
| elem_classes=["primary-btn"], | |
| ) | |
| # Action buttons | |
| with gr.Row(): | |
| clear_btn = gr.Button( | |
| "Clear Conversation", | |
| variant="secondary", | |
| elem_classes=["secondary-btn"], | |
| ) | |
| export_btn = gr.Button( | |
| "Export", | |
| variant="secondary", | |
| elem_classes=["secondary-btn"], | |
| ) | |
| # Response info | |
| response_info = gr.Textbox( | |
| label="Response Info", | |
| interactive=False, | |
| visible=True, | |
| max_lines=1, | |
| ) | |
| with gr.Column(scale=1): | |
| # Settings | |
| gr.Markdown("### Settings") | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative", | |
| ) | |
| skip_refine = gr.Checkbox( | |
| label="Skip Refinement", | |
| value=False, | |
| info="Get raw voice model output", | |
| ) | |
| gr.Markdown("---") | |
| # Example prompts | |
| gr.Markdown("### Example Questions") | |
| example_prompts = [ | |
| "What is your vision for AI in business?", | |
| "How do you approach leadership?", | |
| "What advice do you have for entrepreneurs?", | |
| "Tell me about Akatsuki AI Technologies.", | |
| "What are the biggest challenges in AI today?", | |
| ] | |
| for prompt in example_prompts: | |
| gr.Button( | |
| prompt[:40] + "..." if len(prompt) > 40 else prompt, | |
| size="sm", | |
| ).click( | |
| fn=lambda p=prompt: p, | |
| outputs=msg_input, | |
| ) | |
| # Footer | |
| gr.HTML(get_footer_html()) | |
| # Hidden file output for export | |
| export_file = gr.File(visible=False) | |
| # Event handlers | |
| def on_submit(message, history, temp, skip): | |
| response, new_history, info = generate_response( | |
| message, history, temp, skip | |
| ) | |
| return "", new_history, info | |
| submit_btn.click( | |
| fn=on_submit, | |
| inputs=[msg_input, chatbot, temperature, skip_refine], | |
| outputs=[msg_input, chatbot, response_info], | |
| ) | |
| msg_input.submit( | |
| fn=on_submit, | |
| inputs=[msg_input, chatbot, temperature, skip_refine], | |
| outputs=[msg_input, chatbot, response_info], | |
| ) | |
| clear_btn.click( | |
| fn=clear_conversation, | |
| outputs=[chatbot, msg_input, response_info], | |
| ) | |
| export_btn.click( | |
| fn=export_conversation, | |
| inputs=[chatbot], | |
| outputs=[export_file], | |
| ) | |
| # Load model on app start | |
| def on_load(): | |
| global model_status | |
| voice_id = voice_model_id or DEFAULT_VOICE_MODEL | |
| success = load_pipeline(voice_id, enable_refinement=enable_refinement) | |
| if success: | |
| return get_status_html("ready", f"Model loaded: {voice_id}") | |
| else: | |
| return get_status_html("error", "Failed to load model") | |
| app.load(fn=on_load, outputs=[status_html]) | |
| return app | |
| def main(): | |
| """Main entry point.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Run the AI Executive chatbot", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python app/app.py | |
| python app/app.py --share | |
| python app/app.py --voice-model username/model --port 7861 | |
| Environment variables: | |
| VOICE_MODEL_REPO - Default voice model | |
| REFINEMENT_MODEL - Default refinement model | |
| ENABLE_REFINEMENT - Enable refinement (true/false) | |
| HF_TOKEN - Hugging Face token | |
| """, | |
| ) | |
| parser.add_argument( | |
| "--voice-model", | |
| help=f"Voice model ID (default: {DEFAULT_VOICE_MODEL})", | |
| ) | |
| parser.add_argument( | |
| "--refinement-model", | |
| help=f"Refinement model ID (default: {DEFAULT_REFINEMENT_MODEL})", | |
| ) | |
| parser.add_argument( | |
| "--no-refinement", | |
| action="store_true", | |
| help="Disable refinement model", | |
| ) | |
| parser.add_argument( | |
| "--share", | |
| action="store_true", | |
| help="Create public share link", | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=7860, | |
| help="Port to run on (default: 7860)", | |
| ) | |
| parser.add_argument( | |
| "--server-name", | |
| default="0.0.0.0", | |
| help="Server name (default: 0.0.0.0)", | |
| ) | |
| args = parser.parse_args() | |
| # Create and launch app | |
| app = create_app( | |
| voice_model_id=args.voice_model, | |
| enable_refinement=not args.no_refinement, | |
| share=args.share, | |
| ) | |
| app.launch( | |
| server_name=args.server_name, | |
| server_port=args.port, | |
| share=args.share, | |
| show_error=True, | |
| theme=get_akatsuki_theme(), | |
| css=CUSTOM_CSS, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |