codebook / potato /integrations /langchain_callback.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
12.2 kB
"""
LangChain Callback Handler for Potato
Automatically sends LangChain agent traces to a Potato instance
for human evaluation and annotation.
Requires: pip install langchain-core>=0.1.0
Usage:
from potato.integrations.langchain_callback import PotatoCallbackHandler
handler = PotatoCallbackHandler(potato_url="http://localhost:8000")
chain.invoke({"input": "..."}, config={"callbacks": [handler]})
"""
import json
import logging
import threading
import time
import uuid
from typing import Any, Dict, List, Optional, Sequence, Union
import requests
logger = logging.getLogger(__name__)
def _safe_serialize(obj: Any) -> Any:
"""Convert an object to a JSON-safe representation."""
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, dict):
return {str(k): _safe_serialize(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_safe_serialize(v) for v in obj]
# Fall back to string representation
try:
return str(obj)
except Exception:
return "<unserializable>"
class PotatoCallbackHandler:
"""
LangChain callback handler that sends completed traces to Potato.
Collects run events (chain, LLM, tool starts/ends), tracks
parent-child relationships, and POSTs the full trace to Potato's
webhook endpoint when the root chain completes.
The payload uses the LangSmith format expected by
``POST /api/traces/langsmith``.
Args:
potato_url: Base URL of the Potato server (e.g., ``http://localhost:8000``)
api_key: API key for authenticating with Potato's webhook endpoint
endpoint: Webhook path (default ``/api/traces/langsmith``)
send_timeout: HTTP timeout in seconds for the POST request (default 10)
metadata: Extra metadata dict attached to every trace
"""
def __init__(
self,
potato_url: str,
api_key: str = "",
endpoint: str = "/api/traces/langsmith",
send_timeout: int = 10,
metadata: Optional[Dict[str, Any]] = None,
):
self.potato_url = potato_url.rstrip("/")
self.api_key = api_key
self.endpoint = endpoint
self.send_timeout = send_timeout
self.extra_metadata = metadata or {}
# Run tracking — protected by lock for thread-safety
self._lock = threading.Lock()
self._runs: Dict[str, dict] = {} # run_id -> run dict
self._root_run_id: Optional[str] = None
self._pending_sends: List[threading.Thread] = []
# ------------------------------------------------------------------
# LangChain BaseCallbackHandler interface
# ------------------------------------------------------------------
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self._start_run(
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
run_type="chain",
name=serialized.get("name", serialized.get("id", ["unknown"])[-1]
if isinstance(serialized.get("id"), list) else "chain"),
inputs=inputs,
tags=tags,
metadata=metadata,
)
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> None:
self._end_run(str(run_id), outputs=outputs)
def on_chain_error(
self,
error: BaseException,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> None:
self._end_run(str(run_id), error=error)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self._start_run(
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
run_type="llm",
name=serialized.get("name", "llm"),
inputs={"prompts": prompts},
tags=tags,
metadata=metadata,
)
def on_llm_end(
self,
response: Any,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> None:
output = {}
if hasattr(response, "generations") and response.generations:
texts = []
for gen_list in response.generations:
for gen in gen_list:
texts.append(gen.text if hasattr(gen, "text") else str(gen))
output = {"text": "\n".join(texts)}
self._end_run(str(run_id), outputs=output)
def on_llm_error(
self,
error: BaseException,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> None:
self._end_run(str(run_id), error=error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self._start_run(
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
run_type="tool",
name=serialized.get("name", "tool"),
inputs={"input": input_str},
tags=tags,
metadata=metadata,
)
def on_tool_end(
self,
output: str,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> None:
self._end_run(str(run_id), outputs={"output": output})
def on_tool_error(
self,
error: BaseException,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> None:
self._end_run(str(run_id), error=error)
# Retriever callbacks
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self._start_run(
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
run_type="retriever",
name=serialized.get("name", "retriever"),
inputs={"query": query},
tags=tags,
metadata=metadata,
)
def on_retriever_end(
self,
documents: Any,
*,
run_id: uuid.UUID,
parent_run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> None:
output = {"documents": _safe_serialize(documents)}
self._end_run(str(run_id), outputs=output)
# Text callbacks (no-ops — captured by LLM callbacks)
def on_text(self, text: str, **kwargs: Any) -> None:
pass
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _start_run(
self,
run_id: str,
parent_run_id: Optional[str],
run_type: str,
name: str,
inputs: Any,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
run = {
"id": run_id,
"parent_run_id": parent_run_id,
"run_type": run_type,
"name": name,
"inputs": _safe_serialize(inputs),
"outputs": {},
"status": "running",
"start_time": time.time(),
"end_time": None,
"tags": tags or [],
"metadata": metadata or {},
}
with self._lock:
self._runs[run_id] = run
if parent_run_id is None:
self._root_run_id = run_id
def _end_run(
self,
run_id: str,
outputs: Optional[Dict[str, Any]] = None,
error: Optional[BaseException] = None,
) -> None:
with self._lock:
run = self._runs.get(run_id)
if run is None:
return
run["end_time"] = time.time()
run["latency"] = run["end_time"] - run["start_time"]
if error:
run["status"] = "error"
run["outputs"] = {"error": str(error)}
else:
run["status"] = "completed"
run["outputs"] = _safe_serialize(outputs or {})
# If this is the root run, send the trace
is_root = run_id == self._root_run_id
if is_root:
self._send_trace()
def _build_payload(self) -> dict:
"""Build a LangSmith-format payload from collected runs."""
with self._lock:
runs = list(self._runs.values())
root_id = self._root_run_id
# Find the root run for metadata
root_run = None
for r in runs:
if r["id"] == root_id:
root_run = r
break
payload = {
"runs": [
{
"id": r["id"],
"parent_run_id": r["parent_run_id"],
"run_type": r["run_type"],
"name": r["name"],
"inputs": r["inputs"],
"outputs": r["outputs"],
"status": r["status"],
"latency": r.get("latency"),
"tags": r.get("tags", []),
}
for r in runs
],
}
if root_run:
payload["project_name"] = root_run.get("name", "langchain")
return payload
def _send_trace(self) -> None:
"""POST the trace to Potato in a background thread."""
payload = self._build_payload()
def _do_send():
try:
url = f"{self.potato_url}{self.endpoint}"
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
resp = requests.post(
url,
json=payload,
headers=headers,
timeout=self.send_timeout,
)
if resp.status_code < 300:
logger.info("Trace sent to Potato: %s", resp.json())
else:
logger.warning(
"Potato returned %s: %s", resp.status_code, resp.text
)
except Exception as e:
logger.error("Failed to send trace to Potato: %s", e)
thread = threading.Thread(target=_do_send, daemon=True)
with self._lock:
self._pending_sends.append(thread)
thread.start()
def flush(self, timeout: float = 30.0) -> None:
"""Block until all pending sends complete (or timeout)."""
with self._lock:
threads = list(self._pending_sends)
for t in threads:
t.join(timeout=timeout)
with self._lock:
self._pending_sends = [t for t in self._pending_sends if t.is_alive()]
def reset(self) -> None:
"""Clear all collected runs (for reuse across multiple chains)."""
with self._lock:
self._runs.clear()
self._root_run_id = None