Spaces:
Sleeping
Sleeping
| """ | |
| AstraMind Stage 1 - Gradio Chat Interface | |
| Main application file | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| # Add src directory to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent / "src")) | |
| import gradio as gr | |
| import gradio_client.utils | |
| # --- HF-Space hot-patch for Gradio schema bug (bool is not iterable) --- | |
| import gradio_client.utils | |
| def _safe_json_schema_to_python_type(schema, defs=None): | |
| if isinstance(schema, bool): # <-- short-circuit the bad case | |
| return "bool" | |
| try: | |
| return gradio_client.utils._json_schema_to_python_type(schema, defs) | |
| except Exception: | |
| return "Any" | |
| gradio_client.utils._json_schema_to_python_type = _safe_json_schema_to_python_type | |
| gr.routes.api_info = lambda *a, **k: {} # skip OpenAPI generation | |
| # ----------------------------------------------------------------------- | |
| from backend.chat_engine import ChatEngine | |
| from backend.cache import ResponseCache | |
| from backend.session_manager import SessionManager | |
| from backend.model_registry import list_models, get_model_display_names, list_openrouter_models, list_hf_models | |
| from backend.utils import count_tokens, calculate_cost, format_duration, get_timestamp | |
| # Import export utilities from their location | |
| sys.path.insert(0, str(Path(__file__).parent / "src" / "frontend" / "gradio_app")) | |
| from export_utils import ( | |
| export_to_txt, export_to_markdown, export_to_json, | |
| export_to_csv, export_to_audio, export_to_pdf | |
| ) | |
| # Load custom CSS | |
| css_file = Path(__file__).parent / "src" / "frontend" / "gradio_app" / "styles.css" | |
| with open(css_file, 'r') as f: | |
| custom_css = f.read() | |
| # Initialize global components | |
| cache = ResponseCache(ttl=3600) | |
| session_manager = SessionManager(base_dir="chat-history") | |
| chat_engine = None # Will be initialized when API key is provided | |
| # Session state | |
| session_start_time = time.time() | |
| current_session_id = None | |
| total_tokens_used = 0 | |
| def initialize_chat_engine(api_key: str) -> tuple: | |
| """Initialize chat engine with API key and update model choices""" | |
| global chat_engine | |
| try: | |
| if api_key and api_key.strip(): | |
| # OpenRouter mode | |
| chat_engine = ChatEngine(api_key=api_key, cache=cache) | |
| model_choices = list_openrouter_models() | |
| default_model = "gpt-4o-mini" | |
| status = "β OpenRouter initialized" | |
| else: | |
| # HuggingFace mode | |
| chat_engine = ChatEngine(api_key=None, cache=cache) | |
| model_choices = list_hf_models() | |
| default_model = "openchat" | |
| status = "β HuggingFace models ready" | |
| return gr.update(visible=True), gr.update(value=status), gr.update(choices=model_choices, value=default_model) | |
| except Exception as e: | |
| return gr.update(visible=False), gr.update(value=f"β Error: {str(e)}"), gr.update() | |
| def chat_response(message: str, history: list, model: str, temperature: float, | |
| api_key: str, use_cache: bool, system_message: str = "") -> tuple: | |
| """Handle chat message and generate response using OpenAI message format""" | |
| global chat_engine, total_tokens_used, current_session_id | |
| if not message or message.strip() == "": | |
| return history, "", gr.update(), gr.update(), gr.update() | |
| # Initialize engine if needed | |
| if chat_engine is None: | |
| try: | |
| chat_engine = ChatEngine(api_key=api_key, cache=cache) | |
| except Exception as e: | |
| history.append({"role": "assistant", "content": f"Error: {str(e)}"}) | |
| return history, "", gr.update(), gr.update(), gr.update() | |
| # Create session if needed | |
| if current_session_id is None: | |
| current_session_id = session_manager.create_session(model) | |
| # Prepend system message if provided and this is the first user message | |
| if system_message and system_message.strip(): | |
| # Check if system message is not already in history | |
| if not any(msg.get("role") == "system" for msg in history): | |
| history.insert(0, {"role": "system", "content": system_message.strip()}) | |
| # Count input tokens | |
| input_tokens = count_tokens(message, model) | |
| # Add user message to history (OpenAI format) | |
| history.append({"role": "user", "content": message}) | |
| # Add placeholder for assistant response | |
| history.append({"role": "assistant", "content": ""}) | |
| # Stream response | |
| response_text = "" | |
| try: | |
| for chunk in chat_engine.chat(message, model=model, stream=True, | |
| use_cache=use_cache, temperature=temperature, | |
| system_message=system_message if system_message and system_message.strip() else None): | |
| response_text += chunk | |
| history[-1]["content"] = response_text | |
| yield history, "", gr.update(), gr.update(), gr.update() | |
| # Count output tokens | |
| output_tokens = count_tokens(response_text, model) | |
| total_tokens_used += input_tokens + output_tokens | |
| # Save session | |
| messages = [] | |
| for msg in history: | |
| messages.append({ | |
| "role": msg["role"], | |
| "content": msg["content"], | |
| "timestamp": get_timestamp(), | |
| "tokens": count_tokens(msg["content"], model) | |
| }) | |
| session_manager.save_session( | |
| current_session_id, | |
| messages, | |
| { | |
| "model": model, | |
| "total_tokens": total_tokens_used, | |
| "created_at": datetime.fromtimestamp(session_start_time).isoformat() | |
| } | |
| ) | |
| # Update UI | |
| token_display = f"{total_tokens_used:,}" | |
| cost = calculate_cost(input_tokens, output_tokens, model) | |
| cost_display = f"${cost:.6f}" | |
| cache_stats = cache.get_stats() | |
| yield history, "", gr.update(value=token_display), gr.update(value=cost_display), gr.update(value=cache_stats) | |
| except Exception as e: | |
| history[-1]["content"] = f"Error: {str(e)}" | |
| yield history, "", gr.update(), gr.update(), gr.update() | |
| def clear_chat() -> tuple: | |
| """Clear chat history""" | |
| global chat_engine, total_tokens_used, current_session_id | |
| if chat_engine: | |
| chat_engine.clear_history() | |
| total_tokens_used = 0 | |
| current_session_id = None | |
| return [], gr.update(value="0"), gr.update(value="$0.00"), gr.update() | |
| def update_timer() -> str: | |
| """Update session duration timer""" | |
| elapsed = int(time.time() - session_start_time) | |
| return format_duration(elapsed) | |
| def export_txt_handler(history: list, date_start, date_end, roles: list) -> str: | |
| """Export chat to TXT - converts OpenAI format to tuple format for export""" | |
| tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) | |
| for msg in history] | |
| # Merge consecutive user/assistant pairs | |
| merged = [] | |
| i = 0 | |
| while i < len(tuple_history): | |
| if i + 1 < len(tuple_history): | |
| user_msg = tuple_history[i][0] if tuple_history[i][0] else "" | |
| asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" | |
| merged.append((user_msg, asst_msg)) | |
| i += 2 | |
| else: | |
| merged.append(tuple_history[i]) | |
| i += 1 | |
| return export_to_txt(merged, date_start, date_end, roles, current_session_id) | |
| def export_md_handler(history: list, date_start, date_end, roles: list) -> str: | |
| """Export chat to Markdown""" | |
| tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) | |
| for msg in history] | |
| merged = [] | |
| i = 0 | |
| while i < len(tuple_history): | |
| if i + 1 < len(tuple_history): | |
| user_msg = tuple_history[i][0] if tuple_history[i][0] else "" | |
| asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" | |
| merged.append((user_msg, asst_msg)) | |
| i += 2 | |
| else: | |
| merged.append(tuple_history[i]) | |
| i += 1 | |
| return export_to_markdown(merged, date_start, date_end, roles, current_session_id) | |
| def export_json_handler(history: list, date_start, date_end, roles: list) -> str: | |
| """Export chat to JSON""" | |
| tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) | |
| for msg in history] | |
| merged = [] | |
| i = 0 | |
| while i < len(tuple_history): | |
| if i + 1 < len(tuple_history): | |
| user_msg = tuple_history[i][0] if tuple_history[i][0] else "" | |
| asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" | |
| merged.append((user_msg, asst_msg)) | |
| i += 2 | |
| else: | |
| merged.append(tuple_history[i]) | |
| i += 1 | |
| return export_to_json(merged, date_start, date_end, roles, current_session_id, | |
| total_tokens_used, session_start_time) | |
| def export_csv_handler(history: list, date_start, date_end, roles: list) -> str: | |
| """Export chat to CSV""" | |
| tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) | |
| for msg in history] | |
| merged = [] | |
| i = 0 | |
| while i < len(tuple_history): | |
| if i + 1 < len(tuple_history): | |
| user_msg = tuple_history[i][0] if tuple_history[i][0] else "" | |
| asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" | |
| merged.append((user_msg, asst_msg)) | |
| i += 2 | |
| else: | |
| merged.append(tuple_history[i]) | |
| i += 1 | |
| return export_to_csv(merged, date_start, date_end, roles, current_session_id) | |
| def export_audio_handler(history: list, date_start, date_end, roles: list) -> str: | |
| """Export chat to Audio (TTS)""" | |
| tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) | |
| for msg in history] | |
| merged = [] | |
| i = 0 | |
| while i < len(tuple_history): | |
| if i + 1 < len(tuple_history): | |
| user_msg = tuple_history[i][0] if tuple_history[i][0] else "" | |
| asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" | |
| merged.append((user_msg, asst_msg)) | |
| i += 2 | |
| else: | |
| merged.append(tuple_history[i]) | |
| i += 1 | |
| return export_to_audio(merged, date_start, date_end, roles, current_session_id) | |
| def export_pdf_handler(history: list, date_start, date_end, roles: list) -> str: | |
| """Export chat to PDF""" | |
| tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"]) | |
| for msg in history] | |
| merged = [] | |
| i = 0 | |
| while i < len(tuple_history): | |
| if i + 1 < len(tuple_history): | |
| user_msg = tuple_history[i][0] if tuple_history[i][0] else "" | |
| asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else "" | |
| merged.append((user_msg, asst_msg)) | |
| i += 2 | |
| else: | |
| merged.append(tuple_history[i]) | |
| i += 1 | |
| return export_to_pdf(merged, date_start, date_end, roles, current_session_id, | |
| total_tokens_used) | |
| # Build Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AstraMind Chat") as app: | |
| gr.Markdown("# π AstraMind Chat - Stage 1") | |
| # Hidden state for tracking | |
| init_status = gr.Textbox(visible=False) | |
| # API Key input at the top | |
| with gr.Row(): | |
| api_key_input = gr.Textbox( | |
| label="OpenRouter API Key (optional - leave empty for HuggingFace models)", | |
| type="password", | |
| placeholder="sk-or-...", | |
| scale=4 | |
| ) | |
| init_btn = gr.Button("Initialize", scale=1, variant="primary") | |
| init_status_display = gr.Textbox(label="Status", scale=2, interactive=False) | |
| # Main interface (hidden until initialized) | |
| main_interface = gr.Column(visible=True) | |
| with main_interface: | |
| with gr.Row(): | |
| # Left sidebar | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Settings") | |
| model_dropdown = gr.Dropdown( | |
| choices=list_hf_models(), | |
| value="openchat", | |
| label="Model", | |
| info="Select AI model" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative" | |
| ) | |
| use_cache_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Use Response Cache", | |
| info="Cache identical queries" | |
| ) | |
| # System Message Accordion | |
| with gr.Accordion("π¬ System Message", open=False): | |
| system_message = gr.Textbox( | |
| label="System Prompt", | |
| placeholder="Enter a system message to guide the AI's behavior...", | |
| lines=4, | |
| value="" | |
| ) | |
| gr.Markdown("### π Statistics") | |
| token_display = gr.Textbox( | |
| label="Total Tokens", | |
| value="0", | |
| interactive=False | |
| ) | |
| cost_display = gr.Textbox( | |
| label="Estimated Cost", | |
| value="$0.00", | |
| interactive=False | |
| ) | |
| session_timer = gr.Textbox( | |
| label="Session Duration", | |
| value="0s", | |
| interactive=False | |
| ) | |
| try: | |
| safe_cache_stats = json.loads(json.dumps(cache.get_stats(), default=str)) | |
| except Exception: | |
| safe_cache_stats = {} | |
| cache_stats_display = gr.JSON( | |
| label="Cache Stats", | |
| value=safe_cache_stats | |
| ) | |
| # Center: Chat interface | |
| with gr.Column(scale=3): | |
| # Get absolute path to bot avatar | |
| avatar_path = Path(__file__).parent / "assets" / "bot-avatar.png" | |
| chatbot = gr.Chatbot( | |
| height=600, | |
| show_label=False, | |
| type="messages", # OpenAI-compatible format | |
| avatar_images=( | |
| None, # user avatar placeholder | |
| str(avatar_path) if avatar_path.exists() else None # assistant avatar | |
| ), | |
| render_markdown=True | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| placeholder="Message AstraMind...", | |
| show_label=False, | |
| scale=9, | |
| container=False | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear Chat", scale=1) | |
| regenerate_btn = gr.Button("π Regenerate", scale=1, visible=False) | |
| # Right: Export panel | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Export Suite") | |
| gr.Markdown("**Filters:**") | |
| date_start = gr.Textbox( | |
| label="From Date (YYYY-MM-DD)", | |
| placeholder="2024-01-01" | |
| ) | |
| date_end = gr.Textbox( | |
| label="To Date (YYYY-MM-DD)", | |
| placeholder="2024-12-31" | |
| ) | |
| role_filter = gr.CheckboxGroup( | |
| choices=["user", "assistant", "system"], | |
| value=["user", "assistant"], | |
| label="Include Roles" | |
| ) | |
| gr.Markdown("**Export Formats:**") | |
| export_txt_btn = gr.Button("π Export TXT", size="sm") | |
| export_md_btn = gr.Button("π Export MD", size="sm") | |
| export_json_btn = gr.Button("π Export JSON", size="sm") | |
| export_csv_btn = gr.Button("π Export CSV", size="sm") | |
| export_audio_btn = gr.Button("π Export Audio", size="sm") | |
| export_pdf_btn = gr.Button("π Export PDF", size="sm") | |
| download_file = gr.File(label="Download", visible=True) | |
| # Event handlers | |
| init_btn.click( | |
| fn=initialize_chat_engine, | |
| inputs=[api_key_input], | |
| outputs=[main_interface, init_status_display, model_dropdown] | |
| ) | |
| # Chat handlers | |
| msg_input.submit( | |
| fn=chat_response, | |
| inputs=[msg_input, chatbot, model_dropdown, temperature_slider, | |
| api_key_input, use_cache_checkbox, system_message], | |
| outputs=[chatbot, msg_input, token_display, cost_display, cache_stats_display] | |
| ) | |
| send_btn.click( | |
| fn=chat_response, | |
| inputs=[msg_input, chatbot, model_dropdown, temperature_slider, | |
| api_key_input, use_cache_checkbox, system_message], | |
| outputs=[chatbot, msg_input, token_display, cost_display, cache_stats_display] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| inputs=[], | |
| outputs=[chatbot, token_display, cost_display, cache_stats_display] | |
| ) | |
| # Export handlers | |
| export_txt_btn.click( | |
| fn=export_txt_handler, | |
| inputs=[chatbot, date_start, date_end, role_filter], | |
| outputs=[download_file] | |
| ) | |
| export_md_btn.click( | |
| fn=export_md_handler, | |
| inputs=[chatbot, date_start, date_end, role_filter], | |
| outputs=[download_file] | |
| ) | |
| export_json_btn.click( | |
| fn=export_json_handler, | |
| inputs=[chatbot, date_start, date_end, role_filter], | |
| outputs=[download_file] | |
| ) | |
| export_csv_btn.click( | |
| fn=export_csv_handler, | |
| inputs=[chatbot, date_start, date_end, role_filter], | |
| outputs=[download_file] | |
| ) | |
| export_audio_btn.click( | |
| fn=export_audio_handler, | |
| inputs=[chatbot, date_start, date_end, role_filter], | |
| outputs=[download_file] | |
| ) | |
| export_pdf_btn.click( | |
| fn=export_pdf_handler, | |
| inputs=[chatbot, date_start, date_end, role_filter], | |
| outputs=[download_file] | |
| ) | |
| # Timer update (using gr.Timer for periodic updates) | |
| timer = gr.Timer(value=1, active=True) | |
| timer.tick(fn=update_timer, outputs=[session_timer]) | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=False, | |
| share=False, | |
| ) | |
| # allowed_paths=[str(Path(__file__).parent / "assets")], | |