AeonDevWorks's picture
apply hot patch for gradio hf spaace
638a685
"""
AstraMind Stage 1 - Gradio Chat Interface
Main application file
"""
import os
import sys
import time
import json
from datetime import datetime
from pathlib import Path
# Add src directory to path for imports
sys.path.insert(0, str(Path(__file__).parent / "src"))
import gradio as gr
import gradio_client.utils
# --- HF-Space hot-patch for Gradio schema bug (bool is not iterable) ---
import gradio_client.utils
def _safe_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool): # <-- short-circuit the bad case
return "bool"
try:
return gradio_client.utils._json_schema_to_python_type(schema, defs)
except Exception:
return "Any"
gradio_client.utils._json_schema_to_python_type = _safe_json_schema_to_python_type
gr.routes.api_info = lambda *a, **k: {} # skip OpenAPI generation
# -----------------------------------------------------------------------
from backend.chat_engine import ChatEngine
from backend.cache import ResponseCache
from backend.session_manager import SessionManager
from backend.model_registry import list_models, get_model_display_names, list_openrouter_models, list_hf_models
from backend.utils import count_tokens, calculate_cost, format_duration, get_timestamp
# Import export utilities from their location
sys.path.insert(0, str(Path(__file__).parent / "src" / "frontend" / "gradio_app"))
from export_utils import (
export_to_txt, export_to_markdown, export_to_json,
export_to_csv, export_to_audio, export_to_pdf
)
# Load custom CSS
css_file = Path(__file__).parent / "src" / "frontend" / "gradio_app" / "styles.css"
with open(css_file, 'r') as f:
custom_css = f.read()
# Initialize global components
cache = ResponseCache(ttl=3600)
session_manager = SessionManager(base_dir="chat-history")
chat_engine = None # Will be initialized when API key is provided
# Session state
session_start_time = time.time()
current_session_id = None
total_tokens_used = 0
def initialize_chat_engine(api_key: str) -> tuple:
"""Initialize chat engine with API key and update model choices"""
global chat_engine
try:
if api_key and api_key.strip():
# OpenRouter mode
chat_engine = ChatEngine(api_key=api_key, cache=cache)
model_choices = list_openrouter_models()
default_model = "gpt-4o-mini"
status = "βœ“ OpenRouter initialized"
else:
# HuggingFace mode
chat_engine = ChatEngine(api_key=None, cache=cache)
model_choices = list_hf_models()
default_model = "openchat"
status = "βœ“ HuggingFace models ready"
return gr.update(visible=True), gr.update(value=status), gr.update(choices=model_choices, value=default_model)
except Exception as e:
return gr.update(visible=False), gr.update(value=f"βœ— Error: {str(e)}"), gr.update()
def chat_response(message: str, history: list, model: str, temperature: float,
api_key: str, use_cache: bool, system_message: str = "") -> tuple:
"""Handle chat message and generate response using OpenAI message format"""
global chat_engine, total_tokens_used, current_session_id
if not message or message.strip() == "":
return history, "", gr.update(), gr.update(), gr.update()
# Initialize engine if needed
if chat_engine is None:
try:
chat_engine = ChatEngine(api_key=api_key, cache=cache)
except Exception as e:
history.append({"role": "assistant", "content": f"Error: {str(e)}"})
return history, "", gr.update(), gr.update(), gr.update()
# Create session if needed
if current_session_id is None:
current_session_id = session_manager.create_session(model)
# Prepend system message if provided and this is the first user message
if system_message and system_message.strip():
# Check if system message is not already in history
if not any(msg.get("role") == "system" for msg in history):
history.insert(0, {"role": "system", "content": system_message.strip()})
# Count input tokens
input_tokens = count_tokens(message, model)
# Add user message to history (OpenAI format)
history.append({"role": "user", "content": message})
# Add placeholder for assistant response
history.append({"role": "assistant", "content": ""})
# Stream response
response_text = ""
try:
for chunk in chat_engine.chat(message, model=model, stream=True,
use_cache=use_cache, temperature=temperature,
system_message=system_message if system_message and system_message.strip() else None):
response_text += chunk
history[-1]["content"] = response_text
yield history, "", gr.update(), gr.update(), gr.update()
# Count output tokens
output_tokens = count_tokens(response_text, model)
total_tokens_used += input_tokens + output_tokens
# Save session
messages = []
for msg in history:
messages.append({
"role": msg["role"],
"content": msg["content"],
"timestamp": get_timestamp(),
"tokens": count_tokens(msg["content"], model)
})
session_manager.save_session(
current_session_id,
messages,
{
"model": model,
"total_tokens": total_tokens_used,
"created_at": datetime.fromtimestamp(session_start_time).isoformat()
}
)
# Update UI
token_display = f"{total_tokens_used:,}"
cost = calculate_cost(input_tokens, output_tokens, model)
cost_display = f"${cost:.6f}"
cache_stats = cache.get_stats()
yield history, "", gr.update(value=token_display), gr.update(value=cost_display), gr.update(value=cache_stats)
except Exception as e:
history[-1]["content"] = f"Error: {str(e)}"
yield history, "", gr.update(), gr.update(), gr.update()
def clear_chat() -> tuple:
"""Clear chat history"""
global chat_engine, total_tokens_used, current_session_id
if chat_engine:
chat_engine.clear_history()
total_tokens_used = 0
current_session_id = None
return [], gr.update(value="0"), gr.update(value="$0.00"), gr.update()
def update_timer() -> str:
"""Update session duration timer"""
elapsed = int(time.time() - session_start_time)
return format_duration(elapsed)
def export_txt_handler(history: list, date_start, date_end, roles: list) -> str:
"""Export chat to TXT - converts OpenAI format to tuple format for export"""
tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"])
for msg in history]
# Merge consecutive user/assistant pairs
merged = []
i = 0
while i < len(tuple_history):
if i + 1 < len(tuple_history):
user_msg = tuple_history[i][0] if tuple_history[i][0] else ""
asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else ""
merged.append((user_msg, asst_msg))
i += 2
else:
merged.append(tuple_history[i])
i += 1
return export_to_txt(merged, date_start, date_end, roles, current_session_id)
def export_md_handler(history: list, date_start, date_end, roles: list) -> str:
"""Export chat to Markdown"""
tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"])
for msg in history]
merged = []
i = 0
while i < len(tuple_history):
if i + 1 < len(tuple_history):
user_msg = tuple_history[i][0] if tuple_history[i][0] else ""
asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else ""
merged.append((user_msg, asst_msg))
i += 2
else:
merged.append(tuple_history[i])
i += 1
return export_to_markdown(merged, date_start, date_end, roles, current_session_id)
def export_json_handler(history: list, date_start, date_end, roles: list) -> str:
"""Export chat to JSON"""
tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"])
for msg in history]
merged = []
i = 0
while i < len(tuple_history):
if i + 1 < len(tuple_history):
user_msg = tuple_history[i][0] if tuple_history[i][0] else ""
asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else ""
merged.append((user_msg, asst_msg))
i += 2
else:
merged.append(tuple_history[i])
i += 1
return export_to_json(merged, date_start, date_end, roles, current_session_id,
total_tokens_used, session_start_time)
def export_csv_handler(history: list, date_start, date_end, roles: list) -> str:
"""Export chat to CSV"""
tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"])
for msg in history]
merged = []
i = 0
while i < len(tuple_history):
if i + 1 < len(tuple_history):
user_msg = tuple_history[i][0] if tuple_history[i][0] else ""
asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else ""
merged.append((user_msg, asst_msg))
i += 2
else:
merged.append(tuple_history[i])
i += 1
return export_to_csv(merged, date_start, date_end, roles, current_session_id)
def export_audio_handler(history: list, date_start, date_end, roles: list) -> str:
"""Export chat to Audio (TTS)"""
tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"])
for msg in history]
merged = []
i = 0
while i < len(tuple_history):
if i + 1 < len(tuple_history):
user_msg = tuple_history[i][0] if tuple_history[i][0] else ""
asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else ""
merged.append((user_msg, asst_msg))
i += 2
else:
merged.append(tuple_history[i])
i += 1
return export_to_audio(merged, date_start, date_end, roles, current_session_id)
def export_pdf_handler(history: list, date_start, date_end, roles: list) -> str:
"""Export chat to PDF"""
tuple_history = [(msg["content"], "") if msg["role"] == "user" else ("", msg["content"])
for msg in history]
merged = []
i = 0
while i < len(tuple_history):
if i + 1 < len(tuple_history):
user_msg = tuple_history[i][0] if tuple_history[i][0] else ""
asst_msg = tuple_history[i + 1][1] if i + 1 < len(tuple_history) else ""
merged.append((user_msg, asst_msg))
i += 2
else:
merged.append(tuple_history[i])
i += 1
return export_to_pdf(merged, date_start, date_end, roles, current_session_id,
total_tokens_used)
# Build Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AstraMind Chat") as app:
gr.Markdown("# 🌟 AstraMind Chat - Stage 1")
# Hidden state for tracking
init_status = gr.Textbox(visible=False)
# API Key input at the top
with gr.Row():
api_key_input = gr.Textbox(
label="OpenRouter API Key (optional - leave empty for HuggingFace models)",
type="password",
placeholder="sk-or-...",
scale=4
)
init_btn = gr.Button("Initialize", scale=1, variant="primary")
init_status_display = gr.Textbox(label="Status", scale=2, interactive=False)
# Main interface (hidden until initialized)
main_interface = gr.Column(visible=True)
with main_interface:
with gr.Row():
# Left sidebar
with gr.Column(scale=1):
gr.Markdown("### βš™οΈ Settings")
model_dropdown = gr.Dropdown(
choices=list_hf_models(),
value="openchat",
label="Model",
info="Select AI model"
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative"
)
use_cache_checkbox = gr.Checkbox(
value=True,
label="Use Response Cache",
info="Cache identical queries"
)
# System Message Accordion
with gr.Accordion("πŸ’¬ System Message", open=False):
system_message = gr.Textbox(
label="System Prompt",
placeholder="Enter a system message to guide the AI's behavior...",
lines=4,
value=""
)
gr.Markdown("### πŸ“Š Statistics")
token_display = gr.Textbox(
label="Total Tokens",
value="0",
interactive=False
)
cost_display = gr.Textbox(
label="Estimated Cost",
value="$0.00",
interactive=False
)
session_timer = gr.Textbox(
label="Session Duration",
value="0s",
interactive=False
)
try:
safe_cache_stats = json.loads(json.dumps(cache.get_stats(), default=str))
except Exception:
safe_cache_stats = {}
cache_stats_display = gr.JSON(
label="Cache Stats",
value=safe_cache_stats
)
# Center: Chat interface
with gr.Column(scale=3):
# Get absolute path to bot avatar
avatar_path = Path(__file__).parent / "assets" / "bot-avatar.png"
chatbot = gr.Chatbot(
height=600,
show_label=False,
type="messages", # OpenAI-compatible format
avatar_images=(
None, # user avatar placeholder
str(avatar_path) if avatar_path.exists() else None # assistant avatar
),
render_markdown=True
)
with gr.Row():
msg_input = gr.Textbox(
placeholder="Message AstraMind...",
show_label=False,
scale=9,
container=False
)
send_btn = gr.Button("Send", scale=1, variant="primary")
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
regenerate_btn = gr.Button("πŸ”„ Regenerate", scale=1, visible=False)
# Right: Export panel
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Export Suite")
gr.Markdown("**Filters:**")
date_start = gr.Textbox(
label="From Date (YYYY-MM-DD)",
placeholder="2024-01-01"
)
date_end = gr.Textbox(
label="To Date (YYYY-MM-DD)",
placeholder="2024-12-31"
)
role_filter = gr.CheckboxGroup(
choices=["user", "assistant", "system"],
value=["user", "assistant"],
label="Include Roles"
)
gr.Markdown("**Export Formats:**")
export_txt_btn = gr.Button("πŸ“„ Export TXT", size="sm")
export_md_btn = gr.Button("πŸ“ Export MD", size="sm")
export_json_btn = gr.Button("πŸ“‹ Export JSON", size="sm")
export_csv_btn = gr.Button("πŸ“Š Export CSV", size="sm")
export_audio_btn = gr.Button("πŸ”Š Export Audio", size="sm")
export_pdf_btn = gr.Button("πŸ“• Export PDF", size="sm")
download_file = gr.File(label="Download", visible=True)
# Event handlers
init_btn.click(
fn=initialize_chat_engine,
inputs=[api_key_input],
outputs=[main_interface, init_status_display, model_dropdown]
)
# Chat handlers
msg_input.submit(
fn=chat_response,
inputs=[msg_input, chatbot, model_dropdown, temperature_slider,
api_key_input, use_cache_checkbox, system_message],
outputs=[chatbot, msg_input, token_display, cost_display, cache_stats_display]
)
send_btn.click(
fn=chat_response,
inputs=[msg_input, chatbot, model_dropdown, temperature_slider,
api_key_input, use_cache_checkbox, system_message],
outputs=[chatbot, msg_input, token_display, cost_display, cache_stats_display]
)
clear_btn.click(
fn=clear_chat,
inputs=[],
outputs=[chatbot, token_display, cost_display, cache_stats_display]
)
# Export handlers
export_txt_btn.click(
fn=export_txt_handler,
inputs=[chatbot, date_start, date_end, role_filter],
outputs=[download_file]
)
export_md_btn.click(
fn=export_md_handler,
inputs=[chatbot, date_start, date_end, role_filter],
outputs=[download_file]
)
export_json_btn.click(
fn=export_json_handler,
inputs=[chatbot, date_start, date_end, role_filter],
outputs=[download_file]
)
export_csv_btn.click(
fn=export_csv_handler,
inputs=[chatbot, date_start, date_end, role_filter],
outputs=[download_file]
)
export_audio_btn.click(
fn=export_audio_handler,
inputs=[chatbot, date_start, date_end, role_filter],
outputs=[download_file]
)
export_pdf_btn.click(
fn=export_pdf_handler,
inputs=[chatbot, date_start, date_end, role_filter],
outputs=[download_file]
)
# Timer update (using gr.Timer for periodic updates)
timer = gr.Timer(value=1, active=True)
timer.tick(fn=update_timer, outputs=[session_timer])
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=False,
share=False,
)
# allowed_paths=[str(Path(__file__).parent / "assets")],