demo / pi_wrapper.py
IPF's picture
Upload 2 files
2a5889a verified
#!/usr/bin/env python3
"""
PI wrapper for HF Spaces.
Tries to use native pi-mono RPC if available; otherwise falls back to a
Python-based re-implementation that mimics PI's event stream and tool set
(read, bash) using the OpenAI function-calling API.
"""
from __future__ import annotations
import json
import os
import re
import shlex
import subprocess
import sys
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Generator, Iterator, List, Optional
import openai
try:
from charset_normalizer import from_bytes
except ImportError:
from_bytes = None
try:
from pypdf import PdfReader
except ImportError:
PdfReader = None
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
APP_DIR = Path(__file__).resolve().parent
if (APP_DIR / "pi-mono").exists():
REPO_ROOT = APP_DIR
else:
REPO_ROOT = APP_DIR.parent
DEFAULT_PI_REPO = REPO_ROOT / "pi-mono"
DEFAULT_PACKAGE_DIR = DEFAULT_PI_REPO / "packages" / "coding-agent"
DEFAULT_AGENT_DIR = DEFAULT_PI_REPO / ".pi" / "agent"
def _node_bin() -> str:
"""Find node >= 20."""
nvm_dir = Path(os.environ.get("NVM_DIR", Path.home() / ".nvm"))
versions_dir = nvm_dir / "versions" / "node"
if versions_dir.is_dir():
candidates = sorted(
(d for d in versions_dir.iterdir() if d.name.startswith("v")),
key=lambda d: tuple(int(x) for x in d.name.lstrip("v").split(".")),
reverse=True,
)
for candidate in candidates:
major = int(candidate.name.lstrip("v").split(".")[0])
node = candidate / "bin" / "node"
if major >= 20 and node.exists():
return str(node)
# Try system node
for candidate in ("node", "nodejs"):
try:
result = subprocess.run(
[candidate, "--version"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
version = result.stdout.strip().lstrip("v")
major = int(version.split(".")[0])
if major >= 20:
return candidate
except Exception:
pass
return "node"
def pi_mono_available() -> bool:
"""Check whether a built pi-mono package exists."""
dist_cli = DEFAULT_PACKAGE_DIR / "dist" / "cli.js"
return dist_cli.exists()
# ---------------------------------------------------------------------------
# Native PI RPC client (adapted from pi_rpc_runner.py)
# ---------------------------------------------------------------------------
class NativePiClient:
def __init__(
self,
*,
package_dir: Path,
cwd: Path,
agent_dir: Path,
provider: str,
model: str,
tools: str,
session_dir: Optional[Path] = None,
session_path: Optional[Path] = None,
continue_session: bool = False,
api_key: str = "",
) -> None:
self.package_dir = package_dir
self.cwd = cwd
self.agent_dir = agent_dir
self.provider = provider
self.model = model
self.tools = tools
self.session_dir = session_dir
self.session_path = session_path
self.continue_session = continue_session
self.api_key = api_key
self.proc: Optional[subprocess.Popen[bytes]] = None
self.stderr_chunks: List[str] = []
self._stderr_thread: Optional[threading.Thread] = None
self._request_id = 0
def _ensure_built_cli(self) -> Path:
dist_cli = self.package_dir / "dist" / "cli.js"
if dist_cli.exists():
return dist_cli
pi_repo_root = self.package_dir.parents[1]
sys.stderr.write("[setup] dist/cli.js not found, running `npm run build`\n")
sys.stderr.flush()
subprocess.run(
["npm", "run", "build"],
cwd=str(pi_repo_root),
check=True,
)
if not dist_cli.exists():
raise RuntimeError(f"Build completed but CLI not found at {dist_cli}")
return dist_cli
def _build_command(self) -> List[str]:
dist_cli = self._ensure_built_cli()
cmd = [_node_bin(), str(dist_cli), "--mode", "rpc"]
if self.provider:
cmd.extend(["--provider", self.provider])
if self.model:
cmd.extend(["--model", self.model])
if self.tools:
cmd.extend(["--tools", self.tools])
if self.session_dir:
cmd.extend(["--session-dir", str(self.session_dir)])
if self.session_path:
cmd.extend(["--session", str(self.session_path)])
elif self.continue_session:
cmd.append("--continue")
return cmd
def start(self) -> None:
if self.proc is not None:
raise RuntimeError("RPC client already started")
env = os.environ.copy()
env["PI_CODING_AGENT_DIR"] = str(self.agent_dir)
if self.api_key:
env["OPENAI_API_KEY"] = self.api_key
self.command = self._build_command()
self.proc = subprocess.Popen(
self.command,
cwd=str(self.cwd),
env=env,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
assert self.proc.stderr is not None
self._stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True)
self._stderr_thread.start()
def _drain_stderr(self) -> None:
assert self.proc is not None
assert self.proc.stderr is not None
for raw in self.proc.stderr:
self.stderr_chunks.append(raw.decode("utf-8", errors="replace"))
def stop(self) -> None:
if self.proc is None:
return
try:
if self.proc.poll() is None:
self.proc.terminate()
self.proc.wait(timeout=2)
except subprocess.TimeoutExpired:
self.proc.kill()
self.proc.wait(timeout=2)
finally:
self.proc = None
def _next_id(self) -> str:
self._request_id += 1
return f"py-{self._request_id}"
def _send(self, payload: Dict[str, Any]) -> None:
if self.proc is None or self.proc.stdin is None:
raise RuntimeError("RPC client is not running")
line = json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n"
self.proc.stdin.write(line.encode("utf-8"))
self.proc.stdin.flush()
def call(self, command_type: str, **payload: Any) -> Dict[str, Any]:
request_id = self._next_id()
message = {"id": request_id, "type": command_type, **payload}
self._send(message)
while True:
event = self._read_json_line()
if event.get("type") == "response" and event.get("id") == request_id:
if not event.get("success", False):
raise RuntimeError(f"RPC {command_type} failed: {event.get('error', 'unknown error')}")
return event
def _read_json_line(self) -> Dict[str, Any]:
if self.proc is None or self.proc.stdout is None:
raise RuntimeError("RPC client is not running")
raw = self.proc.stdout.readline()
if not raw:
stderr_text = "".join(self.stderr_chunks).strip()
raise RuntimeError(f"RPC process exited unexpectedly. stderr:\n{stderr_text}")
if raw.endswith(b"\n"):
raw = raw[:-1]
if raw.endswith(b"\r"):
raw = raw[:-1]
return json.loads(raw.decode("utf-8"))
def stream_events(
self,
message: str,
*,
max_turns: Optional[int] = None,
) -> Generator[Dict[str, Any], None, str]:
"""Yield raw PI events and finally return the assistant's full text."""
request_id = self._next_id()
self._send({"id": request_id, "type": "prompt", "message": message})
auxiliary_ids: set[str] = set()
text_parts: List[str] = []
prompt_ack = False
seen_turns = 0
sent_turn_limit_abort = False
while True:
event = self._read_json_line()
yield event
event_type = event.get("type")
if event_type == "response":
response_id = event.get("id")
if response_id == request_id:
if not event.get("success", False):
raise RuntimeError(f"RPC prompt failed: {event.get('error', 'unknown error')}")
prompt_ack = True
continue
if event_type == "turn_start":
seen_turns += 1
if max_turns is not None and seen_turns > max_turns and not sent_turn_limit_abort:
abort_id = self._next_id()
auxiliary_ids.add(abort_id)
self._send({"id": abort_id, "type": "abort"})
sent_turn_limit_abort = True
continue
if event_type == "message_update":
assistant_event = event.get("assistantMessageEvent", {})
if assistant_event.get("type") == "text_delta":
text_parts.append(assistant_event.get("delta", ""))
continue
if event_type == "agent_end":
if not prompt_ack:
raise RuntimeError("Received agent_end before prompt acknowledgement")
break
return "".join(text_parts)
# ---------------------------------------------------------------------------
# Fallback Python PI agent (OpenAI function calling)
# ---------------------------------------------------------------------------
SYSTEM_PROMPT_TEMPLATE = """You are Pi, an autonomous research assistant. Your goal is to answer the user's question by exploring documents in the working directory.
You have access to the following tools:
- read: Read the contents of a file.
- bash: Execute a shell command. Useful for listing files, searching with rg/grep/find, etc.
Guidelines:
1. Start by exploring the directory structure to understand what files are available.
2. Use targeted searches (rg, grep) to find relevant content.
3. Read files that seem relevant to the question.
4. Provide a clear, concise final answer based on the evidence found.
5. If the answer cannot be found, state that clearly.
Working directory: {cwd}
"""
def _read_tool(cwd: Path, file_path: str) -> str:
"""Read a file relative to cwd."""
try:
target = cwd / file_path
# Security: prevent reading outside cwd
target = target.resolve()
cwd_resolved = cwd.resolve()
if not str(target).startswith(str(cwd_resolved)):
return f"Error: Access denied – path '{file_path}' is outside the working directory."
if not target.exists():
return f"Error: File '{file_path}' does not exist."
if target.is_dir():
items = sorted(target.iterdir())
lines = [f"Directory: {file_path}/"]
for item in items[:50]:
suffix = "/" if item.is_dir() else ""
lines.append(f" {item.name}{suffix}")
if len(items) > 50:
lines.append(f" ... and {len(items) - 50} more items")
return "\n".join(lines)
if target.suffix.lower() == ".pdf":
if PdfReader is None:
return "Error: PDF parsing support is unavailable because pypdf is not installed."
reader = PdfReader(str(target))
page_texts: List[str] = []
for idx, page in enumerate(reader.pages, start=1):
text = (page.extract_text() or "").strip()
if text:
page_texts.append(f"[Page {idx}]\n{text}")
content = "\n\n".join(page_texts).strip()
else:
raw = target.read_bytes()
if from_bytes is not None:
best = from_bytes(raw).best()
if best is not None:
content = str(best)
else:
for encoding in ("utf-8", "utf-8-sig", "utf-16", "utf-16-le", "utf-16-be", "gb18030", "big5", "shift_jis"):
try:
content = raw.decode(encoding)
break
except UnicodeDecodeError:
continue
else:
content = raw.decode("utf-8", errors="replace")
else:
for encoding in ("utf-8", "utf-8-sig", "utf-16", "utf-16-le", "utf-16-be", "gb18030", "big5", "shift_jis"):
try:
content = raw.decode(encoding)
break
except UnicodeDecodeError:
continue
else:
content = raw.decode("utf-8", errors="replace")
max_len = 12000
if len(content) > max_len:
content = content[:max_len] + f"\n\n...[truncated, total {len(content)} chars]"
return content
except Exception as e:
return f"Error reading '{file_path}': {e}"
def _bash_tool(cwd: Path, command: str) -> str:
"""Run a shell command inside cwd."""
try:
# Basic security: block dangerous commands
dangerous = {"rm", "mv", "cp", "chmod", "chown", "sudo", "su", "ssh", "curl", "wget", "python", "python3", "node", "npm", "pip", "pip3", "docker", "kubectl"}
tokens = shlex.split(command)
if tokens and tokens[0] in dangerous:
return f"Error: Command '{tokens[0]}' is not allowed for security reasons."
result = subprocess.run(
command,
shell=True,
cwd=str(cwd),
capture_output=True,
text=True,
timeout=30,
)
output = result.stdout
if result.stderr:
output += "\n" + result.stderr
max_len = 8000
if len(output) > max_len:
output = output[:max_len] + f"\n\n...[truncated, total {len(output)} chars]"
return output
except subprocess.TimeoutExpired:
return "Error: Command timed out after 30 seconds."
except Exception as e:
return f"Error: {e}"
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def _chunk_text(text: str, *, size: int = 48) -> Iterator[str]:
if not text:
return
for i in range(0, len(text), size):
yield text[i : i + size]
class FallbackPiAgent:
"""A lightweight Python agent that mimics PI's behavior using OpenAI tool use."""
def __init__(
self,
*,
cwd: Path,
api_key: str,
model: str = "gpt-4o",
max_turns: int = 6,
):
self.cwd = cwd
self.api_key = api_key
self.model = model
self.max_turns = max_turns
self.client = openai.OpenAI(api_key=api_key)
def _mk_event(self, typ: str, **kwargs) -> Dict[str, Any]:
return {"type": typ, "timestamp": _utc_now(), **kwargs}
def stream_events(
self,
message: str,
) -> Generator[Dict[str, Any], None, str]:
system_prompt = SYSTEM_PROMPT_TEMPLATE.format(cwd=str(self.cwd.resolve()))
messages: List[Dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
]
tools = [
{
"type": "function",
"function": {
"name": "read",
"description": "Read a file or list a directory. Provide a relative path from the working directory.",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Relative path to the file or directory to read.",
},
},
"required": ["file_path"],
},
},
},
{
"type": "function",
"function": {
"name": "bash",
"description": "Execute a shell command in the working directory. Useful for searching with rg/grep/find, listing files, etc.",
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Shell command to execute. Keep it simple and focused.",
},
},
"required": ["command"],
},
},
},
]
final_text = ""
for turn in range(1, self.max_turns + 1):
yield self._mk_event("turn_start", turn=turn)
yield self._mk_event("message_start", message={"role": "assistant", "content": []})
response = self.client.chat.completions.create(
model=self.model,
messages=messages, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
tool_choice="auto",
temperature=0.2,
)
choice = response.choices[0]
msg = choice.message
# Stream assistant text if present
if msg.content:
final_text = msg.content
yield self._mk_event(
"message_update",
assistantMessageEvent={"type": "text_start", "contentIndex": 0},
)
for chunk in _chunk_text(msg.content):
yield self._mk_event(
"message_update",
assistantMessageEvent={"type": "text_delta", "contentIndex": 0, "delta": chunk},
)
yield self._mk_event(
"message_update",
assistantMessageEvent={"type": "text_end", "contentIndex": 0, "content": msg.content},
)
# Handle tool calls
if msg.tool_calls:
# Build single assistant message with all tool calls
assistant_msg: Dict[str, Any] = {
"role": "assistant",
"content": msg.content or None,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in msg.tool_calls
],
}
# Assistant message with tool_calls MUST come before tool results
messages.append(assistant_msg)
for tc in msg.tool_calls:
fn_name = tc.function.name
yield self._mk_event(
"message_update",
assistantMessageEvent={"type": "toolcall_start", "contentIndex": 0},
)
for chunk in _chunk_text(tc.function.arguments or ""):
yield self._mk_event(
"message_update",
assistantMessageEvent={
"type": "toolcall_delta",
"contentIndex": 0,
"delta": chunk,
},
)
yield self._mk_event(
"message_update",
assistantMessageEvent={
"type": "toolcall_end",
"contentIndex": 0,
"toolCall": {
"id": tc.id,
"name": fn_name,
"arguments": tc.function.arguments,
},
},
)
try:
fn_args = json.loads(tc.function.arguments)
except json.JSONDecodeError:
fn_args = {}
yield self._mk_event(
"tool_execution_start",
toolName=fn_name,
args=fn_args,
toolCallId=tc.id,
)
if fn_name == "read":
result = _read_tool(self.cwd, fn_args.get("file_path", ""))
elif fn_name == "bash":
result = _bash_tool(self.cwd, fn_args.get("command", ""))
else:
result = f"Error: Unknown tool '{fn_name}'"
is_error = result.startswith("Error:")
for chunk in _chunk_text(result):
yield self._mk_event(
"tool_execution_update",
toolName=fn_name,
toolCallId=tc.id,
output=chunk,
)
yield self._mk_event(
"tool_execution_end",
toolName=fn_name,
toolCallId=tc.id,
isError=is_error,
result=result,
)
messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": result,
})
yield self._mk_event("message_end", message={"role": "assistant", "content": []})
continue # next turn
# No tool calls → agent is done
yield self._mk_event("message_end", message={"role": "assistant", "content": msg.content or ""})
yield self._mk_event("agent_end")
return final_text or "(no answer provided)"
# Max turns reached — ask model to summarise with all gathered context
messages.append({
"role": "user",
"content": (
"You have used all available research turns. "
"Based on every file you have read and every search you have run, "
"please give your best final answer to the original question. "
"Be as complete and accurate as the evidence allows."
),
})
try:
summary = self.client.chat.completions.create(
model=self.model,
messages=messages, # type: ignore[arg-type]
temperature=0.2,
)
final_text = summary.choices[0].message.content or final_text
if final_text:
yield self._mk_event(
"message_update",
assistantMessageEvent={"type": "text_start", "contentIndex": 0},
)
for chunk in _chunk_text(final_text):
yield self._mk_event(
"message_update",
assistantMessageEvent={"type": "text_delta", "contentIndex": 0, "delta": chunk},
)
yield self._mk_event(
"message_update",
assistantMessageEvent={"type": "text_end", "contentIndex": 0, "content": final_text},
)
except Exception:
pass
yield self._mk_event("agent_end", note=f"Reached max_turns={self.max_turns}")
return final_text or "(max turns reached without final answer)"
# ---------------------------------------------------------------------------
# Unified runner
# ---------------------------------------------------------------------------
def run_pi(
question: str,
*,
cwd: Path,
api_key: str,
provider: str = "openai",
model: str = "gpt-4o",
max_turns: int = 6,
session_dir: Optional[Path] = None,
session_path: Optional[Path] = None,
continue_session: bool = False,
) -> Generator[Dict[str, Any], None, str]:
"""
Run PI (native if available, else fallback) and yield events.
Returns the final answer text via StopIteration.value.
"""
if pi_mono_available() and provider != "fallback":
client = NativePiClient(
package_dir=DEFAULT_PACKAGE_DIR,
cwd=cwd,
agent_dir=DEFAULT_AGENT_DIR,
provider=provider,
model=model,
tools="read,bash",
session_dir=session_dir,
session_path=session_path,
continue_session=continue_session,
api_key=api_key,
)
try:
client.start()
yield from client.stream_events(question, max_turns=max_turns)
finally:
client.stop()
# stream_events returns via StopIteration; we need to handle this
# Actually generator yield + return doesn't work well with yield from
# Let me refactor
else:
agent = FallbackPiAgent(
cwd=cwd,
api_key=api_key,
model=model,
max_turns=max_turns,
)
return (yield from agent.stream_events(question))
# Refactored to return (events_gen, final_answer_promise)
def run_pi_stream(
question: str,
*,
cwd: Path,
api_key: str,
provider: str = "openai",
model: str = "gpt-4o",
max_turns: int = 6,
session_dir: Optional[Path] = None,
session_path: Optional[Path] = None,
continue_session: bool = False,
) -> Iterator[Dict[str, Any]]:
"""
Iterator that yields PI events. The *last* item is a sentinel dict
``{"type": "__final__", "text": "..."}`` containing the answer.
"""
if pi_mono_available() and provider != "fallback":
client = NativePiClient(
package_dir=DEFAULT_PACKAGE_DIR,
cwd=cwd,
agent_dir=DEFAULT_AGENT_DIR,
provider=provider,
model=model,
tools="read,bash",
session_dir=session_dir,
session_path=session_path,
continue_session=continue_session,
api_key=api_key,
)
text_parts: List[str] = []
try:
client.start()
state_response = client.call("get_state")
state_data = state_response.get("data", {}) or {}
if state_data.get("thinkingLevel") == "off":
client.call("set_thinking_level", level="high")
state_response = client.call("get_state")
state_data = state_response.get("data", {}) or {}
yield {
"type": "__session__",
"runner": "Native PI",
"provider": provider,
"model": model,
"sessionFile": state_data.get("sessionFile"),
"sessionId": state_data.get("sessionId"),
"thinkingLevel": state_data.get("thinkingLevel"),
}
request_id = client._next_id()
client._send({"id": request_id, "type": "prompt", "message": question})
prompt_ack = False
seen_turns = 0
sent_abort = False
while True:
event = client._read_json_line()
yield event
et = event.get("type")
if et == "response" and event.get("id") == request_id:
if not event.get("success", True):
yield {"type": "error", "error": event.get("error", "RPC prompt failed")}
break
prompt_ack = True
continue
if et == "turn_start":
seen_turns += 1
if max_turns and seen_turns > max_turns and not sent_abort:
client._send({"id": client._next_id(), "type": "abort"})
sent_abort = True
continue
if et == "message_update":
ame = event.get("assistantMessageEvent", {})
if ame.get("type") == "text_delta":
text_parts.append(ame.get("delta", ""))
continue
if et == "agent_end":
break
final_state_response = client.call("get_state")
final_state_data = final_state_response.get("data", {}) or {}
yield {
"type": "__session__",
"runner": "Native PI",
"provider": provider,
"model": model,
"sessionFile": final_state_data.get("sessionFile"),
"sessionId": final_state_data.get("sessionId"),
"thinkingLevel": final_state_data.get("thinkingLevel"),
}
except Exception as exc:
yield {"type": "error", "error": str(exc)}
finally:
client.stop()
final_text = "".join(text_parts)
yield {"type": "__final__", "text": final_text}
else:
yield {
"type": "__session__",
"runner": "Fallback Agent",
"provider": provider,
"model": model,
"thinkingLevel": "unsupported",
}
agent = FallbackPiAgent(
cwd=cwd,
api_key=api_key,
model=model,
max_turns=max_turns,
)
text_parts: List[str] = []
try:
for event in agent.stream_events(question):
yield event
if event.get("type") == "message_update":
ame = event.get("assistantMessageEvent", {})
if ame.get("type") == "text_delta":
text_parts.append(ame.get("delta", ""))
except Exception as exc:
yield {"type": "error", "error": str(exc)}
yield {"type": "__final__", "text": "".join(text_parts)}