""" AstraMind Stage 1 - Gradio Chat Interface Main application file """ import os import sys import time import json from datetime import datetime from pathlib import Path # Add src directory to path for imports sys.path.insert(0, str(Path(__file__).parent / "src")) import gradio as gr import gradio_client.utils # --- HF-Space hot-patch for Gradio schema bug (bool is not iterable) --- import gradio_client.utils def _safe_json_schema_to_python_type(schema, defs=None): if isinstance(schema, bool): # <-- short-circuit the bad case return "bool" try: return gradio_client.utils._json_schema_to_python_type(schema, defs) except Exception: return "Any" gradio_client.utils._json_schema_to_python_type = _safe_json_schema_to_python_type gr.routes.api_info = lambda *a, **k: {} # skip OpenAPI generation # ----------------------------------------------------------------------- from backend.chat_engine import ChatEngine from backend.cache import ResponseCache from backend.session_manager import SessionManager from backend.model_registry import list_models, get_model_display_names, list_openrouter_models, list_hf_models from backend.utils import count_tokens, calculate_cost, format_duration, get_timestamp # Import export utilities from their location sys.path.insert(0, str(Path(__file__).parent / "src" / "frontend" / "gradio_app")) from export_utils import ( export_to_txt, export_to_markdown, export_to_json, export_to_csv, export_to_audio, export_to_pdf ) # Load custom CSS css_file = Path(__file__).parent / "src" / "frontend" / "gradio_app" / "styles.css" with open(css_file, 'r') as f: custom_css = f.read() # Initialize global components cache = ResponseCache(ttl=3600) session_manager = SessionManager(base_dir="chat-history") chat_engine = None # Will be initialized when API key is provided # Session state session_start_time = time.time() current_session_id = None total_tokens_used = 0 def initialize_chat_engine(api_key: str) -> tuple: """Initialize chat engine with API key and update model choices""" global chat_engine try: if api_key and api_key.strip(): # OpenRouter mode chat_engine = ChatEngine(api_key=api_key, cache=cache) model_choices = list_openrouter_models() default_model = "gpt-4o-mini" status = "✓ OpenRouter initialized" else: # HuggingFace mode chat_engine = ChatEngine(api_key=None, cache=cache) model_choices = list_hf_models() default_model = "openchat" status = "✓ HuggingFace models ready" return gr.update(visible=True), gr.update(value=status), gr.update(choices=model_choices, value=default_model) except Exception as e: return gr.update(visible=False), gr.update(value=f"✗ Error: {str(e)}"), gr.update() def chat_response(message: str, history: list, model: str, temperature: float, api_key: str, use_cache: bool, system_message: str = "") -> tuple: """Handle chat message and generate response using OpenAI message format""" global chat_engine, total_tokens_used, current_session_id if not message or message.strip() == "": return history, "", gr.update(), gr.update(), gr.update() # Initialize engine if needed if chat_engine is None: try: chat_engine = ChatEngine(api_key=api_key, cache=cache) except Exception as e: history.append({"role": "assistant", "content": f"Error: {str(e)}"}) return history, "", gr.update(), gr.update(), gr.update() # Create session if needed if current_session_id is None: current_session_id = session_manager.create_session(model) # Prepend system message if provided and this is the first user message if system_message and system_message.strip(): # Check if system message is not already in history if not any(msg.get("role") == "system" for msg in history): history.insert(0, {"role": "system", "content": system_message.strip()}) # Count input tokens input_tokens = count_tokens(message, model) # Add user message to history (OpenAI format) history.append({"role": "user", "content": message}) # Add placeholder for assistant response history.append({"role": "assistant", "content": ""}) # Stream response response_text = "" try: for chunk in chat_engine.chat(message, model=model, stream=True, use_cache=use_cache, temperature=temperature, system_message=system_message if system_message and system_message.strip() else None): response_text += chunk history[-1]["content"] = response_text yield history, "", gr.update(), gr.update(), gr.update() # Count output tokens output_tokens = count_tokens(response_text, model) total_tokens_used += input_tokens + output_tokens # Save session messages = [] for msg in history: messages.append({ "role": msg["role"], "content": msg["content"], "timestamp": get_timestamp(), "tokens": count_tokens(msg["content"], model) }) session_manager.save_session( current_session_id, messages, { "model": model, "total_tokens": total_tokens_used, "created_at": datetime.fromtimestamp(session_start_time).isoformat() } ) # Update UI token_display = f"{total_tokens_used:,}" cost = calculate_cost(input_tokens, output_tokens, model) cost_display = f"${cost:.6f}" cache_stats = cache.get_stats() yield history, "", gr.update(value=token_display), gr.update(value=cost_display), gr.update(value=cache_stats) except Exception as e: history[-1]["content"] = f"Error: {str(e)}" yield history, "", gr.update(), gr.update(), gr.update() def clear_chat() -> tuple: """Clear chat history""" global chat_engine, total_tokens_used, current_session_id if chat_engine: chat_engine.clear_history() total_tokens_used = 0 current_session_id = None return [], gr.update(value="0"), gr.update(value="$0.00"), gr.update() def update_timer() -> str: """Update session duration timer""" elapsed = int(time.time() - session_start_time) return format_duration(elapsed) def export_txt_handler(history: list, date_start, date_end, roles: list) -> str: """Export chat to TXT - converts OpenAI format to tuple format for export""" tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) for msg in history] # Merge consecutive user/assistant pairs merged = [] i = 0 while i < len(tuple_history): if i + 1 < len(tuple_history): user_msg = tuple_history[i][0] if tuple_history[i][0] else "" asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" merged.append((user_msg, asst_msg)) i += 2 else: merged.append(tuple_history[i]) i += 1 return export_to_txt(merged, date_start, date_end, roles, current_session_id) def export_md_handler(history: list, date_start, date_end, roles: list) -> str: """Export chat to Markdown""" tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) for msg in history] merged = [] i = 0 while i < len(tuple_history): if i + 1 < len(tuple_history): user_msg = tuple_history[i][0] if tuple_history[i][0] else "" asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" merged.append((user_msg, asst_msg)) i += 2 else: merged.append(tuple_history[i]) i += 1 return export_to_markdown(merged, date_start, date_end, roles, current_session_id) def export_json_handler(history: list, date_start, date_end, roles: list) -> str: """Export chat to JSON""" tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) for msg in history] merged = [] i = 0 while i < len(tuple_history): if i + 1 < len(tuple_history): user_msg = tuple_history[i][0] if tuple_history[i][0] else "" asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" merged.append((user_msg, asst_msg)) i += 2 else: merged.append(tuple_history[i]) i += 1 return export_to_json(merged, date_start, date_end, roles, current_session_id, total_tokens_used, session_start_time) def export_csv_handler(history: list, date_start, date_end, roles: list) -> str: """Export chat to CSV""" tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) for msg in history] merged = [] i = 0 while i < len(tuple_history): if i + 1 < len(tuple_history): user_msg = tuple_history[i][0] if tuple_history[i][0] else "" asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" merged.append((user_msg, asst_msg)) i += 2 else: merged.append(tuple_history[i]) i += 1 return export_to_csv(merged, date_start, date_end, roles, current_session_id) def export_audio_handler(history: list, date_start, date_end, roles: list) -> str: """Export chat to Audio (TTS)""" tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) for msg in history] merged = [] i = 0 while i < len(tuple_history): if i + 1 < len(tuple_history): user_msg = tuple_history[i][0] if tuple_history[i][0] else "" asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" merged.append((user_msg, asst_msg)) i += 2 else: merged.append(tuple_history[i]) i += 1 return export_to_audio(merged, date_start, date_end, roles, current_session_id) def export_pdf_handler(history: list, date_start, date_end, roles: list) -> str: """Export chat to PDF""" tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) for msg in history] merged = [] i = 0 while i < len(tuple_history): if i + 1 < len(tuple_history): user_msg = tuple_history[i][0] if tuple_history[i][0] else "" asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" merged.append((user_msg, asst_msg)) i += 2 else: merged.append(tuple_history[i]) i += 1 return export_to_pdf(merged, date_start, date_end, roles, current_session_id, total_tokens_used) # Build Gradio Interface with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AstraMind Chat") as app: gr.Markdown("# 🌟 AstraMind Chat - Stage 1") # Hidden state for tracking init_status = gr.Textbox(visible=False) # API Key input at the top with gr.Row(): api_key_input = gr.Textbox( label="OpenRouter API Key (optional - leave empty for HuggingFace models)", type="password", placeholder="sk-or-...", scale=4 ) init_btn = gr.Button("Initialize", scale=1, variant="primary") init_status_display = gr.Textbox(label="Status", scale=2, interactive=False) # Main interface (hidden until initialized) main_interface = gr.Column(visible=True) with main_interface: with gr.Row(): # Left sidebar with gr.Column(scale=1): gr.Markdown("### ⚙️ Settings") model_dropdown = gr.Dropdown( choices=list_hf_models(), value="openchat", label="Model", info="Select AI model" ) temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative" ) use_cache_checkbox = gr.Checkbox( value=True, label="Use Response Cache", info="Cache identical queries" ) # System Message Accordion with gr.Accordion("💬 System Message", open=False): system_message = gr.Textbox( label="System Prompt", placeholder="Enter a system message to guide the AI's behavior...", lines=4, value="" ) gr.Markdown("### 📊 Statistics") token_display = gr.Textbox( label="Total Tokens", value="0", interactive=False ) cost_display = gr.Textbox( label="Estimated Cost", value="$0.00", interactive=False ) session_timer = gr.Textbox( label="Session Duration", value="0s", interactive=False ) try: safe_cache_stats = json.loads(json.dumps(cache.get_stats(), default=str)) except Exception: safe_cache_stats = {} cache_stats_display = gr.JSON( label="Cache Stats", value=safe_cache_stats ) # Center: Chat interface with gr.Column(scale=3): # Get absolute path to bot avatar avatar_path = Path(__file__).parent / "assets" / "bot-avatar.png" chatbot = gr.Chatbot( height=600, show_label=False, type="messages", # OpenAI-compatible format avatar_images=( None, # user avatar placeholder str(avatar_path) if avatar_path.exists() else None # assistant avatar ), render_markdown=True ) with gr.Row(): msg_input = gr.Textbox( placeholder="Message AstraMind...", show_label=False, scale=9, container=False ) send_btn = gr.Button("Send", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("🗑️ Clear Chat", scale=1) regenerate_btn = gr.Button("🔄 Regenerate", scale=1, visible=False) # Right: Export panel with gr.Column(scale=1): gr.Markdown("### 📤 Export Suite") gr.Markdown("**Filters:**") date_start = gr.Textbox( label="From Date (YYYY-MM-DD)", placeholder="2024-01-01" ) date_end = gr.Textbox( label="To Date (YYYY-MM-DD)", placeholder="2024-12-31" ) role_filter = gr.CheckboxGroup( choices=["user", "assistant", "system"], value=["user", "assistant"], label="Include Roles" ) gr.Markdown("**Export Formats:**") export_txt_btn = gr.Button("📄 Export TXT", size="sm") export_md_btn = gr.Button("📝 Export MD", size="sm") export_json_btn = gr.Button("📋 Export JSON", size="sm") export_csv_btn = gr.Button("📊 Export CSV", size="sm") export_audio_btn = gr.Button("🔊 Export Audio", size="sm") export_pdf_btn = gr.Button("📕 Export PDF", size="sm") download_file = gr.File(label="Download", visible=True) # Event handlers init_btn.click( fn=initialize_chat_engine, inputs=[api_key_input], outputs=[main_interface, init_status_display, model_dropdown] ) # Chat handlers msg_input.submit( fn=chat_response, inputs=[msg_input, chatbot, model_dropdown, temperature_slider, api_key_input, use_cache_checkbox, system_message], outputs=[chatbot, msg_input, token_display, cost_display, cache_stats_display] ) send_btn.click( fn=chat_response, inputs=[msg_input, chatbot, model_dropdown, temperature_slider, api_key_input, use_cache_checkbox, system_message], outputs=[chatbot, msg_input, token_display, cost_display, cache_stats_display] ) clear_btn.click( fn=clear_chat, inputs=[], outputs=[chatbot, token_display, cost_display, cache_stats_display] ) # Export handlers export_txt_btn.click( fn=export_txt_handler, inputs=[chatbot, date_start, date_end, role_filter], outputs=[download_file] ) export_md_btn.click( fn=export_md_handler, inputs=[chatbot, date_start, date_end, role_filter], outputs=[download_file] ) export_json_btn.click( fn=export_json_handler, inputs=[chatbot, date_start, date_end, role_filter], outputs=[download_file] ) export_csv_btn.click( fn=export_csv_handler, inputs=[chatbot, date_start, date_end, role_filter], outputs=[download_file] ) export_audio_btn.click( fn=export_audio_handler, inputs=[chatbot, date_start, date_end, role_filter], outputs=[download_file] ) export_pdf_btn.click( fn=export_pdf_handler, inputs=[chatbot, date_start, date_end, role_filter], outputs=[download_file] ) # Timer update (using gr.Timer for periodic updates) timer = gr.Timer(value=1, active=True) timer.tick(fn=update_timer, outputs=[session_timer]) if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=7860, show_api=False, share=False, ) # allowed_paths=[str(Path(__file__).parent / "assets")],