Spaces:
Running
Running
| import os | |
| import re | |
| import json | |
| import glob | |
| import base64 | |
| import shutil | |
| import logging | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| from langchain_core.tools import Tool | |
| from langchain_community.agent_toolkits import FileManagementToolkit | |
| from langchain_community.tools.wikipedia.tool import WikipediaQueryRun | |
| from langchain_experimental.tools import PythonREPLTool | |
| from langchain_community.utilities import GoogleSerperAPIWrapper | |
| from langchain_community.utilities.wikipedia import WikipediaAPIWrapper | |
| load_dotenv(override=True) | |
| os.environ.setdefault("MPLBACKEND", "Agg") | |
| logger = logging.getLogger(__name__) | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| plt.show = lambda *args, **kwargs: None | |
| except Exception: | |
| matplotlib = None | |
| # Use an absolute path anchored to this file's directory so the sandbox | |
| # location is predictable regardless of the working directory at launch. | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| SANDBOX_DIR = _SCRIPT_DIR / "sandbox" | |
| SANDBOX_DIR.mkdir(exist_ok=True) | |
| # -- Shared text helpers ------------------------------------------------------- | |
| def normalize_message_text(content) -> str: | |
| """Convert LangChain message content (str | list[dict] | None) to a plain string.""" | |
| if isinstance(content, list): | |
| parts = [] | |
| for item in content: | |
| if isinstance(item, dict): | |
| parts.append(item.get("text", str(item))) | |
| else: | |
| parts.append(str(item)) | |
| return "\n".join(parts) | |
| return "" if content is None else str(content) | |
| # -- Sandbox helpers ------------------------------------------------------------ | |
| def get_session_sandbox_dir(session_id: str) -> Path: | |
| session_dir = SANDBOX_DIR / session_id | |
| session_dir.mkdir(parents=True, exist_ok=True) | |
| return session_dir | |
| def cleanup_session_sandbox(session_dir: str | Path) -> None: | |
| shutil.rmtree(session_dir, ignore_errors=True) | |
| def get_file_tools(session_dir: str | Path): | |
| toolkit = FileManagementToolkit(root_dir=str(session_dir)) | |
| return toolkit.get_tools() | |
| def copy_uploaded_file(src_path: str, filename: str, session_dir: str | Path) -> str: | |
| """Copy an uploaded Gradio file into the session sandbox so the agent can access it.""" | |
| dest = Path(session_dir) / filename | |
| shutil.copy(src_path, dest) | |
| return str(dest) | |
| def collect_charts(session_dir: str) -> list[str]: | |
| """Return paths of any .png files saved in the current session sandbox.""" | |
| return sorted(glob.glob(os.path.join(str(session_dir), "*.png"))) | |
| def recover_orphaned_charts(session_dir: str | Path) -> None: | |
| """ | |
| The PythonREPLTool executes code in the process CWD, which is usually the | |
| project root — not the session sandbox. Charts saved with a bare filename | |
| (e.g. ``plt.savefig('chart.png')``) end up in the CWD. | |
| This helper moves any .png files found in the CWD (and the sandbox root) | |
| into the session directory so that ``collect_charts`` picks them up. | |
| """ | |
| session_path = Path(session_dir) | |
| search_dirs = {Path.cwd(), SANDBOX_DIR} | |
| # Also check _SCRIPT_DIR in case the user launched from there | |
| search_dirs.add(_SCRIPT_DIR) | |
| for search_dir in search_dirs: | |
| if not search_dir.is_dir() or search_dir == session_path: | |
| continue | |
| for png in search_dir.glob("*.png"): | |
| dest = session_path / png.name | |
| if not dest.exists(): | |
| try: | |
| shutil.move(str(png), str(dest)) | |
| logger.info("Recovered orphaned chart %s → %s", png, dest) | |
| except Exception: | |
| logger.warning("Could not move %s", png, exc_info=True) | |
| # -- Notebook generation ------------------------------------------------------- | |
| def _make_nb_cell(cell_type: str, source: str, outputs=None) -> dict: | |
| """Build a single Jupyter notebook cell dict.""" | |
| cell = { | |
| "cell_type": cell_type, | |
| "metadata": {}, | |
| "source": source.splitlines(keepends=True) if source else [], | |
| } | |
| if cell_type == "code": | |
| cell["execution_count"] = None | |
| cell["outputs"] = outputs or [] | |
| return cell | |
| def _png_to_nb_output(png_path: str) -> dict: | |
| """Create a notebook display_data output that embeds a PNG image.""" | |
| with open(png_path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("ascii") | |
| return { | |
| "output_type": "display_data", | |
| "metadata": {}, | |
| "data": { | |
| "image/png": b64, | |
| "text/plain": [f"<Figure: {Path(png_path).name}>"], | |
| }, | |
| } | |
| def build_notebook( | |
| session_dir: str, | |
| analyst_summary: str, | |
| code_snippets: list[str] | None = None, | |
| dataset_filename: str | None = None, | |
| ) -> str: | |
| """ | |
| Build an .ipynb file that bundles: | |
| - a markdown cell with the analyst's summary / findings | |
| - code cells for each Python snippet the agent ran | |
| - inline chart images (embedded as base64 display_data outputs) | |
| Returns the absolute path to the saved notebook. | |
| """ | |
| cells: list[dict] = [] | |
| cells.append(_make_nb_cell( | |
| "markdown", | |
| "# Data Analysis Report\n\n*Auto-generated by Data Analyst Agent*", | |
| )) | |
| if dataset_filename: | |
| cells.append(_make_nb_cell("markdown", f"**Dataset:** `{dataset_filename}`")) | |
| if code_snippets: | |
| cells.append(_make_nb_cell("markdown", "## Analysis Code")) | |
| for snippet in code_snippets: | |
| cells.append(_make_nb_cell("code", snippet)) | |
| chart_paths = collect_charts(session_dir) | |
| if chart_paths: | |
| cells.append(_make_nb_cell("markdown", "## Charts")) | |
| for chart_path in chart_paths: | |
| try: | |
| output = _png_to_nb_output(chart_path) | |
| chart_name = Path(chart_path).stem | |
| cells.append(_make_nb_cell( | |
| "code", | |
| ( | |
| f"# Chart: {chart_name}\n" | |
| f"from IPython.display import Image, display\n" | |
| f"display(Image(filename=r'{chart_path}'))" | |
| ), | |
| outputs=[output], | |
| )) | |
| except Exception: | |
| logger.warning("Could not embed chart %s", chart_path, exc_info=True) | |
| cells.append(_make_nb_cell("markdown", f"## Findings\n\n{analyst_summary}")) | |
| notebook = { | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3", | |
| }, | |
| "language_info": {"name": "python", "version": "3.11.0"}, | |
| }, | |
| "cells": cells, | |
| } | |
| nb_path = os.path.join(session_dir, "analysis_report.ipynb") | |
| with open(nb_path, "w", encoding="utf-8") as f: | |
| json.dump(notebook, f, indent=1, ensure_ascii=False) | |
| logger.info("Notebook saved to %s", nb_path) | |
| return nb_path | |
| # -- HTML report for in-browser preview ---------------------------------------- | |
| def build_html_report( | |
| session_dir: str, | |
| analyst_summary: str, | |
| dataset_filename: str | None = None, | |
| ) -> str: | |
| """ | |
| Build a self-contained HTML report with: | |
| - the analyst's markdown findings rendered as HTML | |
| - charts embedded as base64 <img> tags | |
| Returns the HTML string (also saved to disk). | |
| """ | |
| summary_html = _md_to_simple_html(analyst_summary) | |
| chart_paths = collect_charts(session_dir) | |
| charts_html = "" | |
| for cp in chart_paths: | |
| try: | |
| with open(cp, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("ascii") | |
| name = Path(cp).stem.replace("_", " ").title() | |
| charts_html += ( | |
| f'<div class="chart"><h3>{name}</h3>' | |
| f'<img src="data:image/png;base64,{b64}" alt="{name}" /></div>\n' | |
| ) | |
| except Exception: | |
| logger.warning("Could not embed chart %s in HTML", cp, exc_info=True) | |
| ds_label = ( | |
| f"<p class='meta'>Dataset: <code>{dataset_filename}</code></p>" | |
| if dataset_filename else "" | |
| ) | |
| html = ( | |
| "<div style='font-family:Inter,Segoe UI,system-ui,sans-serif;" | |
| "width:100%;padding:1rem;color:#1e293b;box-sizing:border-box;'>" | |
| "<style>\n" | |
| ".rpt h1{font-size:1.4rem;color:#3730a3;border-bottom:2px solid #c7d2fe;" | |
| "padding-bottom:.4rem;margin-top:0}\n" | |
| ".rpt h2{font-size:1.15rem;color:#3730a3;margin-top:1.5rem}\n" | |
| ".rpt h3{font-size:1rem;color:#1e293b;font-weight:600}\n" | |
| ".rpt .meta{color:#475569;font-size:.85rem}\n" | |
| ".rpt .chart{margin:1rem 0}\n" | |
| ".rpt .chart img{max-width:100%;border-radius:10px;" | |
| "box-shadow:0 2px 12px rgba(0,0,0,.06);margin:.4rem 0 1rem}\n" | |
| ".rpt .findings{background:#f8faff;padding:1.2rem;" | |
| "border-radius:10px;border:1px solid #e0e7ff;margin-top:.8rem}\n" | |
| ".rpt .findings h1,.rpt .findings h2,.rpt .findings h3{" | |
| "color:#1e3a5f}\n" | |
| ".rpt .findings p,.rpt .findings li,.rpt .findings span{" | |
| "color:#1e293b}\n" | |
| ".rpt .findings strong{color:#0f172a}\n" | |
| ".rpt code{background:#e0e7ff;padding:2px 6px;border-radius:4px;" | |
| "font-size:.88em;color:#3730a3}\n" | |
| ".rpt pre{background:#f1f5f9;padding:.8rem;border-radius:8px;" | |
| "overflow-x:auto;font-size:.85em;color:#1e293b}\n" | |
| ".rpt ul,.rpt ol{padding-left:1.4rem}" | |
| ".rpt li{margin-bottom:.25rem;line-height:1.6;color:#1e293b}\n" | |
| ".rpt p{line-height:1.6;color:#1e293b}\n" | |
| "</style>\n" | |
| f"<div class='rpt'>\n<h1>Data Analysis Report</h1>\n{ds_label}\n" | |
| ) | |
| if charts_html: | |
| html += f"<h2>Charts</h2>\n{charts_html}\n" | |
| html += ( | |
| f"<h2>Findings</h2>\n<div class='findings'>\n{summary_html}\n</div>\n" | |
| "<p class='meta' style='margin-top:2rem;text-align:center;color:#64748b'>" | |
| "Generated by Data Analyst Agent</p>\n</div>\n</div>" | |
| ) | |
| html_path = os.path.join(session_dir, "analysis_report.html") | |
| with open(html_path, "w", encoding="utf-8") as f: | |
| f.write(html) | |
| logger.info("HTML report saved to %s", html_path) | |
| return html | |
| def _md_to_simple_html(text: str) -> str: | |
| """Minimal markdown-to-HTML conversion for analyst output.""" | |
| if not text: | |
| return "" | |
| lines = text.split("\n") | |
| out: list[str] = [] | |
| in_ul = False | |
| in_ol = False | |
| in_code = False | |
| for line in lines: | |
| s = line.strip() | |
| if s.startswith("```"): | |
| if in_code: | |
| out.append("</code></pre>") | |
| else: | |
| out.append("<pre><code>") | |
| in_code = not in_code | |
| continue | |
| if in_code: | |
| out.append(line) | |
| continue | |
| if s.startswith("### "): | |
| if in_ul: out.append("</ul>"); in_ul = False | |
| if in_ol: out.append("</ol>"); in_ol = False | |
| out.append(f"<h3>{_inline_fmt(s[4:])}</h3>"); continue | |
| if s.startswith("## "): | |
| if in_ul: out.append("</ul>"); in_ul = False | |
| if in_ol: out.append("</ol>"); in_ol = False | |
| out.append(f"<h2>{_inline_fmt(s[3:])}</h2>"); continue | |
| if s.startswith("# "): | |
| if in_ul: out.append("</ul>"); in_ul = False | |
| if in_ol: out.append("</ol>"); in_ol = False | |
| out.append(f"<h1>{_inline_fmt(s[2:])}</h1>"); continue | |
| if s.startswith("- ") or s.startswith("* "): | |
| if in_ol: out.append("</ol>"); in_ol = False | |
| if not in_ul: out.append("<ul>"); in_ul = True | |
| out.append(f"<li>{_inline_fmt(s[2:])}</li>"); continue | |
| m = re.match(r"^(\d+)\.\s+(.+)$", s) | |
| if m: | |
| if in_ul: out.append("</ul>"); in_ul = False | |
| if not in_ol: out.append("<ol>"); in_ol = True | |
| out.append(f"<li>{_inline_fmt(m.group(2))}</li>"); continue | |
| if in_ul: out.append("</ul>"); in_ul = False | |
| if in_ol: out.append("</ol>"); in_ol = False | |
| if not s: | |
| out.append(""); continue | |
| out.append(f"<p>{_inline_fmt(s)}</p>") | |
| if in_ul: out.append("</ul>") | |
| if in_ol: out.append("</ol>") | |
| if in_code: out.append("</code></pre>") | |
| return "\n".join(out) | |
| def _inline_fmt(text: str) -> str: | |
| """Apply bold and inline-code markdown formatting.""" | |
| text = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", text) | |
| text = re.sub(r"`(.+?)`", r"<code>\1</code>", text) | |
| return text | |
| # -- Extract code snippets from conversation ----------------------------------- | |
| def extract_python_snippets(messages) -> list[str]: | |
| """ | |
| Walk through LangChain messages and pull out the Python code that was | |
| sent to the PythonREPLTool (via tool_calls). | |
| """ | |
| from langchain_core.messages import AIMessage | |
| snippets: list[str] = [] | |
| for msg in messages: | |
| if isinstance(msg, AIMessage) and hasattr(msg, "tool_calls") and msg.tool_calls: | |
| for tc in msg.tool_calls: | |
| name = tc.get("name", "") | |
| if "python" in name.lower(): | |
| args = tc.get("args", {}) | |
| code = ( | |
| args.get("command") | |
| or args.get("code") | |
| or args.get("query") | |
| or "" | |
| ) | |
| if code.strip(): | |
| snippets.append(code.strip()) | |
| return snippets | |
| # -- Tool factory -------------------------------------------------------------- | |
| def get_analyst_tools(session_dir: str | Path, enable_web_search: bool | None = None): | |
| file_tools = get_file_tools(session_dir) | |
| tools = list(file_tools) | |
| if enable_web_search is None: | |
| enable_web_search = bool(os.getenv("SERPER_API_KEY")) | |
| if enable_web_search: | |
| serper = GoogleSerperAPIWrapper() | |
| search_tool = Tool( | |
| name="web_search", | |
| func=serper.run, | |
| description="Search the web for context about data, industry benchmarks, or methodology questions.", | |
| ) | |
| tools.append(search_tool) | |
| wikipedia = WikipediaAPIWrapper() | |
| wiki_tool = WikipediaQueryRun(api_wrapper=wikipedia) | |
| python_repl = PythonREPLTool() | |
| tools.extend([python_repl, wiki_tool]) | |
| return tools | |