Spaces:
Sleeping
Sleeping
File size: 5,621 Bytes
bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb b7ed989 f8656cb b7ed989 f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec 9b44042 bd97fec 9b44042 4a78033 bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb 138544a f8656cb 138544a f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec f8656cb bd97fec 9b44042 bd97fec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import os
import re
import base64
from pathlib import Path
from dotenv import load_dotenv
import gradio as gr
from smolagents import ToolCallingAgent, AzureOpenAIModel
from smolagents.mcp_client import MCPClient
from tools import generate_chart, LAST_CHART, AGENT_STEPS
from prompt import AGENT_INSTRUCTIONS
# Config
BASE_DIR = Path(__file__).parent
load_dotenv(BASE_DIR.parent / ".env")
MCP_SERVER_URL = os.getenv("MCP_SERVER_URL", "https://sitsope-mcp-server-test.hf.space/gradio_api/mcp/sse")
AZURE_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "https://collier-llm.openai.azure.com/")
AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_API_VER = os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview")
AZURE_MODEL = os.getenv("AZURE_OPENAI_MODEL", "gpt-4o")
SUGGESTIONS = [
"How many unique questions?",
"Sub-options per question",
"Questions above average sub-options",
"Distribution by question type",
"Which question has the most sub-options?",
"Percentage of rows per question",
]
# Keep CSS minimal to avoid HF Spaces iframe layout bugs
CSS = """
.gradio-container {
max-width: none !important;
width: 100% !important;
margin: 0 auto !important;
padding-left: 20px !important;
padding-right: 20px !important;
}
"""
# Agent setup
mcp_client = MCPClient({"url": MCP_SERVER_URL, "transport": "sse"}, structured_output=False)
tools = mcp_client.get_tools()
model = AzureOpenAIModel(
model_id=AZURE_MODEL,
azure_endpoint=AZURE_ENDPOINT,
api_key=AZURE_API_KEY,
api_version=AZURE_API_VER,
)
agent = ToolCallingAgent(
tools=[*tools, generate_chart],
model=model,
instructions=AGENT_INSTRUCTIONS,
)
# Agent runner
def run_agent(question: str):
LAST_CHART["path"] = None
AGENT_STEPS.clear()
response = agent.run(question)
if hasattr(agent, "memory") and hasattr(agent.memory, "steps"):
for step in agent.memory.steps:
tool_name = getattr(step, "tool_name", None) or getattr(step, "action", None)
tool_input = getattr(step, "tool_arguments", None) or getattr(step, "tool_input", "")
observation = getattr(step, "observations", None) or getattr(step, "observation", "")
if tool_name and tool_name != "final_answer":
input_str = str(tool_input)[:300] + ("…" if len(str(tool_input)) > 300 else "")
obs_str = str(observation)[:500] + ("…" if len(str(observation)) > 500 else "")
AGENT_STEPS.append((str(tool_name), input_str, obs_str))
chart_path = LAST_CHART["path"]
if chart_path is None:
match = re.search(r"(chart_[^\s]+\.png)", str(response))
if match:
chart_path = match.group(1)
return str(response), chart_path
# Event handlers
def img_to_base64(path: str) -> str:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode()
def ask_agent(question: str, history: list):
history = history or []
if not question.strip():
return history, gr.update(visible=(len(history) == 0)), ""
response, chart_path = run_agent(question)
assistant_content = response
if chart_path and os.path.exists(chart_path):
b64 = img_to_base64(chart_path)
assistant_content = f"\n\n{response}"
updated_history = history + [
{"role": "user", "content": question},
{"role": "assistant", "content": assistant_content},
]
return updated_history, gr.update(visible=False), ""
def new_chat():
return [], gr.update(visible=True), ""
# Gradio app
with gr.Blocks(title="Data Analyst agent with MCP", fill_height=True) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=260):
gr.Markdown("## Data Analyst agent with MCP")
new_btn = gr.Button("+ New chat", variant="secondary")
gr.Markdown("### Suggestions")
suggestion_buttons = []
for text in SUGGESTIONS:
suggestion_buttons.append(gr.Button(text, size="sm"))
with gr.Column(scale=4):
welcome = gr.Markdown(
"## How can I help you today?\nAsk anything about your survey data and I will generate stats and charts.",
visible=True,
)
chatbot = gr.Chatbot(
show_label=False,
avatar_images=(None, "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"),
height=620,
)
with gr.Row():
question_box = gr.Textbox(
placeholder="Ask anything about the Stack Overflow survey…",
show_label=False,
lines=1,
max_lines=6,
scale=8,
)
send_btn = gr.Button("Send", variant="primary", scale=1, min_width=90)
gr.Markdown("Generated content may be inaccurate or false.")
send_btn.click(ask_agent, inputs=[question_box, chatbot], outputs=[chatbot, welcome, question_box])
question_box.submit(ask_agent, inputs=[question_box, chatbot], outputs=[chatbot, welcome, question_box])
new_btn.click(new_chat, outputs=[chatbot, welcome, question_box])
for button, text in zip(suggestion_buttons, SUGGESTIONS):
button.click(lambda t=text: t, outputs=question_box)
if __name__ == "__main__":
demo.launch(
theme=gr.themes.Soft(),
css=CSS,
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", 7860)),
share=False,
)
|