Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| PI wrapper for HF Spaces. | |
| Tries to use native pi-mono RPC if available; otherwise falls back to a | |
| Python-based re-implementation that mimics PI's event stream and tool set | |
| (read, bash) using the OpenAI function-calling API. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import shlex | |
| import subprocess | |
| import sys | |
| import threading | |
| import time | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, Dict, Generator, Iterator, List, Optional | |
| import openai | |
| try: | |
| from charset_normalizer import from_bytes | |
| except ImportError: | |
| from_bytes = None | |
| try: | |
| from pypdf import PdfReader | |
| except ImportError: | |
| PdfReader = None | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| APP_DIR = Path(__file__).resolve().parent | |
| if (APP_DIR / "pi-mono").exists(): | |
| REPO_ROOT = APP_DIR | |
| else: | |
| REPO_ROOT = APP_DIR.parent | |
| DEFAULT_PI_REPO = REPO_ROOT / "pi-mono" | |
| DEFAULT_PACKAGE_DIR = DEFAULT_PI_REPO / "packages" / "coding-agent" | |
| DEFAULT_AGENT_DIR = DEFAULT_PI_REPO / ".pi" / "agent" | |
| def _node_bin() -> str: | |
| """Find node >= 20.""" | |
| nvm_dir = Path(os.environ.get("NVM_DIR", Path.home() / ".nvm")) | |
| versions_dir = nvm_dir / "versions" / "node" | |
| if versions_dir.is_dir(): | |
| candidates = sorted( | |
| (d for d in versions_dir.iterdir() if d.name.startswith("v")), | |
| key=lambda d: tuple(int(x) for x in d.name.lstrip("v").split(".")), | |
| reverse=True, | |
| ) | |
| for candidate in candidates: | |
| major = int(candidate.name.lstrip("v").split(".")[0]) | |
| node = candidate / "bin" / "node" | |
| if major >= 20 and node.exists(): | |
| return str(node) | |
| # Try system node | |
| for candidate in ("node", "nodejs"): | |
| try: | |
| result = subprocess.run( | |
| [candidate, "--version"], | |
| capture_output=True, | |
| text=True, | |
| timeout=5, | |
| ) | |
| if result.returncode == 0: | |
| version = result.stdout.strip().lstrip("v") | |
| major = int(version.split(".")[0]) | |
| if major >= 20: | |
| return candidate | |
| except Exception: | |
| pass | |
| return "node" | |
| def pi_mono_available() -> bool: | |
| """Check whether a built pi-mono package exists.""" | |
| dist_cli = DEFAULT_PACKAGE_DIR / "dist" / "cli.js" | |
| return dist_cli.exists() | |
| # --------------------------------------------------------------------------- | |
| # Native PI RPC client (adapted from pi_rpc_runner.py) | |
| # --------------------------------------------------------------------------- | |
| class NativePiClient: | |
| def __init__( | |
| self, | |
| *, | |
| package_dir: Path, | |
| cwd: Path, | |
| agent_dir: Path, | |
| provider: str, | |
| model: str, | |
| tools: str, | |
| session_dir: Optional[Path] = None, | |
| session_path: Optional[Path] = None, | |
| continue_session: bool = False, | |
| api_key: str = "", | |
| ) -> None: | |
| self.package_dir = package_dir | |
| self.cwd = cwd | |
| self.agent_dir = agent_dir | |
| self.provider = provider | |
| self.model = model | |
| self.tools = tools | |
| self.session_dir = session_dir | |
| self.session_path = session_path | |
| self.continue_session = continue_session | |
| self.api_key = api_key | |
| self.proc: Optional[subprocess.Popen[bytes]] = None | |
| self.stderr_chunks: List[str] = [] | |
| self._stderr_thread: Optional[threading.Thread] = None | |
| self._request_id = 0 | |
| def _ensure_built_cli(self) -> Path: | |
| dist_cli = self.package_dir / "dist" / "cli.js" | |
| if dist_cli.exists(): | |
| return dist_cli | |
| pi_repo_root = self.package_dir.parents[1] | |
| sys.stderr.write("[setup] dist/cli.js not found, running `npm run build`\n") | |
| sys.stderr.flush() | |
| subprocess.run( | |
| ["npm", "run", "build"], | |
| cwd=str(pi_repo_root), | |
| check=True, | |
| ) | |
| if not dist_cli.exists(): | |
| raise RuntimeError(f"Build completed but CLI not found at {dist_cli}") | |
| return dist_cli | |
| def _build_command(self) -> List[str]: | |
| dist_cli = self._ensure_built_cli() | |
| cmd = [_node_bin(), str(dist_cli), "--mode", "rpc"] | |
| if self.provider: | |
| cmd.extend(["--provider", self.provider]) | |
| if self.model: | |
| cmd.extend(["--model", self.model]) | |
| if self.tools: | |
| cmd.extend(["--tools", self.tools]) | |
| if self.session_dir: | |
| cmd.extend(["--session-dir", str(self.session_dir)]) | |
| if self.session_path: | |
| cmd.extend(["--session", str(self.session_path)]) | |
| elif self.continue_session: | |
| cmd.append("--continue") | |
| return cmd | |
| def start(self) -> None: | |
| if self.proc is not None: | |
| raise RuntimeError("RPC client already started") | |
| env = os.environ.copy() | |
| env["PI_CODING_AGENT_DIR"] = str(self.agent_dir) | |
| if self.api_key: | |
| env["OPENAI_API_KEY"] = self.api_key | |
| self.command = self._build_command() | |
| self.proc = subprocess.Popen( | |
| self.command, | |
| cwd=str(self.cwd), | |
| env=env, | |
| stdin=subprocess.PIPE, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| assert self.proc.stderr is not None | |
| self._stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True) | |
| self._stderr_thread.start() | |
| def _drain_stderr(self) -> None: | |
| assert self.proc is not None | |
| assert self.proc.stderr is not None | |
| for raw in self.proc.stderr: | |
| self.stderr_chunks.append(raw.decode("utf-8", errors="replace")) | |
| def stop(self) -> None: | |
| if self.proc is None: | |
| return | |
| try: | |
| if self.proc.poll() is None: | |
| self.proc.terminate() | |
| self.proc.wait(timeout=2) | |
| except subprocess.TimeoutExpired: | |
| self.proc.kill() | |
| self.proc.wait(timeout=2) | |
| finally: | |
| self.proc = None | |
| def _next_id(self) -> str: | |
| self._request_id += 1 | |
| return f"py-{self._request_id}" | |
| def _send(self, payload: Dict[str, Any]) -> None: | |
| if self.proc is None or self.proc.stdin is None: | |
| raise RuntimeError("RPC client is not running") | |
| line = json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n" | |
| self.proc.stdin.write(line.encode("utf-8")) | |
| self.proc.stdin.flush() | |
| def call(self, command_type: str, **payload: Any) -> Dict[str, Any]: | |
| request_id = self._next_id() | |
| message = {"id": request_id, "type": command_type, **payload} | |
| self._send(message) | |
| while True: | |
| event = self._read_json_line() | |
| if event.get("type") == "response" and event.get("id") == request_id: | |
| if not event.get("success", False): | |
| raise RuntimeError(f"RPC {command_type} failed: {event.get('error', 'unknown error')}") | |
| return event | |
| def _read_json_line(self) -> Dict[str, Any]: | |
| if self.proc is None or self.proc.stdout is None: | |
| raise RuntimeError("RPC client is not running") | |
| raw = self.proc.stdout.readline() | |
| if not raw: | |
| stderr_text = "".join(self.stderr_chunks).strip() | |
| raise RuntimeError(f"RPC process exited unexpectedly. stderr:\n{stderr_text}") | |
| if raw.endswith(b"\n"): | |
| raw = raw[:-1] | |
| if raw.endswith(b"\r"): | |
| raw = raw[:-1] | |
| return json.loads(raw.decode("utf-8")) | |
| def stream_events( | |
| self, | |
| message: str, | |
| *, | |
| max_turns: Optional[int] = None, | |
| ) -> Generator[Dict[str, Any], None, str]: | |
| """Yield raw PI events and finally return the assistant's full text.""" | |
| request_id = self._next_id() | |
| self._send({"id": request_id, "type": "prompt", "message": message}) | |
| auxiliary_ids: set[str] = set() | |
| text_parts: List[str] = [] | |
| prompt_ack = False | |
| seen_turns = 0 | |
| sent_turn_limit_abort = False | |
| while True: | |
| event = self._read_json_line() | |
| yield event | |
| event_type = event.get("type") | |
| if event_type == "response": | |
| response_id = event.get("id") | |
| if response_id == request_id: | |
| if not event.get("success", False): | |
| raise RuntimeError(f"RPC prompt failed: {event.get('error', 'unknown error')}") | |
| prompt_ack = True | |
| continue | |
| if event_type == "turn_start": | |
| seen_turns += 1 | |
| if max_turns is not None and seen_turns > max_turns and not sent_turn_limit_abort: | |
| abort_id = self._next_id() | |
| auxiliary_ids.add(abort_id) | |
| self._send({"id": abort_id, "type": "abort"}) | |
| sent_turn_limit_abort = True | |
| continue | |
| if event_type == "message_update": | |
| assistant_event = event.get("assistantMessageEvent", {}) | |
| if assistant_event.get("type") == "text_delta": | |
| text_parts.append(assistant_event.get("delta", "")) | |
| continue | |
| if event_type == "agent_end": | |
| if not prompt_ack: | |
| raise RuntimeError("Received agent_end before prompt acknowledgement") | |
| break | |
| return "".join(text_parts) | |
| # --------------------------------------------------------------------------- | |
| # Fallback Python PI agent (OpenAI function calling) | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT_TEMPLATE = """You are Pi, an autonomous research assistant. Your goal is to answer the user's question by exploring documents in the working directory. | |
| You have access to the following tools: | |
| - read: Read the contents of a file. | |
| - bash: Execute a shell command. Useful for listing files, searching with rg/grep/find, etc. | |
| Guidelines: | |
| 1. Start by exploring the directory structure to understand what files are available. | |
| 2. Use targeted searches (rg, grep) to find relevant content. | |
| 3. Read files that seem relevant to the question. | |
| 4. Provide a clear, concise final answer based on the evidence found. | |
| 5. If the answer cannot be found, state that clearly. | |
| Working directory: {cwd} | |
| """ | |
| def _read_tool(cwd: Path, file_path: str) -> str: | |
| """Read a file relative to cwd.""" | |
| try: | |
| target = cwd / file_path | |
| # Security: prevent reading outside cwd | |
| target = target.resolve() | |
| cwd_resolved = cwd.resolve() | |
| if not str(target).startswith(str(cwd_resolved)): | |
| return f"Error: Access denied – path '{file_path}' is outside the working directory." | |
| if not target.exists(): | |
| return f"Error: File '{file_path}' does not exist." | |
| if target.is_dir(): | |
| items = sorted(target.iterdir()) | |
| lines = [f"Directory: {file_path}/"] | |
| for item in items[:50]: | |
| suffix = "/" if item.is_dir() else "" | |
| lines.append(f" {item.name}{suffix}") | |
| if len(items) > 50: | |
| lines.append(f" ... and {len(items) - 50} more items") | |
| return "\n".join(lines) | |
| if target.suffix.lower() == ".pdf": | |
| if PdfReader is None: | |
| return "Error: PDF parsing support is unavailable because pypdf is not installed." | |
| reader = PdfReader(str(target)) | |
| page_texts: List[str] = [] | |
| for idx, page in enumerate(reader.pages, start=1): | |
| text = (page.extract_text() or "").strip() | |
| if text: | |
| page_texts.append(f"[Page {idx}]\n{text}") | |
| content = "\n\n".join(page_texts).strip() | |
| else: | |
| raw = target.read_bytes() | |
| if from_bytes is not None: | |
| best = from_bytes(raw).best() | |
| if best is not None: | |
| content = str(best) | |
| else: | |
| for encoding in ("utf-8", "utf-8-sig", "utf-16", "utf-16-le", "utf-16-be", "gb18030", "big5", "shift_jis"): | |
| try: | |
| content = raw.decode(encoding) | |
| break | |
| except UnicodeDecodeError: | |
| continue | |
| else: | |
| content = raw.decode("utf-8", errors="replace") | |
| else: | |
| for encoding in ("utf-8", "utf-8-sig", "utf-16", "utf-16-le", "utf-16-be", "gb18030", "big5", "shift_jis"): | |
| try: | |
| content = raw.decode(encoding) | |
| break | |
| except UnicodeDecodeError: | |
| continue | |
| else: | |
| content = raw.decode("utf-8", errors="replace") | |
| max_len = 12000 | |
| if len(content) > max_len: | |
| content = content[:max_len] + f"\n\n...[truncated, total {len(content)} chars]" | |
| return content | |
| except Exception as e: | |
| return f"Error reading '{file_path}': {e}" | |
| def _bash_tool(cwd: Path, command: str) -> str: | |
| """Run a shell command inside cwd.""" | |
| try: | |
| # Basic security: block dangerous commands | |
| dangerous = {"rm", "mv", "cp", "chmod", "chown", "sudo", "su", "ssh", "curl", "wget", "python", "python3", "node", "npm", "pip", "pip3", "docker", "kubectl"} | |
| tokens = shlex.split(command) | |
| if tokens and tokens[0] in dangerous: | |
| return f"Error: Command '{tokens[0]}' is not allowed for security reasons." | |
| result = subprocess.run( | |
| command, | |
| shell=True, | |
| cwd=str(cwd), | |
| capture_output=True, | |
| text=True, | |
| timeout=30, | |
| ) | |
| output = result.stdout | |
| if result.stderr: | |
| output += "\n" + result.stderr | |
| max_len = 8000 | |
| if len(output) > max_len: | |
| output = output[:max_len] + f"\n\n...[truncated, total {len(output)} chars]" | |
| return output | |
| except subprocess.TimeoutExpired: | |
| return "Error: Command timed out after 30 seconds." | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def _utc_now() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| def _chunk_text(text: str, *, size: int = 48) -> Iterator[str]: | |
| if not text: | |
| return | |
| for i in range(0, len(text), size): | |
| yield text[i : i + size] | |
| class FallbackPiAgent: | |
| """A lightweight Python agent that mimics PI's behavior using OpenAI tool use.""" | |
| def __init__( | |
| self, | |
| *, | |
| cwd: Path, | |
| api_key: str, | |
| model: str = "gpt-4o", | |
| max_turns: int = 6, | |
| ): | |
| self.cwd = cwd | |
| self.api_key = api_key | |
| self.model = model | |
| self.max_turns = max_turns | |
| self.client = openai.OpenAI(api_key=api_key) | |
| def _mk_event(self, typ: str, **kwargs) -> Dict[str, Any]: | |
| return {"type": typ, "timestamp": _utc_now(), **kwargs} | |
| def stream_events( | |
| self, | |
| message: str, | |
| ) -> Generator[Dict[str, Any], None, str]: | |
| system_prompt = SYSTEM_PROMPT_TEMPLATE.format(cwd=str(self.cwd.resolve())) | |
| messages: List[Dict[str, Any]] = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": message}, | |
| ] | |
| tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "read", | |
| "description": "Read a file or list a directory. Provide a relative path from the working directory.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "file_path": { | |
| "type": "string", | |
| "description": "Relative path to the file or directory to read.", | |
| }, | |
| }, | |
| "required": ["file_path"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "bash", | |
| "description": "Execute a shell command in the working directory. Useful for searching with rg/grep/find, listing files, etc.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "command": { | |
| "type": "string", | |
| "description": "Shell command to execute. Keep it simple and focused.", | |
| }, | |
| }, | |
| "required": ["command"], | |
| }, | |
| }, | |
| }, | |
| ] | |
| final_text = "" | |
| for turn in range(1, self.max_turns + 1): | |
| yield self._mk_event("turn_start", turn=turn) | |
| yield self._mk_event("message_start", message={"role": "assistant", "content": []}) | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, # type: ignore[arg-type] | |
| tools=tools, # type: ignore[arg-type] | |
| tool_choice="auto", | |
| temperature=0.2, | |
| ) | |
| choice = response.choices[0] | |
| msg = choice.message | |
| # Stream assistant text if present | |
| if msg.content: | |
| final_text = msg.content | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={"type": "text_start", "contentIndex": 0}, | |
| ) | |
| for chunk in _chunk_text(msg.content): | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={"type": "text_delta", "contentIndex": 0, "delta": chunk}, | |
| ) | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={"type": "text_end", "contentIndex": 0, "content": msg.content}, | |
| ) | |
| # Handle tool calls | |
| if msg.tool_calls: | |
| # Build single assistant message with all tool calls | |
| assistant_msg: Dict[str, Any] = { | |
| "role": "assistant", | |
| "content": msg.content or None, | |
| "tool_calls": [ | |
| { | |
| "id": tc.id, | |
| "type": "function", | |
| "function": { | |
| "name": tc.function.name, | |
| "arguments": tc.function.arguments, | |
| }, | |
| } | |
| for tc in msg.tool_calls | |
| ], | |
| } | |
| # Assistant message with tool_calls MUST come before tool results | |
| messages.append(assistant_msg) | |
| for tc in msg.tool_calls: | |
| fn_name = tc.function.name | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={"type": "toolcall_start", "contentIndex": 0}, | |
| ) | |
| for chunk in _chunk_text(tc.function.arguments or ""): | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={ | |
| "type": "toolcall_delta", | |
| "contentIndex": 0, | |
| "delta": chunk, | |
| }, | |
| ) | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={ | |
| "type": "toolcall_end", | |
| "contentIndex": 0, | |
| "toolCall": { | |
| "id": tc.id, | |
| "name": fn_name, | |
| "arguments": tc.function.arguments, | |
| }, | |
| }, | |
| ) | |
| try: | |
| fn_args = json.loads(tc.function.arguments) | |
| except json.JSONDecodeError: | |
| fn_args = {} | |
| yield self._mk_event( | |
| "tool_execution_start", | |
| toolName=fn_name, | |
| args=fn_args, | |
| toolCallId=tc.id, | |
| ) | |
| if fn_name == "read": | |
| result = _read_tool(self.cwd, fn_args.get("file_path", "")) | |
| elif fn_name == "bash": | |
| result = _bash_tool(self.cwd, fn_args.get("command", "")) | |
| else: | |
| result = f"Error: Unknown tool '{fn_name}'" | |
| is_error = result.startswith("Error:") | |
| for chunk in _chunk_text(result): | |
| yield self._mk_event( | |
| "tool_execution_update", | |
| toolName=fn_name, | |
| toolCallId=tc.id, | |
| output=chunk, | |
| ) | |
| yield self._mk_event( | |
| "tool_execution_end", | |
| toolName=fn_name, | |
| toolCallId=tc.id, | |
| isError=is_error, | |
| result=result, | |
| ) | |
| messages.append({ | |
| "role": "tool", | |
| "tool_call_id": tc.id, | |
| "content": result, | |
| }) | |
| yield self._mk_event("message_end", message={"role": "assistant", "content": []}) | |
| continue # next turn | |
| # No tool calls → agent is done | |
| yield self._mk_event("message_end", message={"role": "assistant", "content": msg.content or ""}) | |
| yield self._mk_event("agent_end") | |
| return final_text or "(no answer provided)" | |
| # Max turns reached — ask model to summarise with all gathered context | |
| messages.append({ | |
| "role": "user", | |
| "content": ( | |
| "You have used all available research turns. " | |
| "Based on every file you have read and every search you have run, " | |
| "please give your best final answer to the original question. " | |
| "Be as complete and accurate as the evidence allows." | |
| ), | |
| }) | |
| try: | |
| summary = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, # type: ignore[arg-type] | |
| temperature=0.2, | |
| ) | |
| final_text = summary.choices[0].message.content or final_text | |
| if final_text: | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={"type": "text_start", "contentIndex": 0}, | |
| ) | |
| for chunk in _chunk_text(final_text): | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={"type": "text_delta", "contentIndex": 0, "delta": chunk}, | |
| ) | |
| yield self._mk_event( | |
| "message_update", | |
| assistantMessageEvent={"type": "text_end", "contentIndex": 0, "content": final_text}, | |
| ) | |
| except Exception: | |
| pass | |
| yield self._mk_event("agent_end", note=f"Reached max_turns={self.max_turns}") | |
| return final_text or "(max turns reached without final answer)" | |
| # --------------------------------------------------------------------------- | |
| # Unified runner | |
| # --------------------------------------------------------------------------- | |
| def run_pi( | |
| question: str, | |
| *, | |
| cwd: Path, | |
| api_key: str, | |
| provider: str = "openai", | |
| model: str = "gpt-4o", | |
| max_turns: int = 6, | |
| session_dir: Optional[Path] = None, | |
| session_path: Optional[Path] = None, | |
| continue_session: bool = False, | |
| ) -> Generator[Dict[str, Any], None, str]: | |
| """ | |
| Run PI (native if available, else fallback) and yield events. | |
| Returns the final answer text via StopIteration.value. | |
| """ | |
| if pi_mono_available() and provider != "fallback": | |
| client = NativePiClient( | |
| package_dir=DEFAULT_PACKAGE_DIR, | |
| cwd=cwd, | |
| agent_dir=DEFAULT_AGENT_DIR, | |
| provider=provider, | |
| model=model, | |
| tools="read,bash", | |
| session_dir=session_dir, | |
| session_path=session_path, | |
| continue_session=continue_session, | |
| api_key=api_key, | |
| ) | |
| try: | |
| client.start() | |
| yield from client.stream_events(question, max_turns=max_turns) | |
| finally: | |
| client.stop() | |
| # stream_events returns via StopIteration; we need to handle this | |
| # Actually generator yield + return doesn't work well with yield from | |
| # Let me refactor | |
| else: | |
| agent = FallbackPiAgent( | |
| cwd=cwd, | |
| api_key=api_key, | |
| model=model, | |
| max_turns=max_turns, | |
| ) | |
| return (yield from agent.stream_events(question)) | |
| # Refactored to return (events_gen, final_answer_promise) | |
| def run_pi_stream( | |
| question: str, | |
| *, | |
| cwd: Path, | |
| api_key: str, | |
| provider: str = "openai", | |
| model: str = "gpt-4o", | |
| max_turns: int = 6, | |
| session_dir: Optional[Path] = None, | |
| session_path: Optional[Path] = None, | |
| continue_session: bool = False, | |
| ) -> Iterator[Dict[str, Any]]: | |
| """ | |
| Iterator that yields PI events. The *last* item is a sentinel dict | |
| ``{"type": "__final__", "text": "..."}`` containing the answer. | |
| """ | |
| if pi_mono_available() and provider != "fallback": | |
| client = NativePiClient( | |
| package_dir=DEFAULT_PACKAGE_DIR, | |
| cwd=cwd, | |
| agent_dir=DEFAULT_AGENT_DIR, | |
| provider=provider, | |
| model=model, | |
| tools="read,bash", | |
| session_dir=session_dir, | |
| session_path=session_path, | |
| continue_session=continue_session, | |
| api_key=api_key, | |
| ) | |
| text_parts: List[str] = [] | |
| try: | |
| client.start() | |
| state_response = client.call("get_state") | |
| state_data = state_response.get("data", {}) or {} | |
| if state_data.get("thinkingLevel") == "off": | |
| client.call("set_thinking_level", level="high") | |
| state_response = client.call("get_state") | |
| state_data = state_response.get("data", {}) or {} | |
| yield { | |
| "type": "__session__", | |
| "runner": "Native PI", | |
| "provider": provider, | |
| "model": model, | |
| "sessionFile": state_data.get("sessionFile"), | |
| "sessionId": state_data.get("sessionId"), | |
| "thinkingLevel": state_data.get("thinkingLevel"), | |
| } | |
| request_id = client._next_id() | |
| client._send({"id": request_id, "type": "prompt", "message": question}) | |
| prompt_ack = False | |
| seen_turns = 0 | |
| sent_abort = False | |
| while True: | |
| event = client._read_json_line() | |
| yield event | |
| et = event.get("type") | |
| if et == "response" and event.get("id") == request_id: | |
| if not event.get("success", True): | |
| yield {"type": "error", "error": event.get("error", "RPC prompt failed")} | |
| break | |
| prompt_ack = True | |
| continue | |
| if et == "turn_start": | |
| seen_turns += 1 | |
| if max_turns and seen_turns > max_turns and not sent_abort: | |
| client._send({"id": client._next_id(), "type": "abort"}) | |
| sent_abort = True | |
| continue | |
| if et == "message_update": | |
| ame = event.get("assistantMessageEvent", {}) | |
| if ame.get("type") == "text_delta": | |
| text_parts.append(ame.get("delta", "")) | |
| continue | |
| if et == "agent_end": | |
| break | |
| final_state_response = client.call("get_state") | |
| final_state_data = final_state_response.get("data", {}) or {} | |
| yield { | |
| "type": "__session__", | |
| "runner": "Native PI", | |
| "provider": provider, | |
| "model": model, | |
| "sessionFile": final_state_data.get("sessionFile"), | |
| "sessionId": final_state_data.get("sessionId"), | |
| "thinkingLevel": final_state_data.get("thinkingLevel"), | |
| } | |
| except Exception as exc: | |
| yield {"type": "error", "error": str(exc)} | |
| finally: | |
| client.stop() | |
| final_text = "".join(text_parts) | |
| yield {"type": "__final__", "text": final_text} | |
| else: | |
| yield { | |
| "type": "__session__", | |
| "runner": "Fallback Agent", | |
| "provider": provider, | |
| "model": model, | |
| "thinkingLevel": "unsupported", | |
| } | |
| agent = FallbackPiAgent( | |
| cwd=cwd, | |
| api_key=api_key, | |
| model=model, | |
| max_turns=max_turns, | |
| ) | |
| text_parts: List[str] = [] | |
| try: | |
| for event in agent.stream_events(question): | |
| yield event | |
| if event.get("type") == "message_update": | |
| ame = event.get("assistantMessageEvent", {}) | |
| if ame.get("type") == "text_delta": | |
| text_parts.append(ame.get("delta", "")) | |
| except Exception as exc: | |
| yield {"type": "error", "error": str(exc)} | |
| yield {"type": "__final__", "text": "".join(text_parts)} | |