| import os |
| import sys |
| import re |
| import gradio as gr |
| import json |
| import tempfile |
| import base64 |
| import io |
| from typing import List, Dict, Any, Optional, Tuple, Union |
| import logging |
| import pandas as pd |
| import plotly.express as px |
| import plotly.graph_objects as go |
| from plotly.subplots import make_subplots |
| from shared import initialize_llm, setup_database_connection, create_agent |
|
|
| try: |
| from langchain_core.messages import HumanMessage, AIMessage |
| LANGCHAIN_AVAILABLE = True |
| except ImportError: |
| |
| class HumanMessage: |
| def __init__(self, content): |
| self.content = content |
| |
| class AIMessage: |
| def __init__(self, content): |
| self.content = content |
| LANGCHAIN_AVAILABLE = False |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| def create_ui(): |
| """Create the Gradio UI components.""" |
| |
| custom_css = """ |
| .gradio-container { |
| max-width: 1200px !important; |
| } |
| .chat-container { |
| height: 600px; |
| overflow-y: auto; |
| } |
| .chart-container { |
| height: 600px; |
| overflow-y: auto; |
| } |
| """ |
| |
| with gr.Blocks(css=custom_css, title="🤖 SQL Database Assistant") as demo: |
| gr.Markdown("# 🤖 SQL Database Assistant") |
| gr.Markdown("Ask questions about your database in natural language!") |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| chatbot = gr.Chatbot( |
| label="Chat", |
| elem_classes="chat-container", |
| type="messages" |
| ) |
| |
| with gr.Row(): |
| question_input = gr.Textbox( |
| label="Ask your question", |
| placeholder="Type your question here...", |
| lines=2, |
| scale=4 |
| ) |
| submit_button = gr.Button("Send", variant="primary", scale=1) |
| |
| streaming_output_display = gr.Markdown(visible=False) |
| |
| with gr.Column(scale=1): |
| chart_display = gr.Plot( |
| label="Charts", |
| elem_classes="chart-container" |
| ) |
| |
| |
| with gr.Row(): |
| status_indicator = gr.Markdown( |
| "### ✅ System Status\n- **Database**: Ready\n- **AI Model**: Ready\n- **API**: Available", |
| elem_id="status" |
| ) |
| |
| return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display |
|
|
| |
|
|
| def create_application(): |
| """Create and configure the Gradio application.""" |
| |
| demo, chatbot, chart_display, question_input, submit_button, streaming_output_display = create_ui() |
| |
| |
| if os.getenv('SPACE_ID'): |
| import api |
| api.app = gr.mount_gradio_app(api.app, demo, path="/") |
| |
| def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]: |
| """Add user message to chat history (messages format) and clear input.""" |
| if not user_input.strip(): |
| return "", chat_history |
|
|
| logger.info(f"User message: {user_input}") |
|
|
| if chat_history is None: |
| chat_history = [] |
|
|
| |
| chat_history.append({"role": "user", "content": user_input}) |
|
|
| return "", chat_history |
| |
| async def bot_response(chat_history: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Optional[go.Figure]]: |
| """Generate bot response for messages-format chat history and return optional chart figure.""" |
| if not chat_history: |
| return chat_history, None |
|
|
| |
| last = chat_history[-1] |
| if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"): |
| return chat_history, None |
|
|
| try: |
| question = last["content"] |
| logger.info(f"Processing question: {question}") |
|
|
| |
| pair_history: List[List[str]] = [] |
| i = 0 |
| while i < len(chat_history) - 1: |
| m1 = chat_history[i] |
| m2 = chat_history[i + 1] if i + 1 < len(chat_history) else None |
| if ( |
| isinstance(m1, dict) |
| and m1.get("role") == "user" |
| and isinstance(m2, dict) |
| and m2.get("role") == "assistant" |
| ): |
| pair_history.append([m1.get("content", ""), m2.get("content", "")]) |
| i += 2 |
| else: |
| i += 1 |
|
|
| |
| assistant_message, chart_fig = await stream_agent_response(question, pair_history) |
|
|
| |
| chat_history.append({"role": "assistant", "content": assistant_message}) |
|
|
| logger.info("Response generation complete") |
| return chat_history, chart_fig |
|
|
| except Exception as e: |
| error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```" |
| logger.error(error_msg, exc_info=True) |
| |
| chat_history.append({"role": "assistant", "content": error_msg}) |
| return chat_history, None |
| |
| |
| with demo: |
| |
| msg_submit = question_input.submit( |
| fn=user_message, |
| inputs=[question_input, chatbot], |
| outputs=[question_input, chatbot], |
| queue=True |
| ).then( |
| fn=bot_response, |
| inputs=[chatbot], |
| outputs=[chatbot, chart_display], |
| api_name="ask" |
| ) |
| |
| |
| btn_click = submit_button.click( |
| fn=user_message, |
| inputs=[question_input, chatbot], |
| outputs=[question_input, chatbot], |
| queue=True |
| ).then( |
| fn=bot_response, |
| inputs=[chatbot], |
| outputs=[chatbot, chart_display] |
| ) |
| |
| return demo |
|
|
| async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional[go.Figure]]: |
| """Process a question through the SQL agent and return response with optional chart.""" |
| |
| |
| llm, llm_error = initialize_llm() |
| if llm_error: |
| return f"**LLM Error:** {llm_error}", None |
| |
| db_connection, db_error = setup_database_connection() |
| if db_error: |
| return f"**Database Error:** {db_error}", None |
| |
| agent, agent_error = create_agent(llm, db_connection) |
| if agent_error: |
| return f"**Agent Error:** {agent_error}", None |
| |
| try: |
| logger.info(f"Processing question: {question}") |
| |
| |
| input_data = {"input": question} |
| if chat_history: |
| |
| formatted_history = [] |
| for human, ai in chat_history: |
| formatted_history.extend([ |
| HumanMessage(content=human), |
| AIMessage(content=ai) |
| ]) |
| input_data["chat_history"] = formatted_history |
| |
| |
| response = agent.invoke(input_data) |
| |
| |
| if hasattr(response, 'output') and response.output: |
| response_text = response.output |
| elif isinstance(response, dict) and 'output' in response: |
| response_text = response['output'] |
| elif isinstance(response, str): |
| response_text = response |
| else: |
| response_text = str(response) |
| |
| |
| sql_pattern = r'```sql\s*(.*?)\s*```' |
| sql_matches = re.findall(sql_pattern, response_text, re.DOTALL) |
| |
| chart_fig = None |
| if sql_matches: |
| |
| try: |
| sql_query = sql_matches[-1].strip() |
| logger.info(f"Executing SQL query: {sql_query}") |
| |
| |
| result = db_connection.run(sql_query) |
| |
| if result: |
| |
| import pandas as pd |
| if isinstance(result, list) and result: |
| df = pd.DataFrame(result) |
| |
| |
| if len(df.columns) >= 2: |
| |
| fig = go.Figure() |
| |
| if len(df) <= 20: |
| fig.add_trace(go.Bar( |
| x=df.iloc[:, 0], |
| y=df.iloc[:, 1], |
| name=str(df.columns[1]) |
| )) |
| fig.update_layout( |
| title=f"{df.columns[0]} vs {df.columns[1]}", |
| xaxis_title=str(df.columns[0]), |
| yaxis_title=str(df.columns[1]) |
| ) |
| else: |
| fig.add_trace(go.Scatter( |
| x=df.iloc[:, 0], |
| y=df.iloc[:, 1], |
| mode='lines+markers', |
| name=str(df.columns[1]) |
| )) |
| fig.update_layout( |
| title=f"{df.columns[0]} vs {df.columns[1]}", |
| xaxis_title=str(df.columns[0]), |
| yaxis_title=str(df.columns[1]) |
| ) |
| |
| chart_fig = fig |
| |
| except Exception as e: |
| logger.warning(f"Could not create chart: {e}") |
| |
| |
| return response_text, chart_fig |
| |
| except Exception as e: |
| error_msg = f"**Error processing question:** {str(e)}" |
| logger.error(error_msg, exc_info=True) |
| return error_msg, None |
|
|
| |
| demo = create_application() |
|
|
| |
| def get_app(): |
| """Obtiene la instancia de la aplicación Gradio para Hugging Face Spaces.""" |
| |
| if os.getenv('SPACE_ID'): |
| |
| demo.title = "🤖 Asistente de Base de Datos SQL (Demo)" |
| demo.description = """ |
| Este es un demo del asistente de base de datos SQL. |
| Para usar la versión completa con conexión a base de datos, clona este espacio y configura las variables de entorno. |
| """ |
| |
| return demo |
|
|
| |
| if __name__ == "__main__": |
| |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| debug=True, |
| share=False |
| ) |
|
|