""" LangChain Callback Handler for Potato Automatically sends LangChain agent traces to a Potato instance for human evaluation and annotation. Requires: pip install langchain-core>=0.1.0 Usage: from potato.integrations.langchain_callback import PotatoCallbackHandler handler = PotatoCallbackHandler(potato_url="http://localhost:8000") chain.invoke({"input": "..."}, config={"callbacks": [handler]}) """ import json import logging import threading import time import uuid from typing import Any, Dict, List, Optional, Sequence, Union import requests logger = logging.getLogger(__name__) def _safe_serialize(obj: Any) -> Any: """Convert an object to a JSON-safe representation.""" if obj is None or isinstance(obj, (str, int, float, bool)): return obj if isinstance(obj, dict): return {str(k): _safe_serialize(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [_safe_serialize(v) for v in obj] # Fall back to string representation try: return str(obj) except Exception: return "" class PotatoCallbackHandler: """ LangChain callback handler that sends completed traces to Potato. Collects run events (chain, LLM, tool starts/ends), tracks parent-child relationships, and POSTs the full trace to Potato's webhook endpoint when the root chain completes. The payload uses the LangSmith format expected by ``POST /api/traces/langsmith``. Args: potato_url: Base URL of the Potato server (e.g., ``http://localhost:8000``) api_key: API key for authenticating with Potato's webhook endpoint endpoint: Webhook path (default ``/api/traces/langsmith``) send_timeout: HTTP timeout in seconds for the POST request (default 10) metadata: Extra metadata dict attached to every trace """ def __init__( self, potato_url: str, api_key: str = "", endpoint: str = "/api/traces/langsmith", send_timeout: int = 10, metadata: Optional[Dict[str, Any]] = None, ): self.potato_url = potato_url.rstrip("/") self.api_key = api_key self.endpoint = endpoint self.send_timeout = send_timeout self.extra_metadata = metadata or {} # Run tracking — protected by lock for thread-safety self._lock = threading.Lock() self._runs: Dict[str, dict] = {} # run_id -> run dict self._root_run_id: Optional[str] = None self._pending_sends: List[threading.Thread] = [] # ------------------------------------------------------------------ # LangChain BaseCallbackHandler interface # ------------------------------------------------------------------ def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: self._start_run( run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, run_type="chain", name=serialized.get("name", serialized.get("id", ["unknown"])[-1] if isinstance(serialized.get("id"), list) else "chain"), inputs=inputs, tags=tags, metadata=metadata, ) def on_chain_end( self, outputs: Dict[str, Any], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> None: self._end_run(str(run_id), outputs=outputs) def on_chain_error( self, error: BaseException, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> None: self._end_run(str(run_id), error=error) def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: self._start_run( run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, run_type="llm", name=serialized.get("name", "llm"), inputs={"prompts": prompts}, tags=tags, metadata=metadata, ) def on_llm_end( self, response: Any, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> None: output = {} if hasattr(response, "generations") and response.generations: texts = [] for gen_list in response.generations: for gen in gen_list: texts.append(gen.text if hasattr(gen, "text") else str(gen)) output = {"text": "\n".join(texts)} self._end_run(str(run_id), outputs=output) def on_llm_error( self, error: BaseException, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> None: self._end_run(str(run_id), error=error) def on_tool_start( self, serialized: Dict[str, Any], input_str: str, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: self._start_run( run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, run_type="tool", name=serialized.get("name", "tool"), inputs={"input": input_str}, tags=tags, metadata=metadata, ) def on_tool_end( self, output: str, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> None: self._end_run(str(run_id), outputs={"output": output}) def on_tool_error( self, error: BaseException, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> None: self._end_run(str(run_id), error=error) # Retriever callbacks def on_retriever_start( self, serialized: Dict[str, Any], query: str, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: self._start_run( run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, run_type="retriever", name=serialized.get("name", "retriever"), inputs={"query": query}, tags=tags, metadata=metadata, ) def on_retriever_end( self, documents: Any, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> None: output = {"documents": _safe_serialize(documents)} self._end_run(str(run_id), outputs=output) # Text callbacks (no-ops — captured by LLM callbacks) def on_text(self, text: str, **kwargs: Any) -> None: pass # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _start_run( self, run_id: str, parent_run_id: Optional[str], run_type: str, name: str, inputs: Any, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: run = { "id": run_id, "parent_run_id": parent_run_id, "run_type": run_type, "name": name, "inputs": _safe_serialize(inputs), "outputs": {}, "status": "running", "start_time": time.time(), "end_time": None, "tags": tags or [], "metadata": metadata or {}, } with self._lock: self._runs[run_id] = run if parent_run_id is None: self._root_run_id = run_id def _end_run( self, run_id: str, outputs: Optional[Dict[str, Any]] = None, error: Optional[BaseException] = None, ) -> None: with self._lock: run = self._runs.get(run_id) if run is None: return run["end_time"] = time.time() run["latency"] = run["end_time"] - run["start_time"] if error: run["status"] = "error" run["outputs"] = {"error": str(error)} else: run["status"] = "completed" run["outputs"] = _safe_serialize(outputs or {}) # If this is the root run, send the trace is_root = run_id == self._root_run_id if is_root: self._send_trace() def _build_payload(self) -> dict: """Build a LangSmith-format payload from collected runs.""" with self._lock: runs = list(self._runs.values()) root_id = self._root_run_id # Find the root run for metadata root_run = None for r in runs: if r["id"] == root_id: root_run = r break payload = { "runs": [ { "id": r["id"], "parent_run_id": r["parent_run_id"], "run_type": r["run_type"], "name": r["name"], "inputs": r["inputs"], "outputs": r["outputs"], "status": r["status"], "latency": r.get("latency"), "tags": r.get("tags", []), } for r in runs ], } if root_run: payload["project_name"] = root_run.get("name", "langchain") return payload def _send_trace(self) -> None: """POST the trace to Potato in a background thread.""" payload = self._build_payload() def _do_send(): try: url = f"{self.potato_url}{self.endpoint}" headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" resp = requests.post( url, json=payload, headers=headers, timeout=self.send_timeout, ) if resp.status_code < 300: logger.info("Trace sent to Potato: %s", resp.json()) else: logger.warning( "Potato returned %s: %s", resp.status_code, resp.text ) except Exception as e: logger.error("Failed to send trace to Potato: %s", e) thread = threading.Thread(target=_do_send, daemon=True) with self._lock: self._pending_sends.append(thread) thread.start() def flush(self, timeout: float = 30.0) -> None: """Block until all pending sends complete (or timeout).""" with self._lock: threads = list(self._pending_sends) for t in threads: t.join(timeout=timeout) with self._lock: self._pending_sends = [t for t in self._pending_sends if t.is_alive()] def reset(self) -> None: """Clear all collected runs (for reuse across multiple chains).""" with self._lock: self._runs.clear() self._root_run_id = None