mcpserver / app.py
GrizzGrizz's picture
Update app.py
7d8ca3e verified
"""
FastAPI MCP server for advanced visualizations (designed for Hugging Face Spaces).
Endpoints:
- GET /health → { status: "ok" }
- GET /capabilities → list supported features
- POST /advanced-visualization → generate script via Anthropic, execute in sandbox, return base64 image or HTML
Environment variables (set in HF Spaces Secrets):
- ANTHROPIC_API_KEY: required
- SANDBOX_TIMEOUT: optional (default 30s)
- MAX_OUTPUT_SIZE: optional (bytes, default 10MB)
"""
import base64
import json
import os
import shutil
import subprocess
import tempfile
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from anthropic import Anthropic
import re
app = FastAPI(title="MCP Server (FastAPI)", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class VisualizationRequest(BaseModel):
prompt: str
dataset_info: Optional[Dict[str, Any]] = None
visualization_type: Optional[str] = "advanced"
output_format: Optional[str] = "png" # png|html
def get_anthropic_client() -> Anthropic:
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise RuntimeError("ANTHROPIC_API_KEY is missing")
return Anthropic(api_key=api_key)
@app.get("/")
def root() -> Dict[str, str]:
return {"message": "MCP Server is running", "status": "ok"}
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.get("/capabilities")
def capabilities() -> Dict[str, Any]:
return {
"capabilities": [
"advanced_visualizations",
"matplotlib_seaborn_png",
]
}
def strip_markdown_fences(text: str) -> str:
if "```" not in text:
return text
# Prefer fenced ```python blocks, fallback to first triple-backtick section
start_marker = "```python"
if start_marker in text:
start = text.find(start_marker) + len(start_marker)
end = text.find("```", start)
if end != -1:
return text[start:end].strip()
else:
# Pokud nenajde konec, vezmi vše od začátku bloku
return text[start:].strip()
# generic triple backticks
start = text.find("```") + 3
end = text.find("```", start)
if end != -1:
return text[start:end].strip()
else:
# Pokud nenajde konec, vezmi vše od začátku bloku
return text[start:].strip()
def sanitize_script(script: str) -> str:
"""Lightweight fixups to make LLM scripts runnable in our sandbox.
- Force reading data.csv
- Remove deprecated plt.style.use
- Ensure seaborn theme is set
- Avoid unsafe casts of string columns to float
- Remove any Plotly imports/usages (Plotly disabled)
"""
try:
# Bezpečnostní kontrola: odstranění případných zbývajících markdown bloků
script = script.strip()
if script.startswith('```'):
# Pokud začíná markdown blokem, přeskoč ho
start_idx = script.find('```') + 3
if script[start_idx:start_idx+6] == 'python':
start_idx += 6
script = script[start_idx:].strip()
# Odstraň konec bloku pokud existuje
if script.endswith('```'):
script = script[:-3].strip()
# Force df definition at start to prevent 'name df is not defined' errors
if not script.startswith('import pandas'):
script = "import pandas as pd\n" + script
if 'df = pd.read_csv' not in script:
# Find the first import statement and add df definition after it
lines = script.split('\n')
for i, line in enumerate(lines):
if 'import pandas' in line:
lines.insert(i + 1, 'df = pd.read_csv("data.csv")')
break
script = '\n'.join(lines)
# Force read_csv('data.csv')
script = re.sub(r"read_csv\((?:[^)]*)\)", "read_csv('data.csv')", script)
# Drop plt.style.use lines
script = re.sub(r"\n\s*plt\.style\.use\([^\)]*\)\s*\n", "\n", script)
# Ensure seaborn theme
if "seaborn" in script and "set_theme(" not in script and "set_style(" not in script:
# Try to insert after first seaborn import
script = re.sub(
r"(import\s+seaborn\s+as\s+sns\s*\n)",
r"\1sns.set_theme()\n",
script,
count=1,
)
# Replace dangerous DataFrame-wide float casts with safe numeric coercion
# df.astype(float) or something.astype(float)
script = re.sub(
r"\.astype\(\s*float\s*\)",
".apply(pd.to_numeric, errors='coerce')",
script,
)
# If the script reads df from data.csv, ensure numeric_df exists safely
# Insert after first occurrence of reading csv
script = re.sub(
r"(df\s*=\s*pd\.read_csv\('data\.csv'\)\s*\n)",
r"\1# Derive numeric-only dataframe to avoid casting errors\n" \
r"numeric_df = df.select_dtypes(include=['number']).copy()\n",
script,
count=1,
)
# Strip Plotly imports and common usages
script = re.sub(r"^\s*import\s+plotly[\s\S]*?$", "", script, flags=re.MULTILINE)
script = re.sub(r"^\s*from\s+plotly[\s\S]*?$", "", script, flags=re.MULTILINE)
script = re.sub(r"px\.", "# px.", script)
script = re.sub(r"go\.", "# go.", script)
script = re.sub(r"\.write_html\(", "# .write_html(", script)
script = re.sub(r"\.write_image\(", "# .write_image(", script)
return script
except Exception:
return script
def generate_script(prompt: str, dataset_info: Dict[str, Any], visualization_type: str, output_format: str) -> str:
client = get_anthropic_client()
# Require explicit model from env
model = os.getenv("LLM_MODEL")
if not model:
raise HTTPException(status_code=500, detail="LLM_MODEL není nastaveno v prostředí")
sys_prompt = f"""
Jste expert na POKROČILÉ datové vizualizace v Pythonu. Vytvořte složitý, profesionální Python skript s pokročilými technikami.
Požadavek: {prompt}
Typ vizualizace: {visualization_type}
Dataset info: {json.dumps(dataset_info or {}, ensure_ascii=False, default=str)}
Výstupní formát: {output_format}
POVINNÉ POKROČILÉ VIZUALIZACE:
1. MACHINE LEARNING VIZUALIZACE:
- Clustering s K-means, DBSCAN, nebo hierarchické clustering
- PCA/t-SNE/UMAP pro dimensionality reduction
- Decision boundaries pro klasifikátory
- Feature importance plots
- Confusion matrices s heatmapami
- ROC curves a precision-recall curves
2. POKROČILÉ STATISTICKÉ GRAFY:
- Violin plots s distribucí dat
- Box plots s outlier analýzou
- Correlation heatmaps s clustermap
- Pair plots s regresními čarami
- Distribution plots s KDE
- Q-Q plots pro normality testing
3. INTERAKTIVNÍ A 3D VIZUALIZACE:
- 3D scatter plots s color mapping
- Surface plots a contour plots
- Interaktivní prvky pouze pokud je to možné s matplotlib; preferujte statické PNG
- Animated plots s matplotlib.animation
- Subplot grids s komplexními layouty
4. POKROČILÉ KNIHOVNY (POVINNÉ):
- PyTorch pro deep learning vizualizace
- scikit-learn pro ML algoritmy
- seaborn pro statistické grafy
- matplotlib pro custom vizualizace
- numpy pro numerické operace
TECHNICKÉ POŽADAVKY:
- Vytvořte VÍCE grafů (minimálně 3-5 různých vizualizací)
- Použijte pokročilé ML algoritmy na datech
- Implementujte custom color palettes
- Přidejte statistical annotations
- Použijte subplot layouts pro komplexní dashboards
- Implementujte error handling pro všechny operace
- Uložte každý graf jako 'graph1.{output_format}', 'graph2.{output_format}', atd.
- Hlavní graf uložte jako 'main.{output_format}'
OMEZENÍ:
- Plotly je zakázáno; nepoužívejte importy ani funkce Plotly
KRITICKÉ SYNTAX POŽADAVKY:
- NIKDY nepoužívejte plt.style.use() - je to ZASTARALÉ a způsobuje chyby!
- VŽDY použijte: sns.set_theme() nebo sns.set_style()
- Použijte 's' místo 'size' v scatter plot parametrech
- Použijte 'hue' místo 'color' pro kategorické proměnné
- Použijte plt.tight_layout() před plt.savefig()
- Zkontrolujte, že všechny sloupce existují před použitím
- PRO ML a numerické operace používejte pouze numerické sloupce: numeric_df = df.select_dtypes(include=['number'])
- Nikdy neprovádějte df.astype(float) na celém DataFrame; místo toho použijte pd.to_numeric(..., errors='coerce') na jednotlivé sloupce nebo pracujte s numeric_df
Vytvořte pokročilý skript s více vizualizacemi výhradně pomocí matplotlib a seaborn.
"""
resp = client.messages.create(
model=model,
max_tokens=8000, # VÍCE tokenů pro složitější skripty
messages=[{"role": "user", "content": sys_prompt}],
)
raw = resp.content[0].text
return sanitize_script(strip_markdown_fences(raw))
def run_script(script: str, output_format: str, dataset_info: Dict[str, Any] = None) -> Dict[str, Any]:
sandbox_timeout = int(os.getenv("SANDBOX_TIMEOUT", "120")) # Více času pro složité skripty
max_output = int(os.getenv("MAX_OUTPUT_SIZE", str(50 * 1024 * 1024))) # Větší limit pro více grafů
with tempfile.TemporaryDirectory(prefix="mcp_sandbox_") as tmp:
script_path = os.path.join(tmp, "visualization.py")
with open(script_path, "w", encoding="utf-8") as f:
f.write(script)
# Použij skutečná data z dataset_info
import pandas as pd
if dataset_info and 'sample_data' in dataset_info:
# Vytvoř DataFrame ze skutečných dat
df = pd.DataFrame(dataset_info['sample_data'])
df.to_csv(os.path.join(tmp, "data.csv"), index=False)
else:
# Fallback na dummy data pouze pokud nejsou skutečná data
dummy_data = {
'Age': [25, 30, 35, 40, 45, 50, 55, 60, 65, 70],
'Height': [170, 175, 180, 165, 172, 178, 182, 168, 173, 176],
'Weight': [70, 75, 80, 65, 72, 78, 82, 68, 73, 76],
'Gender': ['Male', 'Female', 'Male', 'Female', 'Male', 'Female', 'Male', 'Female', 'Male', 'Female'],
'Income': [30000, 40000, 50000, 35000, 42000, 48000, 55000, 38000, 43000, 52000]
}
dummy_df = pd.DataFrame(dummy_data)
dummy_df.to_csv(os.path.join(tmp, "data.csv"), index=False)
try:
res = subprocess.run(
["python", script_path],
cwd=tmp,
capture_output=True,
text=True,
timeout=sandbox_timeout,
)
except subprocess.TimeoutExpired:
raise HTTPException(status_code=504, detail="Timeout při spuštění skriptu")
# Hledej vygenerované soubory - prioritně main.{format}, pak ostatní
output_file = None
all_output_files = []
# 1. Zkus najít main.{format}
main_file = os.path.join(tmp, f"main.{output_format}")
if os.path.exists(main_file):
output_file = main_file
all_output_files.append(main_file)
# 2. Najdi všechny soubory s daným formátem
for file in os.listdir(tmp):
if file.endswith(f".{output_format}"):
file_path = os.path.join(tmp, file)
if file_path not in all_output_files:
all_output_files.append(file_path)
if not output_file: # První nalezený soubor
output_file = file_path
# 3. Pokud stále nic, hledej jakýkoliv obrázek
if not output_file:
for file in os.listdir(tmp):
if file.endswith(('.png', '.jpg', '.jpeg', '.svg')):
file_path = os.path.join(tmp, file)
all_output_files.append(file_path)
if not output_file:
output_file = file_path
# 4. Pokud stále nic, zkus najít jakýkoliv soubor s 'graph' v názvu
if not output_file:
for file in os.listdir(tmp):
if 'graph' in file.lower() and file.endswith(('.png', '.jpg', '.jpeg', '.svg')):
file_path = os.path.join(tmp, file)
all_output_files.append(file_path)
if not output_file:
output_file = file_path
# Převeď na base64 - sbírej všechny soubory (galerie) a pro kompatibilitu nech i single preview
data_b64 = ""
gallery_b64: list[str] = []
if all_output_files:
# Najdi největší soubor
largest_file = max(all_output_files, key=lambda f: os.path.getsize(f) if os.path.exists(f) else 0)
if os.path.exists(largest_file):
with open(largest_file, "rb") as f:
blob = f.read()
if len(blob) > max_output:
raise HTTPException(status_code=413, detail="Výstup je příliš velký")
data_b64 = base64.b64encode(blob).decode("utf-8")
# Naplň galerii
for fpath in all_output_files:
if os.path.exists(fpath):
with open(fpath, "rb") as fb:
blob = fb.read()
if len(blob) <= max_output:
gallery_b64.append(base64.b64encode(blob).decode("utf-8"))
elif output_file and os.path.exists(output_file):
with open(output_file, "rb") as f:
blob = f.read()
if len(blob) > max_output:
raise HTTPException(status_code=413, detail="Výstup je příliš velký")
data_b64 = base64.b64encode(blob).decode("utf-8")
gallery_b64 = [data_b64]
else:
# Debug: vypiš všechny soubory v tmp adresáři
all_files = os.listdir(tmp)
print(f"DEBUG: Všechny soubory v tmp: {all_files}")
print(f"DEBUG: Hledaný output_file: {output_file}")
print(f"DEBUG: return_code: {res.returncode}")
return {
"return_code": res.returncode,
"stdout": res.stdout,
"stderr": res.stderr,
"output_b64": data_b64,
"all_files": [os.path.basename(f) for f in all_output_files], # Debug info
"gallery_b64": gallery_b64,
}
@app.post("/advanced-visualization")
def advanced_visualization(req: VisualizationRequest) -> Dict[str, Any]:
try:
script = generate_script(
prompt=req.prompt,
dataset_info=req.dataset_info or {},
visualization_type=req.visualization_type or "advanced",
output_format=req.output_format or "png",
)
result = run_script(script, req.output_format or "png", req.dataset_info)
# Úspěch pokud máme výstupní data, i když return_code není 0
success = bool(result.get("output_b64"))
return {
"success": success,
"visualization": result.get("output_b64", ""),
"script": script,
"logs": {
"stdout": result.get("stdout", ""),
"stderr": result.get("stderr", ""),
},
"generated_files": result.get("all_files", []), # Seznam všech vygenerovaných souborů
"file_count": len(result.get("all_files", [])), # Počet vygenerovaných souborů
"visualizations_multi": result.get("gallery_b64", []),
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Chyba serveru: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))