ai_exec / app /app.py
Chaitanya-aitf's picture
Update app/app.py
51b7bbc verified
#!/usr/bin/env python3
"""
AI Executive Chatbot - Gradio Application
Main web interface for the CEO chatbot powered by dual-LLM architecture.
Designed for deployment on Hugging Face Spaces.
Usage:
python app/app.py
python app/app.py --share
python app/app.py --voice-model username/model
Environment:
HF_TOKEN - Hugging Face token for loading models
VOICE_MODEL_REPO - Voice model repository ID
REFINEMENT_MODEL - Refinement model ID (optional)
"""
import os
import sys
import time
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
import gradio as gr
from loguru import logger
from app.theme import (
get_akatsuki_theme,
CUSTOM_CSS,
get_header_html,
get_footer_html,
get_status_html,
)
# Configuration
DEFAULT_VOICE_MODEL = os.environ.get("VOICE_MODEL_REPO", "Chaitanya-aitf/autotrain-epz4k-fl17r")
DEFAULT_REFINEMENT_MODEL = os.environ.get("REFINEMENT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
ENABLE_REFINEMENT = os.environ.get("ENABLE_REFINEMENT", "true").lower() == "true"
MAX_HISTORY_TURNS = int(os.environ.get("MAX_HISTORY_TURNS", "5"))
# Global pipeline instance
pipeline = None
model_status = "loading"
status_message = "Initializing..."
def load_pipeline(
voice_model_id: str = DEFAULT_VOICE_MODEL,
refinement_model_id: str = DEFAULT_REFINEMENT_MODEL,
enable_refinement: bool = ENABLE_REFINEMENT,
) -> bool:
"""
Load the dual LLM pipeline.
Args:
voice_model_id: Voice model repository ID
refinement_model_id: Refinement model ID
enable_refinement: Whether to enable refinement
Returns:
True if loaded successfully
"""
global pipeline, model_status, status_message
try:
from src.inference.dual_llm_pipeline import DualLLMPipeline
model_status = "loading"
status_message = f"Loading voice model: {voice_model_id}..."
logger.info(status_message)
pipeline = DualLLMPipeline.from_hub(
voice_model_id=voice_model_id,
refinement_model_id=refinement_model_id if enable_refinement else None,
load_in_4bit=True,
enable_refinement=enable_refinement,
enable_cache=True,
)
model_status = "ready"
status_message = "Model loaded successfully"
logger.info(status_message)
return True
except Exception as e:
model_status = "error"
status_message = f"Failed to load model: {str(e)}"
logger.error(status_message)
return False
def generate_response(
message: str,
history: list,
temperature: float = 0.7,
skip_refinement: bool = False,
) -> tuple[str, list, str]:
"""
Generate a response to the user message.
Args:
message: User's message
history: Conversation history
temperature: Generation temperature
skip_refinement: Skip refinement stage
Returns:
Tuple of (response, updated_history, status_info)
"""
global pipeline
if not message.strip():
return "", history, ""
if pipeline is None:
return "Model is still loading. Please wait...", history, "Error: Model not loaded"
try:
start_time = time.time()
# Convert tuple history to messages format for pipeline
conv_history = []
if history:
for user_msg, bot_msg in history:
conv_history.append({"role": "user", "content": user_msg})
if bot_msg:
conv_history.append({"role": "assistant", "content": bot_msg})
# Update pipeline history
pipeline.conversation_history = conv_history
# Generate response
result = pipeline.generate(
user_message=message,
skip_refinement=skip_refinement,
voice_temperature=temperature,
)
elapsed = time.time() - start_time
# Build status info
status_parts = [f"Time: {elapsed:.2f}s"]
if result.was_refined:
status_parts.append(f"Voice: {result.voice_model_time:.2f}s")
status_parts.append(f"Refine: {result.refinement_time:.2f}s")
status_info = " | ".join(status_parts)
# Update history (tuple format for Gradio)
history.append((message, result.final_response))
return result.final_response, history, status_info
except Exception as e:
logger.error(f"Generation error: {e}")
return f"An error occurred: {str(e)}", history, f"Error: {str(e)}"
def clear_conversation():
"""Clear conversation history."""
global pipeline
if pipeline:
pipeline.clear_history()
return [], "", ""
def export_conversation(history: list) -> str:
"""Export conversation to JSON."""
if not history:
return None
# Convert tuple format to messages format for export
messages = []
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
export_data = {
"exported_at": datetime.now().isoformat(),
"messages": messages,
}
# Create temp file
temp_path = Path("conversation_export.json")
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(export_data, f, indent=2, ensure_ascii=False)
return str(temp_path)
def create_app(
voice_model_id: Optional[str] = None,
enable_refinement: bool = True,
share: bool = False,
) -> gr.Blocks:
"""
Create the Gradio application.
Args:
voice_model_id: Voice model to load
enable_refinement: Enable refinement model
share: Create public share link
Returns:
Gradio Blocks app
"""
with gr.Blocks(
title="AI Executive Assistant",
) as app:
# Header
gr.HTML(get_header_html(
title="AI Executive Assistant",
subtitle="Chat with our CEO powered by AI",
))
# Status indicator
status_html = gr.HTML(
get_status_html("loading", "Initializing model..."),
elem_classes=["status-indicator"],
)
with gr.Row():
with gr.Column(scale=4):
# Chat interface
chatbot = gr.Chatbot(
label="Conversation",
height=500,
elem_classes=["chat-container"],
)
# Input area
with gr.Row():
msg_input = gr.Textbox(
placeholder="Ask the CEO anything...",
show_label=False,
container=False,
scale=4,
elem_classes=["input-area"],
)
submit_btn = gr.Button(
"Send",
variant="primary",
scale=1,
elem_classes=["primary-btn"],
)
# Action buttons
with gr.Row():
clear_btn = gr.Button(
"Clear Conversation",
variant="secondary",
elem_classes=["secondary-btn"],
)
export_btn = gr.Button(
"Export",
variant="secondary",
elem_classes=["secondary-btn"],
)
# Response info
response_info = gr.Textbox(
label="Response Info",
interactive=False,
visible=True,
max_lines=1,
)
with gr.Column(scale=1):
# Settings
gr.Markdown("### Settings")
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative",
)
skip_refine = gr.Checkbox(
label="Skip Refinement",
value=False,
info="Get raw voice model output",
)
gr.Markdown("---")
# Example prompts
gr.Markdown("### Example Questions")
example_prompts = [
"What is your vision for AI in business?",
"How do you approach leadership?",
"What advice do you have for entrepreneurs?",
"Tell me about Akatsuki AI Technologies.",
"What are the biggest challenges in AI today?",
]
for prompt in example_prompts:
gr.Button(
prompt[:40] + "..." if len(prompt) > 40 else prompt,
size="sm",
).click(
fn=lambda p=prompt: p,
outputs=msg_input,
)
# Footer
gr.HTML(get_footer_html())
# Hidden file output for export
export_file = gr.File(visible=False)
# Event handlers
def on_submit(message, history, temp, skip):
response, new_history, info = generate_response(
message, history, temp, skip
)
return "", new_history, info
submit_btn.click(
fn=on_submit,
inputs=[msg_input, chatbot, temperature, skip_refine],
outputs=[msg_input, chatbot, response_info],
)
msg_input.submit(
fn=on_submit,
inputs=[msg_input, chatbot, temperature, skip_refine],
outputs=[msg_input, chatbot, response_info],
)
clear_btn.click(
fn=clear_conversation,
outputs=[chatbot, msg_input, response_info],
)
export_btn.click(
fn=export_conversation,
inputs=[chatbot],
outputs=[export_file],
)
# Load model on app start
def on_load():
global model_status
voice_id = voice_model_id or DEFAULT_VOICE_MODEL
success = load_pipeline(voice_id, enable_refinement=enable_refinement)
if success:
return get_status_html("ready", f"Model loaded: {voice_id}")
else:
return get_status_html("error", "Failed to load model")
app.load(fn=on_load, outputs=[status_html])
return app
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(
description="Run the AI Executive chatbot",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python app/app.py
python app/app.py --share
python app/app.py --voice-model username/model --port 7861
Environment variables:
VOICE_MODEL_REPO - Default voice model
REFINEMENT_MODEL - Default refinement model
ENABLE_REFINEMENT - Enable refinement (true/false)
HF_TOKEN - Hugging Face token
""",
)
parser.add_argument(
"--voice-model",
help=f"Voice model ID (default: {DEFAULT_VOICE_MODEL})",
)
parser.add_argument(
"--refinement-model",
help=f"Refinement model ID (default: {DEFAULT_REFINEMENT_MODEL})",
)
parser.add_argument(
"--no-refinement",
action="store_true",
help="Disable refinement model",
)
parser.add_argument(
"--share",
action="store_true",
help="Create public share link",
)
parser.add_argument(
"--port",
type=int,
default=7860,
help="Port to run on (default: 7860)",
)
parser.add_argument(
"--server-name",
default="0.0.0.0",
help="Server name (default: 0.0.0.0)",
)
args = parser.parse_args()
# Create and launch app
app = create_app(
voice_model_id=args.voice_model,
enable_refinement=not args.no_refinement,
share=args.share,
)
app.launch(
server_name=args.server_name,
server_port=args.port,
share=args.share,
show_error=True,
theme=get_akatsuki_theme(),
css=CUSTOM_CSS,
)
if __name__ == "__main__":
main()