#!/usr/bin/env python3 """ PI wrapper for HF Spaces. Tries to use native pi-mono RPC if available; otherwise falls back to a Python-based re-implementation that mimics PI's event stream and tool set (read, bash) using the OpenAI function-calling API. """ from __future__ import annotations import json import os import re import shlex import subprocess import sys import threading import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, Generator, Iterator, List, Optional import openai try: from charset_normalizer import from_bytes except ImportError: from_bytes = None try: from pypdf import PdfReader except ImportError: PdfReader = None # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- APP_DIR = Path(__file__).resolve().parent if (APP_DIR / "pi-mono").exists(): REPO_ROOT = APP_DIR else: REPO_ROOT = APP_DIR.parent DEFAULT_PI_REPO = REPO_ROOT / "pi-mono" DEFAULT_PACKAGE_DIR = DEFAULT_PI_REPO / "packages" / "coding-agent" DEFAULT_AGENT_DIR = DEFAULT_PI_REPO / ".pi" / "agent" def _node_bin() -> str: """Find node >= 20.""" nvm_dir = Path(os.environ.get("NVM_DIR", Path.home() / ".nvm")) versions_dir = nvm_dir / "versions" / "node" if versions_dir.is_dir(): candidates = sorted( (d for d in versions_dir.iterdir() if d.name.startswith("v")), key=lambda d: tuple(int(x) for x in d.name.lstrip("v").split(".")), reverse=True, ) for candidate in candidates: major = int(candidate.name.lstrip("v").split(".")[0]) node = candidate / "bin" / "node" if major >= 20 and node.exists(): return str(node) # Try system node for candidate in ("node", "nodejs"): try: result = subprocess.run( [candidate, "--version"], capture_output=True, text=True, timeout=5, ) if result.returncode == 0: version = result.stdout.strip().lstrip("v") major = int(version.split(".")[0]) if major >= 20: return candidate except Exception: pass return "node" def pi_mono_available() -> bool: """Check whether a built pi-mono package exists.""" dist_cli = DEFAULT_PACKAGE_DIR / "dist" / "cli.js" return dist_cli.exists() # --------------------------------------------------------------------------- # Native PI RPC client (adapted from pi_rpc_runner.py) # --------------------------------------------------------------------------- class NativePiClient: def __init__( self, *, package_dir: Path, cwd: Path, agent_dir: Path, provider: str, model: str, tools: str, session_dir: Optional[Path] = None, session_path: Optional[Path] = None, continue_session: bool = False, api_key: str = "", ) -> None: self.package_dir = package_dir self.cwd = cwd self.agent_dir = agent_dir self.provider = provider self.model = model self.tools = tools self.session_dir = session_dir self.session_path = session_path self.continue_session = continue_session self.api_key = api_key self.proc: Optional[subprocess.Popen[bytes]] = None self.stderr_chunks: List[str] = [] self._stderr_thread: Optional[threading.Thread] = None self._request_id = 0 def _ensure_built_cli(self) -> Path: dist_cli = self.package_dir / "dist" / "cli.js" if dist_cli.exists(): return dist_cli pi_repo_root = self.package_dir.parents[1] sys.stderr.write("[setup] dist/cli.js not found, running `npm run build`\n") sys.stderr.flush() subprocess.run( ["npm", "run", "build"], cwd=str(pi_repo_root), check=True, ) if not dist_cli.exists(): raise RuntimeError(f"Build completed but CLI not found at {dist_cli}") return dist_cli def _build_command(self) -> List[str]: dist_cli = self._ensure_built_cli() cmd = [_node_bin(), str(dist_cli), "--mode", "rpc"] if self.provider: cmd.extend(["--provider", self.provider]) if self.model: cmd.extend(["--model", self.model]) if self.tools: cmd.extend(["--tools", self.tools]) if self.session_dir: cmd.extend(["--session-dir", str(self.session_dir)]) if self.session_path: cmd.extend(["--session", str(self.session_path)]) elif self.continue_session: cmd.append("--continue") return cmd def start(self) -> None: if self.proc is not None: raise RuntimeError("RPC client already started") env = os.environ.copy() env["PI_CODING_AGENT_DIR"] = str(self.agent_dir) if self.api_key: env["OPENAI_API_KEY"] = self.api_key self.command = self._build_command() self.proc = subprocess.Popen( self.command, cwd=str(self.cwd), env=env, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) assert self.proc.stderr is not None self._stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True) self._stderr_thread.start() def _drain_stderr(self) -> None: assert self.proc is not None assert self.proc.stderr is not None for raw in self.proc.stderr: self.stderr_chunks.append(raw.decode("utf-8", errors="replace")) def stop(self) -> None: if self.proc is None: return try: if self.proc.poll() is None: self.proc.terminate() self.proc.wait(timeout=2) except subprocess.TimeoutExpired: self.proc.kill() self.proc.wait(timeout=2) finally: self.proc = None def _next_id(self) -> str: self._request_id += 1 return f"py-{self._request_id}" def _send(self, payload: Dict[str, Any]) -> None: if self.proc is None or self.proc.stdin is None: raise RuntimeError("RPC client is not running") line = json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n" self.proc.stdin.write(line.encode("utf-8")) self.proc.stdin.flush() def call(self, command_type: str, **payload: Any) -> Dict[str, Any]: request_id = self._next_id() message = {"id": request_id, "type": command_type, **payload} self._send(message) while True: event = self._read_json_line() if event.get("type") == "response" and event.get("id") == request_id: if not event.get("success", False): raise RuntimeError(f"RPC {command_type} failed: {event.get('error', 'unknown error')}") return event def _read_json_line(self) -> Dict[str, Any]: if self.proc is None or self.proc.stdout is None: raise RuntimeError("RPC client is not running") raw = self.proc.stdout.readline() if not raw: stderr_text = "".join(self.stderr_chunks).strip() raise RuntimeError(f"RPC process exited unexpectedly. stderr:\n{stderr_text}") if raw.endswith(b"\n"): raw = raw[:-1] if raw.endswith(b"\r"): raw = raw[:-1] return json.loads(raw.decode("utf-8")) def stream_events( self, message: str, *, max_turns: Optional[int] = None, ) -> Generator[Dict[str, Any], None, str]: """Yield raw PI events and finally return the assistant's full text.""" request_id = self._next_id() self._send({"id": request_id, "type": "prompt", "message": message}) auxiliary_ids: set[str] = set() text_parts: List[str] = [] prompt_ack = False seen_turns = 0 sent_turn_limit_abort = False while True: event = self._read_json_line() yield event event_type = event.get("type") if event_type == "response": response_id = event.get("id") if response_id == request_id: if not event.get("success", False): raise RuntimeError(f"RPC prompt failed: {event.get('error', 'unknown error')}") prompt_ack = True continue if event_type == "turn_start": seen_turns += 1 if max_turns is not None and seen_turns > max_turns and not sent_turn_limit_abort: abort_id = self._next_id() auxiliary_ids.add(abort_id) self._send({"id": abort_id, "type": "abort"}) sent_turn_limit_abort = True continue if event_type == "message_update": assistant_event = event.get("assistantMessageEvent", {}) if assistant_event.get("type") == "text_delta": text_parts.append(assistant_event.get("delta", "")) continue if event_type == "agent_end": if not prompt_ack: raise RuntimeError("Received agent_end before prompt acknowledgement") break return "".join(text_parts) # --------------------------------------------------------------------------- # Fallback Python PI agent (OpenAI function calling) # --------------------------------------------------------------------------- SYSTEM_PROMPT_TEMPLATE = """You are Pi, an autonomous research assistant. Your goal is to answer the user's question by exploring documents in the working directory. You have access to the following tools: - read: Read the contents of a file. - bash: Execute a shell command. Useful for listing files, searching with rg/grep/find, etc. Guidelines: 1. Start by exploring the directory structure to understand what files are available. 2. Use targeted searches (rg, grep) to find relevant content. 3. Read files that seem relevant to the question. 4. Provide a clear, concise final answer based on the evidence found. 5. If the answer cannot be found, state that clearly. Working directory: {cwd} """ def _read_tool(cwd: Path, file_path: str) -> str: """Read a file relative to cwd.""" try: target = cwd / file_path # Security: prevent reading outside cwd target = target.resolve() cwd_resolved = cwd.resolve() if not str(target).startswith(str(cwd_resolved)): return f"Error: Access denied – path '{file_path}' is outside the working directory." if not target.exists(): return f"Error: File '{file_path}' does not exist." if target.is_dir(): items = sorted(target.iterdir()) lines = [f"Directory: {file_path}/"] for item in items[:50]: suffix = "/" if item.is_dir() else "" lines.append(f" {item.name}{suffix}") if len(items) > 50: lines.append(f" ... and {len(items) - 50} more items") return "\n".join(lines) if target.suffix.lower() == ".pdf": if PdfReader is None: return "Error: PDF parsing support is unavailable because pypdf is not installed." reader = PdfReader(str(target)) page_texts: List[str] = [] for idx, page in enumerate(reader.pages, start=1): text = (page.extract_text() or "").strip() if text: page_texts.append(f"[Page {idx}]\n{text}") content = "\n\n".join(page_texts).strip() else: raw = target.read_bytes() if from_bytes is not None: best = from_bytes(raw).best() if best is not None: content = str(best) else: for encoding in ("utf-8", "utf-8-sig", "utf-16", "utf-16-le", "utf-16-be", "gb18030", "big5", "shift_jis"): try: content = raw.decode(encoding) break except UnicodeDecodeError: continue else: content = raw.decode("utf-8", errors="replace") else: for encoding in ("utf-8", "utf-8-sig", "utf-16", "utf-16-le", "utf-16-be", "gb18030", "big5", "shift_jis"): try: content = raw.decode(encoding) break except UnicodeDecodeError: continue else: content = raw.decode("utf-8", errors="replace") max_len = 12000 if len(content) > max_len: content = content[:max_len] + f"\n\n...[truncated, total {len(content)} chars]" return content except Exception as e: return f"Error reading '{file_path}': {e}" def _bash_tool(cwd: Path, command: str) -> str: """Run a shell command inside cwd.""" try: # Basic security: block dangerous commands dangerous = {"rm", "mv", "cp", "chmod", "chown", "sudo", "su", "ssh", "curl", "wget", "python", "python3", "node", "npm", "pip", "pip3", "docker", "kubectl"} tokens = shlex.split(command) if tokens and tokens[0] in dangerous: return f"Error: Command '{tokens[0]}' is not allowed for security reasons." result = subprocess.run( command, shell=True, cwd=str(cwd), capture_output=True, text=True, timeout=30, ) output = result.stdout if result.stderr: output += "\n" + result.stderr max_len = 8000 if len(output) > max_len: output = output[:max_len] + f"\n\n...[truncated, total {len(output)} chars]" return output except subprocess.TimeoutExpired: return "Error: Command timed out after 30 seconds." except Exception as e: return f"Error: {e}" def _utc_now() -> str: return datetime.now(timezone.utc).isoformat() def _chunk_text(text: str, *, size: int = 48) -> Iterator[str]: if not text: return for i in range(0, len(text), size): yield text[i : i + size] class FallbackPiAgent: """A lightweight Python agent that mimics PI's behavior using OpenAI tool use.""" def __init__( self, *, cwd: Path, api_key: str, model: str = "gpt-4o", max_turns: int = 6, ): self.cwd = cwd self.api_key = api_key self.model = model self.max_turns = max_turns self.client = openai.OpenAI(api_key=api_key) def _mk_event(self, typ: str, **kwargs) -> Dict[str, Any]: return {"type": typ, "timestamp": _utc_now(), **kwargs} def stream_events( self, message: str, ) -> Generator[Dict[str, Any], None, str]: system_prompt = SYSTEM_PROMPT_TEMPLATE.format(cwd=str(self.cwd.resolve())) messages: List[Dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": message}, ] tools = [ { "type": "function", "function": { "name": "read", "description": "Read a file or list a directory. Provide a relative path from the working directory.", "parameters": { "type": "object", "properties": { "file_path": { "type": "string", "description": "Relative path to the file or directory to read.", }, }, "required": ["file_path"], }, }, }, { "type": "function", "function": { "name": "bash", "description": "Execute a shell command in the working directory. Useful for searching with rg/grep/find, listing files, etc.", "parameters": { "type": "object", "properties": { "command": { "type": "string", "description": "Shell command to execute. Keep it simple and focused.", }, }, "required": ["command"], }, }, }, ] final_text = "" for turn in range(1, self.max_turns + 1): yield self._mk_event("turn_start", turn=turn) yield self._mk_event("message_start", message={"role": "assistant", "content": []}) response = self.client.chat.completions.create( model=self.model, messages=messages, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type] tool_choice="auto", temperature=0.2, ) choice = response.choices[0] msg = choice.message # Stream assistant text if present if msg.content: final_text = msg.content yield self._mk_event( "message_update", assistantMessageEvent={"type": "text_start", "contentIndex": 0}, ) for chunk in _chunk_text(msg.content): yield self._mk_event( "message_update", assistantMessageEvent={"type": "text_delta", "contentIndex": 0, "delta": chunk}, ) yield self._mk_event( "message_update", assistantMessageEvent={"type": "text_end", "contentIndex": 0, "content": msg.content}, ) # Handle tool calls if msg.tool_calls: # Build single assistant message with all tool calls assistant_msg: Dict[str, Any] = { "role": "assistant", "content": msg.content or None, "tool_calls": [ { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments, }, } for tc in msg.tool_calls ], } # Assistant message with tool_calls MUST come before tool results messages.append(assistant_msg) for tc in msg.tool_calls: fn_name = tc.function.name yield self._mk_event( "message_update", assistantMessageEvent={"type": "toolcall_start", "contentIndex": 0}, ) for chunk in _chunk_text(tc.function.arguments or ""): yield self._mk_event( "message_update", assistantMessageEvent={ "type": "toolcall_delta", "contentIndex": 0, "delta": chunk, }, ) yield self._mk_event( "message_update", assistantMessageEvent={ "type": "toolcall_end", "contentIndex": 0, "toolCall": { "id": tc.id, "name": fn_name, "arguments": tc.function.arguments, }, }, ) try: fn_args = json.loads(tc.function.arguments) except json.JSONDecodeError: fn_args = {} yield self._mk_event( "tool_execution_start", toolName=fn_name, args=fn_args, toolCallId=tc.id, ) if fn_name == "read": result = _read_tool(self.cwd, fn_args.get("file_path", "")) elif fn_name == "bash": result = _bash_tool(self.cwd, fn_args.get("command", "")) else: result = f"Error: Unknown tool '{fn_name}'" is_error = result.startswith("Error:") for chunk in _chunk_text(result): yield self._mk_event( "tool_execution_update", toolName=fn_name, toolCallId=tc.id, output=chunk, ) yield self._mk_event( "tool_execution_end", toolName=fn_name, toolCallId=tc.id, isError=is_error, result=result, ) messages.append({ "role": "tool", "tool_call_id": tc.id, "content": result, }) yield self._mk_event("message_end", message={"role": "assistant", "content": []}) continue # next turn # No tool calls → agent is done yield self._mk_event("message_end", message={"role": "assistant", "content": msg.content or ""}) yield self._mk_event("agent_end") return final_text or "(no answer provided)" # Max turns reached — ask model to summarise with all gathered context messages.append({ "role": "user", "content": ( "You have used all available research turns. " "Based on every file you have read and every search you have run, " "please give your best final answer to the original question. " "Be as complete and accurate as the evidence allows." ), }) try: summary = self.client.chat.completions.create( model=self.model, messages=messages, # type: ignore[arg-type] temperature=0.2, ) final_text = summary.choices[0].message.content or final_text if final_text: yield self._mk_event( "message_update", assistantMessageEvent={"type": "text_start", "contentIndex": 0}, ) for chunk in _chunk_text(final_text): yield self._mk_event( "message_update", assistantMessageEvent={"type": "text_delta", "contentIndex": 0, "delta": chunk}, ) yield self._mk_event( "message_update", assistantMessageEvent={"type": "text_end", "contentIndex": 0, "content": final_text}, ) except Exception: pass yield self._mk_event("agent_end", note=f"Reached max_turns={self.max_turns}") return final_text or "(max turns reached without final answer)" # --------------------------------------------------------------------------- # Unified runner # --------------------------------------------------------------------------- def run_pi( question: str, *, cwd: Path, api_key: str, provider: str = "openai", model: str = "gpt-4o", max_turns: int = 6, session_dir: Optional[Path] = None, session_path: Optional[Path] = None, continue_session: bool = False, ) -> Generator[Dict[str, Any], None, str]: """ Run PI (native if available, else fallback) and yield events. Returns the final answer text via StopIteration.value. """ if pi_mono_available() and provider != "fallback": client = NativePiClient( package_dir=DEFAULT_PACKAGE_DIR, cwd=cwd, agent_dir=DEFAULT_AGENT_DIR, provider=provider, model=model, tools="read,bash", session_dir=session_dir, session_path=session_path, continue_session=continue_session, api_key=api_key, ) try: client.start() yield from client.stream_events(question, max_turns=max_turns) finally: client.stop() # stream_events returns via StopIteration; we need to handle this # Actually generator yield + return doesn't work well with yield from # Let me refactor else: agent = FallbackPiAgent( cwd=cwd, api_key=api_key, model=model, max_turns=max_turns, ) return (yield from agent.stream_events(question)) # Refactored to return (events_gen, final_answer_promise) def run_pi_stream( question: str, *, cwd: Path, api_key: str, provider: str = "openai", model: str = "gpt-4o", max_turns: int = 6, session_dir: Optional[Path] = None, session_path: Optional[Path] = None, continue_session: bool = False, ) -> Iterator[Dict[str, Any]]: """ Iterator that yields PI events. The *last* item is a sentinel dict ``{"type": "__final__", "text": "..."}`` containing the answer. """ if pi_mono_available() and provider != "fallback": client = NativePiClient( package_dir=DEFAULT_PACKAGE_DIR, cwd=cwd, agent_dir=DEFAULT_AGENT_DIR, provider=provider, model=model, tools="read,bash", session_dir=session_dir, session_path=session_path, continue_session=continue_session, api_key=api_key, ) text_parts: List[str] = [] try: client.start() state_response = client.call("get_state") state_data = state_response.get("data", {}) or {} if state_data.get("thinkingLevel") == "off": client.call("set_thinking_level", level="high") state_response = client.call("get_state") state_data = state_response.get("data", {}) or {} yield { "type": "__session__", "runner": "Native PI", "provider": provider, "model": model, "sessionFile": state_data.get("sessionFile"), "sessionId": state_data.get("sessionId"), "thinkingLevel": state_data.get("thinkingLevel"), } request_id = client._next_id() client._send({"id": request_id, "type": "prompt", "message": question}) prompt_ack = False seen_turns = 0 sent_abort = False while True: event = client._read_json_line() yield event et = event.get("type") if et == "response" and event.get("id") == request_id: if not event.get("success", True): yield {"type": "error", "error": event.get("error", "RPC prompt failed")} break prompt_ack = True continue if et == "turn_start": seen_turns += 1 if max_turns and seen_turns > max_turns and not sent_abort: client._send({"id": client._next_id(), "type": "abort"}) sent_abort = True continue if et == "message_update": ame = event.get("assistantMessageEvent", {}) if ame.get("type") == "text_delta": text_parts.append(ame.get("delta", "")) continue if et == "agent_end": break final_state_response = client.call("get_state") final_state_data = final_state_response.get("data", {}) or {} yield { "type": "__session__", "runner": "Native PI", "provider": provider, "model": model, "sessionFile": final_state_data.get("sessionFile"), "sessionId": final_state_data.get("sessionId"), "thinkingLevel": final_state_data.get("thinkingLevel"), } except Exception as exc: yield {"type": "error", "error": str(exc)} finally: client.stop() final_text = "".join(text_parts) yield {"type": "__final__", "text": final_text} else: yield { "type": "__session__", "runner": "Fallback Agent", "provider": provider, "model": model, "thinkingLevel": "unsupported", } agent = FallbackPiAgent( cwd=cwd, api_key=api_key, model=model, max_turns=max_turns, ) text_parts: List[str] = [] try: for event in agent.stream_events(question): yield event if event.get("type") == "message_update": ame = event.get("assistantMessageEvent", {}) if ame.get("type") == "text_delta": text_parts.append(ame.get("delta", "")) except Exception as exc: yield {"type": "error", "error": str(exc)} yield {"type": "__final__", "text": "".join(text_parts)}