GrizzGrizz commited on
Commit
841c495
·
verified ·
1 Parent(s): a51a0cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -47
app.py CHANGED
@@ -1,66 +1,342 @@
1
  """
2
- Hlavní spouštěcí soubor pro multiagentní chatbot pro vizualizaci dat
3
  """
 
 
 
4
  import os
5
- import sys
6
- from dotenv import load_dotenv
7
 
8
- load_dotenv()
 
 
 
 
 
 
 
 
 
9
 
10
 
11
- def check_requirements():
12
- """Kontrola, zda jsou splněny všechny požadavky"""
13
- required_vars = ["ANTHROPIC_API_KEY"]
14
- missing_vars = []
15
 
16
- for var in required_vars:
17
- if not os.getenv(var) or os.getenv(var) == f"your_{var.lower()}_here":
18
- missing_vars.append(var)
19
 
20
- if missing_vars:
21
- print("❌ Chybějící environment proměnné:")
22
- for var in missing_vars:
23
- print(f" - {var}")
24
- print("\n📝 Nastavte je v .env souboru:")
25
- print(" ANTHROPIC_API_KEY=your_actual_api_key_here")
26
- return False
27
 
28
- return True
29
 
 
30
 
31
- def main():
32
- """Hlavní funkce pro spuštění aplikace"""
33
- print("🤖 Multiagentní Chatbot pro Vizualizaci Dat")
34
- print("=" * 50)
35
 
36
- # Kontrola požadavků
37
- if not check_requirements():
38
- print(
39
- "\n⚠️ Aplikace nemůže být spuštěna bez správně nastavených API klíčů.")
40
- return
41
 
42
- print("✅ Všechny požadavky splněny!")
43
- print("🚀 Spouštím Gradio rozhraní...")
44
 
45
- try:
46
- from gradio_app import create_gradio_interface
 
 
 
47
 
48
- interface = create_gradio_interface()
49
- interface.launch(
50
- server_name="0.0.0.0",
51
- server_port=7862,
52
- share=False,
53
- show_error=True,
54
- inbrowser=True
55
- )
56
 
57
- except ImportError as e:
58
- print(f"❌ Chyba při importu: {e}")
59
- print("💡 Spusťte: pip install -r requirements.txt")
60
- except Exception as e:
61
- print(f"❌ Chyba při spuštění: {e}")
62
 
63
 
64
- if __name__ == "__main__":
65
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ MCP Visualization Server
3
  """
4
+
5
+ import base64
6
+ import io
7
  import os
8
+ import textwrap
9
+ from typing import Dict, Any, List
10
 
11
+ import pandas as pd
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ import numpy as np
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from pydantic import BaseModel
20
+ from anthropic import Anthropic
21
 
22
 
23
+ # FastAPI
 
 
 
24
 
25
+ app = FastAPI(title="MCP Visualization Server", version="8.0.0")
 
 
26
 
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
 
 
35
 
36
+ # Models
37
 
38
+ class VisualizationRequest(BaseModel):
39
+ prompt: str
40
+ dataset_info: Dict[str, Any]
41
+ output_format: str = "png"
42
 
 
 
 
 
 
43
 
44
+ # Helpers
 
45
 
46
+ def get_llm() -> Anthropic:
47
+ key = os.getenv("ANTHROPIC_API_KEY")
48
+ if not key:
49
+ raise RuntimeError("ANTHROPIC_API_KEY missing")
50
+ return Anthropic(api_key=key)
51
 
 
 
 
 
 
 
 
 
52
 
53
+ def get_model() -> str:
54
+ model = os.getenv("LLM_MODEL")
55
+ if not model:
56
+ raise RuntimeError("LLM_MODEL not set")
57
+ return model
58
 
59
 
60
+ def load_df(dataset_info: Dict[str, Any]) -> pd.DataFrame:
61
+ if "sample_data" not in dataset_info:
62
+ raise HTTPException(400, "dataset_info.sample_data missing")
63
+ df = pd.DataFrame(dataset_info["sample_data"])
64
+ if df.empty:
65
+ raise HTTPException(400, "Dataset is empty")
66
+ return df
67
+
68
+
69
+ def fig_to_base64(fig: plt.Figure, fmt: str = "png") -> str:
70
+ buf = io.BytesIO()
71
+ fig.savefig(buf, format=fmt, bbox_inches="tight", dpi=120)
72
+ buf.seek(0)
73
+ out = base64.b64encode(buf.read()).decode()
74
+ plt.close(fig)
75
+ return out
76
+
77
+
78
+ def build_schema(df: pd.DataFrame) -> Dict:
79
+ """Bohaté schéma datasetu pro LLM - kardinalita, typy, ukázky."""
80
+ schema = {}
81
+ for col in df.columns:
82
+ if col.lower().startswith("unnamed"):
83
+ continue
84
+ nunique = int(df[col].nunique())
85
+ dtype = str(df[col].dtype)
86
+ sample = df[col].dropna().head(5).tolist()
87
+ info = {"dtype": dtype, "nunique": nunique, "sample": sample}
88
+ if pd.api.types.is_numeric_dtype(df[col]):
89
+ info["min"] = float(df[col].min())
90
+ info["max"] = float(df[col].max())
91
+ info["mean"] = round(float(df[col].mean()), 3)
92
+ schema[col] = info
93
+ return schema
94
+
95
+
96
+ # Tool schema (plan)
97
+
98
+ PLAN_TOOL = {
99
+ "name": "create_dashboard_plan",
100
+ "description": "Vytvoří plán dashboardu - insight a seznam 3-4 grafů s popisem co každý má ukázat.",
101
+ "input_schema": {
102
+ "type": "object",
103
+ "properties": {
104
+ "insight": {
105
+ "type": "string",
106
+ "description": "Hlavní datový insight v jedné větě"
107
+ },
108
+ "charts": {
109
+ "type": "array",
110
+ "items": {
111
+ "type": "object",
112
+ "properties": {
113
+ "title": {"type": "string", "description": "Nadpis grafu"},
114
+ "description": {"type": "string", "description": "Co graf ukazuje a proč je zajímavý"},
115
+ "chart_type": {
116
+ "type": "string",
117
+ "enum": ["line", "bar", "scatter", "histogram", "violin", "dual_axes"]
118
+ },
119
+ "columns_used": {
120
+ "type": "array",
121
+ "items": {"type": "string"},
122
+ "description": "Přesné názvy sloupců použité v grafu"
123
+ }
124
+ },
125
+ "required": ["title", "description", "chart_type", "columns_used"]
126
+ },
127
+ "minItems": 3,
128
+ "maxItems": 4
129
+ }
130
+ },
131
+ "required": ["insight", "charts"]
132
+ }
133
+ }
134
+
135
+
136
+ # Prompts
137
+
138
+ PLAN_SYSTEM = """
139
+ Jsi zkušený datový analytik. Tvým úkolem je navrhnout dashboard s 3-4 grafy.
140
+
141
+ Pravidla pro výběr grafů:
142
+ - line: pouze pro datum/čas nebo pořadové hodnoty (nunique > 20)
143
+ - bar: pro kategorie s nunique 2-25, zobraz top hodnoty seřazené sestupně
144
+ - scatter: pro vztah dvou numerických sloupů, přidej regresní linii
145
+ - histogram: pro distribuci jednoho numerického sloupce, přidej průměr a medián
146
+ - violin: pro distribuci čísla podle kategorie (nunique kategorie < 15)
147
+ - dual_axes: pouze pokud chceš srovnat 2 metriky s velmi různými škálami
148
+
149
+ KRITICKÁ PRAVIDLA:
150
+ - Nepoužívej sloupce začínající "Unnamed"
151
+ - bar NIKDY pro sloupce s nunique > 25
152
+ - violin NIKDY pro kategorie s nunique > 15
153
+ - Každý graf musí přinést JINOU informaci
154
+ - Nepoužívej stejný typ grafu dvakrát
155
+ """
156
+
157
+ CODE_SYSTEM = """
158
+ Jsi expert na Python vizualizace s matplotlib a seaborn.
159
+
160
+ Napiš Python kód pro JEDEN konkrétní graf.
161
+
162
+ Pravidla:
163
+ - DataFrame je dostupný jako proměnná `df` (již načtený)
164
+ - Figure je dostupný jako proměnná `fig` a `ax` (již vytvořený: fig, ax = plt.subplots(...))
165
+ - NEPIŠ: import, plt.subplots(), plt.show(), plt.savefig(), plt.close()
166
+ - Kresli pouze na `ax`
167
+ - Používej sns nebo ax přímé volání
168
+ - Přidej popisné osy a title
169
+ - Zpracuj data správně (agregace, filtrování, konverze typů)
170
+ - Pro datetime: pd.to_datetime() a resample("ME").mean()
171
+ - Pro bar s mnoha kategoriemi: zobraz jen top 15 podle hodnoty, horizontálně
172
+ - Pro scatter: přidej regresní linii přes sns.regplot(scatter=False)
173
+ - Pro histogram: přidej ax.axvline pro průměr a medián
174
+ - Kód musí být robustní: dropna(), pd.to_numeric(errors='coerce') kde je potřeba
175
+
176
+ Napiš POUZE spustitelný Python kód, bez vysvětlení, bez markdown.
177
+ """
178
+
179
+
180
+ # Step 1: Plan (tool_use)
181
+
182
+ def create_plan(prompt: str, df: pd.DataFrame) -> Dict[str, Any]:
183
+ """LLM navrhne strukturovaný plán dashboardu přes tool_use."""
184
+ llm = get_llm()
185
+ schema = build_schema(df)
186
+
187
+ user_msg = f"""
188
+ Požadavek: {prompt}
189
+
190
+ Schéma datasetu ({len(df)} řádků):
191
+ {schema}
192
+
193
+ Navrhni 3-4 různé grafy pro dashboard.
194
+ """
195
+
196
+ resp = llm.messages.create(
197
+ model=get_model(),
198
+ max_tokens=1000,
199
+ system=PLAN_SYSTEM,
200
+ messages=[{"role": "user", "content": user_msg}],
201
+ tools=[PLAN_TOOL],
202
+ tool_choice={"type": "tool", "name": "create_dashboard_plan"},
203
+ )
204
+
205
+ for block in resp.content:
206
+ if block.type == "tool_use" and block.name == "create_dashboard_plan":
207
+ return block.input # již Python dict, bez json.loads()
208
+
209
+ raise HTTPException(500, "LLM did not return tool_use block")
210
+
211
+
212
+ # Step 2: Code per chart
213
+
214
+ def generate_chart_code(chart: Dict[str, Any], df: pd.DataFrame) -> str:
215
+ """LLM napíše matplotlib kód na míru pro jeden konkrétní graf."""
216
+ llm = get_llm()
217
+ schema = build_schema(df)
218
 
219
+ # Ukázka dat pro relevantní sloupce
220
+ cols = chart.get("columns_used", [])
221
+ valid_cols = [c for c in cols if c in df.columns]
222
+ sample_data = df[valid_cols].head(10).to_string() if valid_cols else df.head(5).to_string()
223
+
224
+ user_msg = f"""
225
+ Graf: {chart['title']}
226
+ Typ: {chart['chart_type']}
227
+ Popis: {chart['description']}
228
+ Použité sloupce: {chart['columns_used']}
229
+
230
+ Schéma datasetu:
231
+ {schema}
232
+
233
+ Ukázka dat:
234
+ {sample_data}
235
+
236
+ Napiš Python kód pro tento graf. Kresli na proměnnou `ax`, data jsou v `df`.
237
+ """
238
+
239
+ resp = llm.messages.create(
240
+ model=get_model(),
241
+ max_tokens=800,
242
+ system=CODE_SYSTEM,
243
+ messages=[{"role": "user", "content": user_msg}],
244
+ )
245
+
246
+ code = resp.content[0].text.strip()
247
+
248
+ # Odstranění markdown pokud LLM přidá
249
+ if "```python" in code:
250
+ code = code.split("```python")[1].split("```")[0].strip()
251
+ elif "```" in code:
252
+ code = code.split("```")[1].split("```")[0].strip()
253
+
254
+ return code
255
+
256
+
257
+ # Step 3: Execute code
258
+
259
+ def execute_chart_code(code: str, df: pd.DataFrame, fmt: str) -> str:
260
+ """Spustí kód grafu a vrátí base64 obrázek."""
261
+ sns.set_theme(style="whitegrid", palette="Set2")
262
+ fig, ax = plt.subplots(figsize=(10, 6))
263
+
264
+ exec_globals = {
265
+ "df": df.copy(),
266
+ "fig": fig,
267
+ "ax": ax,
268
+ "plt": plt,
269
+ "pd": pd,
270
+ "sns": sns,
271
+ "np": np,
272
+ }
273
+
274
+ exec(textwrap.dedent(code), exec_globals) # noqa: S102
275
+
276
+ return fig_to_base64(fig, fmt)
277
+
278
+
279
+ # Endpoint
280
+
281
+ @app.post("/advanced-visualization")
282
+ def advanced_visualization(req: VisualizationRequest):
283
+ df = load_df(req.dataset_info)
284
+ fmt = req.output_format
285
+
286
+ # Krok 1: strukturovaný plán přes tool_use
287
+ plan = create_plan(req.prompt, df)
288
+ print(f"Plan: insight='{plan.get('insight')}', charts={[c['title'] for c in plan.get('charts', [])]}")
289
+
290
+ images = {}
291
+ errors = []
292
+
293
+ # Krok 2+3: pro každý graf LLM napíše kód
294
+ for chart in plan.get("charts", [])[:4]:
295
+ title = chart.get("title", "chart")
296
+ print(f"Generating code for: {title} ({chart.get('chart_type')})")
297
+
298
+ try:
299
+ code = generate_chart_code(chart, df)
300
+ print(f"Code for '{title}':\n{code}\n---")
301
+
302
+ img = execute_chart_code(code, df, fmt)
303
+
304
+ key = title.lower().replace(" ", "_")[:30]
305
+ counter = 1
306
+ while key in images:
307
+ key = f"{key}_{counter}"
308
+ counter += 1
309
+ images[key] = img
310
+
311
+ except Exception as e:
312
+ import traceback
313
+ tb = traceback.format_exc()
314
+ print(f"Error for '{title}': {tb}")
315
+ errors.append(f"{title}: {str(e)}")
316
+
317
+ if not images:
318
+ raise HTTPException(500, f"No visualizations generated. Errors: {errors}")
319
+
320
+ return {
321
+ "success": True,
322
+ "insight": plan.get("insight"),
323
+ "visualization": next(iter(images.values())),
324
+ "visualizations": images,
325
+ "chart_count": len(images),
326
+ "tool_errors": errors,
327
+ "llm_plan": plan,
328
+ }
329
+
330
+
331
+ @app.get("/health")
332
+ def health():
333
+ return {"status": "ok"}
334
+
335
+
336
+ if __name__ == "__main__":
337
+ import uvicorn
338
+ uvicorn.run(
339
+ app,
340
+ host="0.0.0.0",
341
+ port=int(os.getenv("PORT", "7860")),
342
+ )