#!/usr/bin/env python3 """ AI Executive Chatbot - Gradio Application Main web interface for the CEO chatbot powered by dual-LLM architecture. Designed for deployment on Hugging Face Spaces. Usage: python app/app.py python app/app.py --share python app/app.py --voice-model username/model Environment: HF_TOKEN - Hugging Face token for loading models VOICE_MODEL_REPO - Voice model repository ID REFINEMENT_MODEL - Refinement model ID (optional) """ import os import sys import time import json from datetime import datetime from pathlib import Path from typing import Optional # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) import gradio as gr from loguru import logger from app.theme import ( get_akatsuki_theme, CUSTOM_CSS, get_header_html, get_footer_html, get_status_html, ) # Configuration DEFAULT_VOICE_MODEL = os.environ.get("VOICE_MODEL_REPO", "Chaitanya-aitf/autotrain-epz4k-fl17r") DEFAULT_REFINEMENT_MODEL = os.environ.get("REFINEMENT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") ENABLE_REFINEMENT = os.environ.get("ENABLE_REFINEMENT", "true").lower() == "true" MAX_HISTORY_TURNS = int(os.environ.get("MAX_HISTORY_TURNS", "5")) # Global pipeline instance pipeline = None model_status = "loading" status_message = "Initializing..." def load_pipeline( voice_model_id: str = DEFAULT_VOICE_MODEL, refinement_model_id: str = DEFAULT_REFINEMENT_MODEL, enable_refinement: bool = ENABLE_REFINEMENT, ) -> bool: """ Load the dual LLM pipeline. Args: voice_model_id: Voice model repository ID refinement_model_id: Refinement model ID enable_refinement: Whether to enable refinement Returns: True if loaded successfully """ global pipeline, model_status, status_message try: from src.inference.dual_llm_pipeline import DualLLMPipeline model_status = "loading" status_message = f"Loading voice model: {voice_model_id}..." logger.info(status_message) pipeline = DualLLMPipeline.from_hub( voice_model_id=voice_model_id, refinement_model_id=refinement_model_id if enable_refinement else None, load_in_4bit=True, enable_refinement=enable_refinement, enable_cache=True, ) model_status = "ready" status_message = "Model loaded successfully" logger.info(status_message) return True except Exception as e: model_status = "error" status_message = f"Failed to load model: {str(e)}" logger.error(status_message) return False def generate_response( message: str, history: list, temperature: float = 0.7, skip_refinement: bool = False, ) -> tuple[str, list, str]: """ Generate a response to the user message. Args: message: User's message history: Conversation history temperature: Generation temperature skip_refinement: Skip refinement stage Returns: Tuple of (response, updated_history, status_info) """ global pipeline if not message.strip(): return "", history, "" if pipeline is None: return "Model is still loading. Please wait...", history, "Error: Model not loaded" try: start_time = time.time() # Convert tuple history to messages format for pipeline conv_history = [] if history: for user_msg, bot_msg in history: conv_history.append({"role": "user", "content": user_msg}) if bot_msg: conv_history.append({"role": "assistant", "content": bot_msg}) # Update pipeline history pipeline.conversation_history = conv_history # Generate response result = pipeline.generate( user_message=message, skip_refinement=skip_refinement, voice_temperature=temperature, ) elapsed = time.time() - start_time # Build status info status_parts = [f"Time: {elapsed:.2f}s"] if result.was_refined: status_parts.append(f"Voice: {result.voice_model_time:.2f}s") status_parts.append(f"Refine: {result.refinement_time:.2f}s") status_info = " | ".join(status_parts) # Update history (tuple format for Gradio) history.append((message, result.final_response)) return result.final_response, history, status_info except Exception as e: logger.error(f"Generation error: {e}") return f"An error occurred: {str(e)}", history, f"Error: {str(e)}" def clear_conversation(): """Clear conversation history.""" global pipeline if pipeline: pipeline.clear_history() return [], "", "" def export_conversation(history: list) -> str: """Export conversation to JSON.""" if not history: return None # Convert tuple format to messages format for export messages = [] for user_msg, bot_msg in history: messages.append({"role": "user", "content": user_msg}) if bot_msg: messages.append({"role": "assistant", "content": bot_msg}) export_data = { "exported_at": datetime.now().isoformat(), "messages": messages, } # Create temp file temp_path = Path("conversation_export.json") with open(temp_path, "w", encoding="utf-8") as f: json.dump(export_data, f, indent=2, ensure_ascii=False) return str(temp_path) def create_app( voice_model_id: Optional[str] = None, enable_refinement: bool = True, share: bool = False, ) -> gr.Blocks: """ Create the Gradio application. Args: voice_model_id: Voice model to load enable_refinement: Enable refinement model share: Create public share link Returns: Gradio Blocks app """ with gr.Blocks( title="AI Executive Assistant", ) as app: # Header gr.HTML(get_header_html( title="AI Executive Assistant", subtitle="Chat with our CEO powered by AI", )) # Status indicator status_html = gr.HTML( get_status_html("loading", "Initializing model..."), elem_classes=["status-indicator"], ) with gr.Row(): with gr.Column(scale=4): # Chat interface chatbot = gr.Chatbot( label="Conversation", height=500, elem_classes=["chat-container"], ) # Input area with gr.Row(): msg_input = gr.Textbox( placeholder="Ask the CEO anything...", show_label=False, container=False, scale=4, elem_classes=["input-area"], ) submit_btn = gr.Button( "Send", variant="primary", scale=1, elem_classes=["primary-btn"], ) # Action buttons with gr.Row(): clear_btn = gr.Button( "Clear Conversation", variant="secondary", elem_classes=["secondary-btn"], ) export_btn = gr.Button( "Export", variant="secondary", elem_classes=["secondary-btn"], ) # Response info response_info = gr.Textbox( label="Response Info", interactive=False, visible=True, max_lines=1, ) with gr.Column(scale=1): # Settings gr.Markdown("### Settings") temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature", info="Higher = more creative", ) skip_refine = gr.Checkbox( label="Skip Refinement", value=False, info="Get raw voice model output", ) gr.Markdown("---") # Example prompts gr.Markdown("### Example Questions") example_prompts = [ "What is your vision for AI in business?", "How do you approach leadership?", "What advice do you have for entrepreneurs?", "Tell me about Akatsuki AI Technologies.", "What are the biggest challenges in AI today?", ] for prompt in example_prompts: gr.Button( prompt[:40] + "..." if len(prompt) > 40 else prompt, size="sm", ).click( fn=lambda p=prompt: p, outputs=msg_input, ) # Footer gr.HTML(get_footer_html()) # Hidden file output for export export_file = gr.File(visible=False) # Event handlers def on_submit(message, history, temp, skip): response, new_history, info = generate_response( message, history, temp, skip ) return "", new_history, info submit_btn.click( fn=on_submit, inputs=[msg_input, chatbot, temperature, skip_refine], outputs=[msg_input, chatbot, response_info], ) msg_input.submit( fn=on_submit, inputs=[msg_input, chatbot, temperature, skip_refine], outputs=[msg_input, chatbot, response_info], ) clear_btn.click( fn=clear_conversation, outputs=[chatbot, msg_input, response_info], ) export_btn.click( fn=export_conversation, inputs=[chatbot], outputs=[export_file], ) # Load model on app start def on_load(): global model_status voice_id = voice_model_id or DEFAULT_VOICE_MODEL success = load_pipeline(voice_id, enable_refinement=enable_refinement) if success: return get_status_html("ready", f"Model loaded: {voice_id}") else: return get_status_html("error", "Failed to load model") app.load(fn=on_load, outputs=[status_html]) return app def main(): """Main entry point.""" import argparse parser = argparse.ArgumentParser( description="Run the AI Executive chatbot", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python app/app.py python app/app.py --share python app/app.py --voice-model username/model --port 7861 Environment variables: VOICE_MODEL_REPO - Default voice model REFINEMENT_MODEL - Default refinement model ENABLE_REFINEMENT - Enable refinement (true/false) HF_TOKEN - Hugging Face token """, ) parser.add_argument( "--voice-model", help=f"Voice model ID (default: {DEFAULT_VOICE_MODEL})", ) parser.add_argument( "--refinement-model", help=f"Refinement model ID (default: {DEFAULT_REFINEMENT_MODEL})", ) parser.add_argument( "--no-refinement", action="store_true", help="Disable refinement model", ) parser.add_argument( "--share", action="store_true", help="Create public share link", ) parser.add_argument( "--port", type=int, default=7860, help="Port to run on (default: 7860)", ) parser.add_argument( "--server-name", default="0.0.0.0", help="Server name (default: 0.0.0.0)", ) args = parser.parse_args() # Create and launch app app = create_app( voice_model_id=args.voice_model, enable_refinement=not args.no_refinement, share=args.share, ) app.launch( server_name=args.server_name, server_port=args.port, share=args.share, show_error=True, theme=get_akatsuki_theme(), css=CUSTOM_CSS, ) if __name__ == "__main__": main()