import os import re import json import glob import base64 import shutil import logging from pathlib import Path from dotenv import load_dotenv from langchain_core.tools import Tool from langchain_community.agent_toolkits import FileManagementToolkit from langchain_community.tools.wikipedia.tool import WikipediaQueryRun from langchain_experimental.tools import PythonREPLTool from langchain_community.utilities import GoogleSerperAPIWrapper from langchain_community.utilities.wikipedia import WikipediaAPIWrapper load_dotenv(override=True) os.environ.setdefault("MPLBACKEND", "Agg") logger = logging.getLogger(__name__) try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt plt.show = lambda *args, **kwargs: None except Exception: matplotlib = None # Use an absolute path anchored to this file's directory so the sandbox # location is predictable regardless of the working directory at launch. _SCRIPT_DIR = Path(__file__).resolve().parent SANDBOX_DIR = _SCRIPT_DIR / "sandbox" SANDBOX_DIR.mkdir(exist_ok=True) # -- Shared text helpers ------------------------------------------------------- def normalize_message_text(content) -> str: """Convert LangChain message content (str | list[dict] | None) to a plain string.""" if isinstance(content, list): parts = [] for item in content: if isinstance(item, dict): parts.append(item.get("text", str(item))) else: parts.append(str(item)) return "\n".join(parts) return "" if content is None else str(content) # -- Sandbox helpers ------------------------------------------------------------ def get_session_sandbox_dir(session_id: str) -> Path: session_dir = SANDBOX_DIR / session_id session_dir.mkdir(parents=True, exist_ok=True) return session_dir def cleanup_session_sandbox(session_dir: str | Path) -> None: shutil.rmtree(session_dir, ignore_errors=True) def get_file_tools(session_dir: str | Path): toolkit = FileManagementToolkit(root_dir=str(session_dir)) return toolkit.get_tools() def copy_uploaded_file(src_path: str, filename: str, session_dir: str | Path) -> str: """Copy an uploaded Gradio file into the session sandbox so the agent can access it.""" dest = Path(session_dir) / filename shutil.copy(src_path, dest) return str(dest) def collect_charts(session_dir: str) -> list[str]: """Return paths of any .png files saved in the current session sandbox.""" return sorted(glob.glob(os.path.join(str(session_dir), "*.png"))) def recover_orphaned_charts(session_dir: str | Path) -> None: """ The PythonREPLTool executes code in the process CWD, which is usually the project root — not the session sandbox. Charts saved with a bare filename (e.g. ``plt.savefig('chart.png')``) end up in the CWD. This helper moves any .png files found in the CWD (and the sandbox root) into the session directory so that ``collect_charts`` picks them up. """ session_path = Path(session_dir) search_dirs = {Path.cwd(), SANDBOX_DIR} # Also check _SCRIPT_DIR in case the user launched from there search_dirs.add(_SCRIPT_DIR) for search_dir in search_dirs: if not search_dir.is_dir() or search_dir == session_path: continue for png in search_dir.glob("*.png"): dest = session_path / png.name if not dest.exists(): try: shutil.move(str(png), str(dest)) logger.info("Recovered orphaned chart %s → %s", png, dest) except Exception: logger.warning("Could not move %s", png, exc_info=True) # -- Notebook generation ------------------------------------------------------- def _make_nb_cell(cell_type: str, source: str, outputs=None) -> dict: """Build a single Jupyter notebook cell dict.""" cell = { "cell_type": cell_type, "metadata": {}, "source": source.splitlines(keepends=True) if source else [], } if cell_type == "code": cell["execution_count"] = None cell["outputs"] = outputs or [] return cell def _png_to_nb_output(png_path: str) -> dict: """Create a notebook display_data output that embeds a PNG image.""" with open(png_path, "rb") as f: b64 = base64.b64encode(f.read()).decode("ascii") return { "output_type": "display_data", "metadata": {}, "data": { "image/png": b64, "text/plain": [f""], }, } def build_notebook( session_dir: str, analyst_summary: str, code_snippets: list[str] | None = None, dataset_filename: str | None = None, ) -> str: """ Build an .ipynb file that bundles: - a markdown cell with the analyst's summary / findings - code cells for each Python snippet the agent ran - inline chart images (embedded as base64 display_data outputs) Returns the absolute path to the saved notebook. """ cells: list[dict] = [] cells.append(_make_nb_cell( "markdown", "# Data Analysis Report\n\n*Auto-generated by Data Analyst Agent*", )) if dataset_filename: cells.append(_make_nb_cell("markdown", f"**Dataset:** `{dataset_filename}`")) if code_snippets: cells.append(_make_nb_cell("markdown", "## Analysis Code")) for snippet in code_snippets: cells.append(_make_nb_cell("code", snippet)) chart_paths = collect_charts(session_dir) if chart_paths: cells.append(_make_nb_cell("markdown", "## Charts")) for chart_path in chart_paths: try: output = _png_to_nb_output(chart_path) chart_name = Path(chart_path).stem cells.append(_make_nb_cell( "code", ( f"# Chart: {chart_name}\n" f"from IPython.display import Image, display\n" f"display(Image(filename=r'{chart_path}'))" ), outputs=[output], )) except Exception: logger.warning("Could not embed chart %s", chart_path, exc_info=True) cells.append(_make_nb_cell("markdown", f"## Findings\n\n{analyst_summary}")) notebook = { "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3", }, "language_info": {"name": "python", "version": "3.11.0"}, }, "cells": cells, } nb_path = os.path.join(session_dir, "analysis_report.ipynb") with open(nb_path, "w", encoding="utf-8") as f: json.dump(notebook, f, indent=1, ensure_ascii=False) logger.info("Notebook saved to %s", nb_path) return nb_path # -- HTML report for in-browser preview ---------------------------------------- def build_html_report( session_dir: str, analyst_summary: str, dataset_filename: str | None = None, ) -> str: """ Build a self-contained HTML report with: - the analyst's markdown findings rendered as HTML - charts embedded as base64 tags Returns the HTML string (also saved to disk). """ summary_html = _md_to_simple_html(analyst_summary) chart_paths = collect_charts(session_dir) charts_html = "" for cp in chart_paths: try: with open(cp, "rb") as f: b64 = base64.b64encode(f.read()).decode("ascii") name = Path(cp).stem.replace("_", " ").title() charts_html += ( f'

{name}

' f'{name}
\n' ) except Exception: logger.warning("Could not embed chart %s in HTML", cp, exc_info=True) ds_label = ( f"

Dataset: {dataset_filename}

" if dataset_filename else "" ) html = ( "
" "\n" f"
\n

Data Analysis Report

\n{ds_label}\n" ) if charts_html: html += f"

Charts

\n{charts_html}\n" html += ( f"

Findings

\n
\n{summary_html}\n
\n" "

" "Generated by Data Analyst Agent

\n
\n
" ) html_path = os.path.join(session_dir, "analysis_report.html") with open(html_path, "w", encoding="utf-8") as f: f.write(html) logger.info("HTML report saved to %s", html_path) return html def _md_to_simple_html(text: str) -> str: """Minimal markdown-to-HTML conversion for analyst output.""" if not text: return "" lines = text.split("\n") out: list[str] = [] in_ul = False in_ol = False in_code = False for line in lines: s = line.strip() if s.startswith("```"): if in_code: out.append("") else: out.append("
")
            in_code = not in_code
            continue
        if in_code:
            out.append(line)
            continue
        if s.startswith("### "):
            if in_ul: out.append(""); in_ul = False
            if in_ol: out.append(""); in_ol = False
            out.append(f"

{_inline_fmt(s[4:])}

"); continue if s.startswith("## "): if in_ul: out.append(""); in_ul = False if in_ol: out.append(""); in_ol = False out.append(f"

{_inline_fmt(s[3:])}

"); continue if s.startswith("# "): if in_ul: out.append(""); in_ul = False if in_ol: out.append(""); in_ol = False out.append(f"

{_inline_fmt(s[2:])}

"); continue if s.startswith("- ") or s.startswith("* "): if in_ol: out.append(""); in_ol = False if not in_ul: out.append("
    "); in_ul = True out.append(f"
  • {_inline_fmt(s[2:])}
  • "); continue m = re.match(r"^(\d+)\.\s+(.+)$", s) if m: if in_ul: out.append("
"); in_ul = False if not in_ol: out.append("
    "); in_ol = True out.append(f"
  1. {_inline_fmt(m.group(2))}
  2. "); continue if in_ul: out.append(""); in_ul = False if in_ol: out.append("
"); in_ol = False if not s: out.append(""); continue out.append(f"

{_inline_fmt(s)}

") if in_ul: out.append("") if in_ol: out.append("") if in_code: out.append("
") return "\n".join(out) def _inline_fmt(text: str) -> str: """Apply bold and inline-code markdown formatting.""" text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) text = re.sub(r"`(.+?)`", r"\1", text) return text # -- Extract code snippets from conversation ----------------------------------- def extract_python_snippets(messages) -> list[str]: """ Walk through LangChain messages and pull out the Python code that was sent to the PythonREPLTool (via tool_calls). """ from langchain_core.messages import AIMessage snippets: list[str] = [] for msg in messages: if isinstance(msg, AIMessage) and hasattr(msg, "tool_calls") and msg.tool_calls: for tc in msg.tool_calls: name = tc.get("name", "") if "python" in name.lower(): args = tc.get("args", {}) code = ( args.get("command") or args.get("code") or args.get("query") or "" ) if code.strip(): snippets.append(code.strip()) return snippets # -- Tool factory -------------------------------------------------------------- def get_analyst_tools(session_dir: str | Path, enable_web_search: bool | None = None): file_tools = get_file_tools(session_dir) tools = list(file_tools) if enable_web_search is None: enable_web_search = bool(os.getenv("SERPER_API_KEY")) if enable_web_search: serper = GoogleSerperAPIWrapper() search_tool = Tool( name="web_search", func=serper.run, description="Search the web for context about data, industry benchmarks, or methodology questions.", ) tools.append(search_tool) wikipedia = WikipediaAPIWrapper() wiki_tool = WikipediaQueryRun(api_wrapper=wikipedia) python_repl = PythonREPLTool() tools.extend([python_repl, wiki_tool]) return tools