data-analyst-agent / analyst_tools.py
JamesDominiqueAI
Fix Tool import for current LangChain
7549192
Raw
History Blame Contribute Delete
14.3 kB
import os
import re
import json
import glob
import base64
import shutil
import logging
from pathlib import Path
from dotenv import load_dotenv
from langchain_core.tools import Tool
from langchain_community.agent_toolkits import FileManagementToolkit
from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
from langchain_experimental.tools import PythonREPLTool
from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
load_dotenv(override=True)
os.environ.setdefault("MPLBACKEND", "Agg")
logger = logging.getLogger(__name__)
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.show = lambda *args, **kwargs: None
except Exception:
matplotlib = None
# Use an absolute path anchored to this file's directory so the sandbox
# location is predictable regardless of the working directory at launch.
_SCRIPT_DIR = Path(__file__).resolve().parent
SANDBOX_DIR = _SCRIPT_DIR / "sandbox"
SANDBOX_DIR.mkdir(exist_ok=True)
# -- Shared text helpers -------------------------------------------------------
def normalize_message_text(content) -> str:
"""Convert LangChain message content (str | list[dict] | None) to a plain string."""
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(item.get("text", str(item)))
else:
parts.append(str(item))
return "\n".join(parts)
return "" if content is None else str(content)
# -- Sandbox helpers ------------------------------------------------------------
def get_session_sandbox_dir(session_id: str) -> Path:
session_dir = SANDBOX_DIR / session_id
session_dir.mkdir(parents=True, exist_ok=True)
return session_dir
def cleanup_session_sandbox(session_dir: str | Path) -> None:
shutil.rmtree(session_dir, ignore_errors=True)
def get_file_tools(session_dir: str | Path):
toolkit = FileManagementToolkit(root_dir=str(session_dir))
return toolkit.get_tools()
def copy_uploaded_file(src_path: str, filename: str, session_dir: str | Path) -> str:
"""Copy an uploaded Gradio file into the session sandbox so the agent can access it."""
dest = Path(session_dir) / filename
shutil.copy(src_path, dest)
return str(dest)
def collect_charts(session_dir: str) -> list[str]:
"""Return paths of any .png files saved in the current session sandbox."""
return sorted(glob.glob(os.path.join(str(session_dir), "*.png")))
def recover_orphaned_charts(session_dir: str | Path) -> None:
"""
The PythonREPLTool executes code in the process CWD, which is usually the
project root — not the session sandbox. Charts saved with a bare filename
(e.g. ``plt.savefig('chart.png')``) end up in the CWD.
This helper moves any .png files found in the CWD (and the sandbox root)
into the session directory so that ``collect_charts`` picks them up.
"""
session_path = Path(session_dir)
search_dirs = {Path.cwd(), SANDBOX_DIR}
# Also check _SCRIPT_DIR in case the user launched from there
search_dirs.add(_SCRIPT_DIR)
for search_dir in search_dirs:
if not search_dir.is_dir() or search_dir == session_path:
continue
for png in search_dir.glob("*.png"):
dest = session_path / png.name
if not dest.exists():
try:
shutil.move(str(png), str(dest))
logger.info("Recovered orphaned chart %s → %s", png, dest)
except Exception:
logger.warning("Could not move %s", png, exc_info=True)
# -- Notebook generation -------------------------------------------------------
def _make_nb_cell(cell_type: str, source: str, outputs=None) -> dict:
"""Build a single Jupyter notebook cell dict."""
cell = {
"cell_type": cell_type,
"metadata": {},
"source": source.splitlines(keepends=True) if source else [],
}
if cell_type == "code":
cell["execution_count"] = None
cell["outputs"] = outputs or []
return cell
def _png_to_nb_output(png_path: str) -> dict:
"""Create a notebook display_data output that embeds a PNG image."""
with open(png_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
return {
"output_type": "display_data",
"metadata": {},
"data": {
"image/png": b64,
"text/plain": [f"<Figure: {Path(png_path).name}>"],
},
}
def build_notebook(
session_dir: str,
analyst_summary: str,
code_snippets: list[str] | None = None,
dataset_filename: str | None = None,
) -> str:
"""
Build an .ipynb file that bundles:
- a markdown cell with the analyst's summary / findings
- code cells for each Python snippet the agent ran
- inline chart images (embedded as base64 display_data outputs)
Returns the absolute path to the saved notebook.
"""
cells: list[dict] = []
cells.append(_make_nb_cell(
"markdown",
"# Data Analysis Report\n\n*Auto-generated by Data Analyst Agent*",
))
if dataset_filename:
cells.append(_make_nb_cell("markdown", f"**Dataset:** `{dataset_filename}`"))
if code_snippets:
cells.append(_make_nb_cell("markdown", "## Analysis Code"))
for snippet in code_snippets:
cells.append(_make_nb_cell("code", snippet))
chart_paths = collect_charts(session_dir)
if chart_paths:
cells.append(_make_nb_cell("markdown", "## Charts"))
for chart_path in chart_paths:
try:
output = _png_to_nb_output(chart_path)
chart_name = Path(chart_path).stem
cells.append(_make_nb_cell(
"code",
(
f"# Chart: {chart_name}\n"
f"from IPython.display import Image, display\n"
f"display(Image(filename=r'{chart_path}'))"
),
outputs=[output],
))
except Exception:
logger.warning("Could not embed chart %s", chart_path, exc_info=True)
cells.append(_make_nb_cell("markdown", f"## Findings\n\n{analyst_summary}"))
notebook = {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3",
},
"language_info": {"name": "python", "version": "3.11.0"},
},
"cells": cells,
}
nb_path = os.path.join(session_dir, "analysis_report.ipynb")
with open(nb_path, "w", encoding="utf-8") as f:
json.dump(notebook, f, indent=1, ensure_ascii=False)
logger.info("Notebook saved to %s", nb_path)
return nb_path
# -- HTML report for in-browser preview ----------------------------------------
def build_html_report(
session_dir: str,
analyst_summary: str,
dataset_filename: str | None = None,
) -> str:
"""
Build a self-contained HTML report with:
- the analyst's markdown findings rendered as HTML
- charts embedded as base64 <img> tags
Returns the HTML string (also saved to disk).
"""
summary_html = _md_to_simple_html(analyst_summary)
chart_paths = collect_charts(session_dir)
charts_html = ""
for cp in chart_paths:
try:
with open(cp, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
name = Path(cp).stem.replace("_", " ").title()
charts_html += (
f'<div class="chart"><h3>{name}</h3>'
f'<img src="data:image/png;base64,{b64}" alt="{name}" /></div>\n'
)
except Exception:
logger.warning("Could not embed chart %s in HTML", cp, exc_info=True)
ds_label = (
f"<p class='meta'>Dataset: <code>{dataset_filename}</code></p>"
if dataset_filename else ""
)
html = (
"<div style='font-family:Inter,Segoe UI,system-ui,sans-serif;"
"width:100%;padding:1rem;color:#1e293b;box-sizing:border-box;'>"
"<style>\n"
".rpt h1{font-size:1.4rem;color:#3730a3;border-bottom:2px solid #c7d2fe;"
"padding-bottom:.4rem;margin-top:0}\n"
".rpt h2{font-size:1.15rem;color:#3730a3;margin-top:1.5rem}\n"
".rpt h3{font-size:1rem;color:#1e293b;font-weight:600}\n"
".rpt .meta{color:#475569;font-size:.85rem}\n"
".rpt .chart{margin:1rem 0}\n"
".rpt .chart img{max-width:100%;border-radius:10px;"
"box-shadow:0 2px 12px rgba(0,0,0,.06);margin:.4rem 0 1rem}\n"
".rpt .findings{background:#f8faff;padding:1.2rem;"
"border-radius:10px;border:1px solid #e0e7ff;margin-top:.8rem}\n"
".rpt .findings h1,.rpt .findings h2,.rpt .findings h3{"
"color:#1e3a5f}\n"
".rpt .findings p,.rpt .findings li,.rpt .findings span{"
"color:#1e293b}\n"
".rpt .findings strong{color:#0f172a}\n"
".rpt code{background:#e0e7ff;padding:2px 6px;border-radius:4px;"
"font-size:.88em;color:#3730a3}\n"
".rpt pre{background:#f1f5f9;padding:.8rem;border-radius:8px;"
"overflow-x:auto;font-size:.85em;color:#1e293b}\n"
".rpt ul,.rpt ol{padding-left:1.4rem}"
".rpt li{margin-bottom:.25rem;line-height:1.6;color:#1e293b}\n"
".rpt p{line-height:1.6;color:#1e293b}\n"
"</style>\n"
f"<div class='rpt'>\n<h1>Data Analysis Report</h1>\n{ds_label}\n"
)
if charts_html:
html += f"<h2>Charts</h2>\n{charts_html}\n"
html += (
f"<h2>Findings</h2>\n<div class='findings'>\n{summary_html}\n</div>\n"
"<p class='meta' style='margin-top:2rem;text-align:center;color:#64748b'>"
"Generated by Data Analyst Agent</p>\n</div>\n</div>"
)
html_path = os.path.join(session_dir, "analysis_report.html")
with open(html_path, "w", encoding="utf-8") as f:
f.write(html)
logger.info("HTML report saved to %s", html_path)
return html
def _md_to_simple_html(text: str) -> str:
"""Minimal markdown-to-HTML conversion for analyst output."""
if not text:
return ""
lines = text.split("\n")
out: list[str] = []
in_ul = False
in_ol = False
in_code = False
for line in lines:
s = line.strip()
if s.startswith("```"):
if in_code:
out.append("</code></pre>")
else:
out.append("<pre><code>")
in_code = not in_code
continue
if in_code:
out.append(line)
continue
if s.startswith("### "):
if in_ul: out.append("</ul>"); in_ul = False
if in_ol: out.append("</ol>"); in_ol = False
out.append(f"<h3>{_inline_fmt(s[4:])}</h3>"); continue
if s.startswith("## "):
if in_ul: out.append("</ul>"); in_ul = False
if in_ol: out.append("</ol>"); in_ol = False
out.append(f"<h2>{_inline_fmt(s[3:])}</h2>"); continue
if s.startswith("# "):
if in_ul: out.append("</ul>"); in_ul = False
if in_ol: out.append("</ol>"); in_ol = False
out.append(f"<h1>{_inline_fmt(s[2:])}</h1>"); continue
if s.startswith("- ") or s.startswith("* "):
if in_ol: out.append("</ol>"); in_ol = False
if not in_ul: out.append("<ul>"); in_ul = True
out.append(f"<li>{_inline_fmt(s[2:])}</li>"); continue
m = re.match(r"^(\d+)\.\s+(.+)$", s)
if m:
if in_ul: out.append("</ul>"); in_ul = False
if not in_ol: out.append("<ol>"); in_ol = True
out.append(f"<li>{_inline_fmt(m.group(2))}</li>"); continue
if in_ul: out.append("</ul>"); in_ul = False
if in_ol: out.append("</ol>"); in_ol = False
if not s:
out.append(""); continue
out.append(f"<p>{_inline_fmt(s)}</p>")
if in_ul: out.append("</ul>")
if in_ol: out.append("</ol>")
if in_code: out.append("</code></pre>")
return "\n".join(out)
def _inline_fmt(text: str) -> str:
"""Apply bold and inline-code markdown formatting."""
text = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", text)
text = re.sub(r"`(.+?)`", r"<code>\1</code>", text)
return text
# -- Extract code snippets from conversation -----------------------------------
def extract_python_snippets(messages) -> list[str]:
"""
Walk through LangChain messages and pull out the Python code that was
sent to the PythonREPLTool (via tool_calls).
"""
from langchain_core.messages import AIMessage
snippets: list[str] = []
for msg in messages:
if isinstance(msg, AIMessage) and hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
name = tc.get("name", "")
if "python" in name.lower():
args = tc.get("args", {})
code = (
args.get("command")
or args.get("code")
or args.get("query")
or ""
)
if code.strip():
snippets.append(code.strip())
return snippets
# -- Tool factory --------------------------------------------------------------
def get_analyst_tools(session_dir: str | Path, enable_web_search: bool | None = None):
file_tools = get_file_tools(session_dir)
tools = list(file_tools)
if enable_web_search is None:
enable_web_search = bool(os.getenv("SERPER_API_KEY"))
if enable_web_search:
serper = GoogleSerperAPIWrapper()
search_tool = Tool(
name="web_search",
func=serper.run,
description="Search the web for context about data, industry benchmarks, or methodology questions.",
)
tools.append(search_tool)
wikipedia = WikipediaAPIWrapper()
wiki_tool = WikipediaQueryRun(api_wrapper=wikipedia)
python_repl = PythonREPLTool()
tools.extend([python_repl, wiki_tool])
return tools