smolcode / engine /rust_session.py
seanpoyner's picture
Upload folder using huggingface_hub
daea45b verified
Raw
History Blame Contribute Delete
12.5 kB
"""Python facade over the Rust smolcode agent engine (smolcode_core)."""
from __future__ import annotations
import asyncio
import json
import os
import tempfile
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from .trace_collector import TraceCollector, TraceEvent
try:
import smolcode_core as _rust
except ImportError:
_rust = None # type: ignore
def rust_available() -> bool:
return _rust is not None
ApprovalHandler = Callable[[str], Awaitable[bool]]
@dataclass
class RustRunResult:
final: str
hit_max_steps: bool = False
errored: bool = False
class RustSession:
"""Thin wrapper around smolcode_core.Session."""
def __init__(
self,
*,
workspace: str | None = None,
agent: str = "build",
yolo: bool = False,
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
profile: str = "full",
approval_handler: ApprovalHandler | None = None,
) -> None:
if _rust is None:
raise RuntimeError(
"smolcode_core is not installed; build with "
"`maturin develop --release` in smolcode-cli/crates/smolcode-py"
)
if workspace is None:
workspace = os.environ.get(
"SMALLCODE_WORKSPACE",
tempfile.mkdtemp(prefix="smolcode-"),
)
self._session = _rust.Session(
workspace=workspace,
agent=agent,
yolo=yolo,
model=model,
base_url=base_url,
api_key=api_key,
profile=profile,
)
self.trace_collector = TraceCollector()
self.approval_handler = approval_handler
self.hit_max_steps = False
self.errored = False
self._steps: list[dict[str, Any]] = []
self._final: str = ""
self._cancelled = False
def request_cancel(self) -> None:
self._cancelled = True
self.cancel_turn()
@property
def cancelled(self) -> bool:
return self._cancelled
def clear_cancel(self) -> None:
self._cancelled = False
@property
def session_id(self) -> str:
return self._session.session_id
@property
def workspace_path(self) -> str:
return self._session.workspace()
def set_model(self, model: str) -> None:
self._session.set_model(model)
def set_agent(self, agent: str) -> None:
self._session.set_agent(agent)
def set_think(self, level: str) -> None:
self._session.set_think(level)
def register_tool(self, name: str, fn: Callable[[dict], dict]) -> None:
self._session.register_tool(name, fn)
def files(self) -> dict[str, str]:
out: dict[str, str] = {}
for path in self._session.workspace_files():
content = self._session.read_file(path)
if content is not None:
out[path] = content
return out
def run_shell(self, command: str) -> str:
return self._session.run_shell(command)
async def run(
self,
task: str,
*,
think: str | None = None,
yolo: bool | None = None,
) -> RustRunResult:
"""Run one agent turn to completion."""
self.hit_max_steps = False
self.errored = False
self._final = ""
self.clear_cancel()
self._session.start_turn(task, think=think, yolo=yolo)
final_text = ""
while True:
if self._cancelled:
break
ev = await asyncio.to_thread(self._session.poll_event)
if ev is None:
await asyncio.sleep(0.05)
continue
kind = ev.get("kind")
if kind == "approval":
approved = True
if self.approval_handler is not None:
approved = await self.approval_handler(ev.get("desc", ""))
elif not (yolo if yolo is not None else False):
approved = False
self._session.approve(approved)
continue
self._ingest_event(ev)
if kind == "final":
final_text = ev.get("text", "")
if kind == "done":
break
if kind == "error":
self.errored = True
self._final = final_text
if "step" in self._final.lower() and "without finishing" in self._final.lower():
self.hit_max_steps = True
self._session.record_turn(task, final_text)
return RustRunResult(
final=final_text,
hit_max_steps=self.hit_max_steps,
errored=self.errored,
)
async def poll_events_once(self) -> list[dict[str, Any]]:
"""Non-blocking poll for live UI updates during a turn."""
events: list[dict[str, Any]] = []
while True:
ev = await asyncio.to_thread(self._session.poll_event)
if ev is None:
break
kind = ev.get("kind")
if kind == "approval":
approved = True
if self.approval_handler is not None:
approved = await self.approval_handler(ev.get("desc", ""))
self._session.approve(approved)
continue
self._ingest_event(ev)
events.append(ev)
if kind in ("done",):
break
return events
def _ingest_event(self, ev: dict[str, Any]) -> None:
kind = ev.get("kind")
if kind == "tool_call":
args_raw = ev.get("args", "{}")
try:
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
except json.JSONDecodeError:
args = {"raw": args_raw}
self.trace_collector.record_tool_call(ev.get("name", ""), args)
elif kind == "tool_result":
text = ev.get("text", "")
try:
result = json.loads(text)
except json.JSONDecodeError:
result = {"output": text}
self.trace_collector.record_tool_result(ev.get("name", ""), result)
elif kind == "final":
self.trace_collector.record_final(ev.get("text", ""))
elif kind == "error":
self.trace_collector.record_error(ev.get("text", ""))
self.errored = True
def save(self) -> None:
self._session.save()
@staticmethod
def list_sessions() -> list[dict[str, Any]]:
if _rust is None:
return []
return _rust.Session.list_sessions()
def load_session(self, session_id: str) -> bool:
return self._session.load_session(session_id)
def fork(self) -> str | None:
return self._session.fork()
def rename(self, title: str) -> bool:
return self._session.rename(title)
def delete(self) -> bool:
return self._session.delete()
def cancel_turn(self) -> None:
self._session.cancel_turn()
def render_config(self) -> str:
return self._session.render_config()
def render_config(session: RustSession) -> str:
return session.render_config()
def apply_settings(session: RustSession, settings: Any) -> None:
"""Apply UI settings to a live Rust session before each agent turn.
The "auto" / "auto:<size>" pseudo-selections are NOT real model tags — the Router
picks the model and sets it on the session (see router.run_live), so we must not
push them via set_model. Only concrete pins are applied here.
"""
session.set_think(settings.think)
model = settings.model or ""
if model and model != "auto" and not model.startswith("auto:"):
session.set_model(model)
session.set_agent(settings.agent)
def list_commands(workspace: str) -> list[str]:
if _rust is None:
return []
return _rust.list_commands(workspace)
def expand_command(workspace: str, name: str, args: str = "") -> str | None:
if _rust is None:
return None
return _rust.expand_command(workspace, name, args)
def list_rules(workspace: str) -> list[dict[str, Any]]:
if _rust is None:
return []
return _rust.list_rules(workspace)
def list_skills(workspace: str) -> list[dict[str, Any]]:
if _rust is None:
return []
return _rust.list_skills(workspace)
def expand_skill(workspace: str, name: str, args: str = "") -> str | None:
if _rust is None:
return None
return _rust.expand_skill(workspace, name, args)
def list_mcp(session: RustSession) -> list[dict[str, Any]]:
return session._session.list_mcp()
def list_background_jobs() -> str:
if _rust is None:
return ""
return _rust.list_background_jobs()
def write_agents_md(workspace: str) -> str:
if _rust is None:
raise RuntimeError("smolcode_core not installed")
return _rust.write_agents_md(workspace)
def git_status(workspace: str) -> str:
if _rust is None:
return ""
return _rust.git_status(workspace)
def workspace_tree(workspace: str, depth: int = 3) -> str:
if _rust is None:
return ""
return _rust.workspace_tree(workspace, depth=depth)
UI_FILE_LIMIT = 1500
AUTOCOMPLETE_FILE_LIMIT = 200
ATTACH_FILE_MAX_BYTES = 8192
def read_workspace_file(
workspace: str,
path: str,
*,
max_bytes: int = ATTACH_FILE_MAX_BYTES,
rust: RustSession | None = None,
) -> str | None:
"""Read a workspace file for @-attachment inlining. Returns None if missing."""
if _rust is None:
return None
try:
session = rust if rust is not None else RustSession(workspace=workspace, yolo=True)
content = session._session.read_file(path)
if content is None:
return None
if len(content) > max_bytes:
return content[:max_bytes] + "\n… (truncated)"
return content
except Exception:
return None
def workspace_paths(workspace: str, *, limit: int = UI_FILE_LIMIT) -> tuple[list[str], int]:
"""Workspace paths for UI sidebars (no file reads). Returns (paths, total_count)."""
if _rust is None:
return [], 0
session = RustSession(workspace=workspace, yolo=True)
paths = sorted(session._session.workspace_files())
total = len(paths)
if total > limit:
paths = paths[:limit]
return paths, total
def workspace_files(workspace: str) -> dict[str, str]:
session = RustSession(workspace=workspace, yolo=True)
return session.files()
def export_transcript(session_id: str, path: str | None = None) -> str:
if _rust is None:
raise RuntimeError("smolcode_core not installed")
return _rust.export_transcript(session_id, path)
def session_timeline(session_id: str) -> list[str]:
if _rust is None:
return []
return _rust.session_timeline(session_id)
def get_session_chat(session_id: str) -> list[dict[str, str]]:
if _rust is None:
return []
return _rust.get_session_chat(session_id)
def chat_from_stored(lines: list[dict[str, str]]) -> list[dict[str, str]]:
"""Convert stored session lines to Gradio chat messages."""
out: list[dict[str, str]] = []
for m in lines:
role = m.get("role", "assistant")
text = m.get("text", "")
if role == "user":
out.append({"role": "user", "content": text})
else:
out.append({"role": "assistant", "content": text})
return out
def session_choices() -> list[str]:
"""Dropdown labels: `title (id)`."""
return [
f"{r['title']} ({r['id']})"
for r in RustSession.list_sessions()
]
def parse_session_label(label: str) -> str | None:
if not label or "(" not in label:
return None
return label.rsplit("(", 1)[-1].rstrip(")")
def load_rust_config(
*,
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
agent: str | None = None,
yolo: bool = False,
) -> dict[str, Any]:
"""Load layered config.toml via Rust Config."""
if _rust is None:
return {}
cfg = _rust.Config.load(
model=model,
base_url=base_url,
api_key=api_key,
agent=agent,
yolo=yolo,
)
return {
"model": cfg.model,
"base_url": cfg.base_url,
"agent": cfg.agent,
"yolo": cfg.yolo,
}