Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import base64 | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| from smolagents import ToolCallingAgent, AzureOpenAIModel | |
| from smolagents.mcp_client import MCPClient | |
| from tools import generate_chart, LAST_CHART, AGENT_STEPS | |
| from prompt import AGENT_INSTRUCTIONS | |
| # Config | |
| BASE_DIR = Path(__file__).parent | |
| load_dotenv(BASE_DIR.parent / ".env") | |
| MCP_SERVER_URL = os.getenv("MCP_SERVER_URL", "https://sitsope-mcp-server-test.hf.space/gradio_api/mcp/sse") | |
| AZURE_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "https://collier-llm.openai.azure.com/") | |
| AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
| AZURE_API_VER = os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") | |
| AZURE_MODEL = os.getenv("AZURE_OPENAI_MODEL", "gpt-4o") | |
| SUGGESTIONS = [ | |
| "How many unique questions?", | |
| "Sub-options per question", | |
| "Questions above average sub-options", | |
| "Distribution by question type", | |
| "Which question has the most sub-options?", | |
| "Percentage of rows per question", | |
| ] | |
| # Keep CSS minimal to avoid HF Spaces iframe layout bugs | |
| CSS = """ | |
| .gradio-container { | |
| max-width: none !important; | |
| width: 100% !important; | |
| margin: 0 auto !important; | |
| padding-left: 20px !important; | |
| padding-right: 20px !important; | |
| } | |
| """ | |
| # Agent setup | |
| mcp_client = MCPClient({"url": MCP_SERVER_URL, "transport": "sse"}, structured_output=False) | |
| tools = mcp_client.get_tools() | |
| model = AzureOpenAIModel( | |
| model_id=AZURE_MODEL, | |
| azure_endpoint=AZURE_ENDPOINT, | |
| api_key=AZURE_API_KEY, | |
| api_version=AZURE_API_VER, | |
| ) | |
| agent = ToolCallingAgent( | |
| tools=[*tools, generate_chart], | |
| model=model, | |
| instructions=AGENT_INSTRUCTIONS, | |
| ) | |
| # Agent runner | |
| def run_agent(question: str): | |
| LAST_CHART["path"] = None | |
| AGENT_STEPS.clear() | |
| response = agent.run(question) | |
| if hasattr(agent, "memory") and hasattr(agent.memory, "steps"): | |
| for step in agent.memory.steps: | |
| tool_name = getattr(step, "tool_name", None) or getattr(step, "action", None) | |
| tool_input = getattr(step, "tool_arguments", None) or getattr(step, "tool_input", "") | |
| observation = getattr(step, "observations", None) or getattr(step, "observation", "") | |
| if tool_name and tool_name != "final_answer": | |
| input_str = str(tool_input)[:300] + ("…" if len(str(tool_input)) > 300 else "") | |
| obs_str = str(observation)[:500] + ("…" if len(str(observation)) > 500 else "") | |
| AGENT_STEPS.append((str(tool_name), input_str, obs_str)) | |
| chart_path = LAST_CHART["path"] | |
| if chart_path is None: | |
| match = re.search(r"(chart_[^\s]+\.png)", str(response)) | |
| if match: | |
| chart_path = match.group(1) | |
| return str(response), chart_path | |
| # Event handlers | |
| def img_to_base64(path: str) -> str: | |
| with open(path, "rb") as f: | |
| return base64.b64encode(f.read()).decode() | |
| def ask_agent(question: str, history: list): | |
| history = history or [] | |
| if not question.strip(): | |
| return history, gr.update(visible=(len(history) == 0)), "" | |
| response, chart_path = run_agent(question) | |
| assistant_content = response | |
| if chart_path and os.path.exists(chart_path): | |
| b64 = img_to_base64(chart_path) | |
| assistant_content = f"\n\n{response}" | |
| updated_history = history + [ | |
| {"role": "user", "content": question}, | |
| {"role": "assistant", "content": assistant_content}, | |
| ] | |
| return updated_history, gr.update(visible=False), "" | |
| def new_chat(): | |
| return [], gr.update(visible=True), "" | |
| # Gradio app | |
| with gr.Blocks(title="Data Analyst agent with MCP", fill_height=True) as demo: | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=260): | |
| gr.Markdown("## Data Analyst agent with MCP") | |
| new_btn = gr.Button("+ New chat", variant="secondary") | |
| gr.Markdown("### Suggestions") | |
| suggestion_buttons = [] | |
| for text in SUGGESTIONS: | |
| suggestion_buttons.append(gr.Button(text, size="sm")) | |
| with gr.Column(scale=4): | |
| welcome = gr.Markdown( | |
| "## How can I help you today?\nAsk anything about your survey data and I will generate stats and charts.", | |
| visible=True, | |
| ) | |
| chatbot = gr.Chatbot( | |
| show_label=False, | |
| avatar_images=(None, "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), | |
| height=620, | |
| ) | |
| with gr.Row(): | |
| question_box = gr.Textbox( | |
| placeholder="Ask anything about the Stack Overflow survey…", | |
| show_label=False, | |
| lines=1, | |
| max_lines=6, | |
| scale=8, | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1, min_width=90) | |
| gr.Markdown("Generated content may be inaccurate or false.") | |
| send_btn.click(ask_agent, inputs=[question_box, chatbot], outputs=[chatbot, welcome, question_box]) | |
| question_box.submit(ask_agent, inputs=[question_box, chatbot], outputs=[chatbot, welcome, question_box]) | |
| new_btn.click(new_chat, outputs=[chatbot, welcome, question_box]) | |
| for button, text in zip(suggestion_buttons, SUGGESTIONS): | |
| button.click(lambda t=text: t, outputs=question_box) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| css=CSS, | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", 7860)), | |
| share=False, | |
| ) | |