smolcode / engine /agent.py
seanpoyner's picture
Upload folder using huggingface_hub
daea45b verified
Raw
History Blame Contribute Delete
6.54 kB
"""smolcode agent engine — backed by the Rust smolcode_core agent loop."""
from __future__ import annotations
import asyncio
import os
import tempfile
from collections.abc import Callable
from dataclasses import dataclass
from .config import Preset, load_preset
from .rust_session import RustRunResult, RustSession, rust_available
from .sandbox import Workspace
from .trace_collector import TraceCollector
# Legacy prompt kept for docs; Rust agent uses prompts.rs system prompts.
SYSTEM_PROMPT = """You are smolcode, a precise coding assistant running on a small local model."""
@dataclass
class Step:
number: int
kind: str
detail: str
total_tokens: int | None = None
class SmallCodeAgent:
"""Agent facade: uses the Rust engine when smolcode_core is installed."""
def __init__(
self,
preset: Preset | None = None,
model: str | None = None,
max_steps: int = 12,
*,
system_prompt: str | None = None,
registry_builder: Callable | None = None,
workspace: Workspace | None = None,
name: str = "smolcode",
agent: str = "build",
profile: str = "full",
yolo: bool = False,
workspace_dir: str | None = None,
approval_handler=None,
rust_session: RustSession | None = None,
) -> None:
self.preset = preset or load_preset()
self.model = model or self.preset.default_model
self.max_steps = max_steps
self._system_prompt = system_prompt # unused by Rust; kept for API compat
self._registry_builder = registry_builder
self.hit_max_steps = False
self.errored = False
ws_path = workspace_dir or os.environ.get("SMALLCODE_WORKSPACE")
if workspace is not None:
ws_path = str(workspace.root)
elif ws_path is None:
ws_path = tempfile.mkdtemp(prefix="smallcode-")
self._owns_workspace = True
else:
self._owns_workspace = False
self.workspace = workspace or Workspace(root=ws_path)
profile_name = profile
if registry_builder is not None:
profile_name = "web"
if not rust_available():
raise RuntimeError(
"smolcode_core required; install with maturin in smolcode-cli/crates/smolcode-py"
)
if rust_session is not None:
self._rust = rust_session
else:
self._rust = RustSession(
workspace=ws_path,
agent=agent,
yolo=yolo,
model=self.model,
base_url=self.preset.base_url,
api_key=self.preset.api_key,
profile=profile_name,
approval_handler=approval_handler,
)
self.trace_collector = self._rust.trace_collector
if registry_builder is not None:
self._register_web_tools()
def _register_web_tools(self) -> None:
from .tools import check_app_impl
ws = self.workspace
collector = self.trace_collector
def check_app(args: dict) -> dict:
return check_app_impl(ws, collector, args)
self._rust.register_tool("check_app", check_app)
async def run(self, task: str, *, think: str | None = None, yolo: bool | None = None) -> tuple[str, list[Step]]:
self.hit_max_steps = False
self.errored = False
result: RustRunResult = await self._rust.run(task, think=think, yolo=yolo)
self.hit_max_steps = result.hit_max_steps
self.errored = result.errored
steps = self._steps_from_trace()
return result.final, steps
async def run_live_turn(
self,
task: str,
*,
think: str | None = None,
yolo: bool | None = None,
poll_interval: float = 0.35,
):
"""Async generator yielding LiveFrame snapshots during a Rust agent turn."""
from .live_run import LiveFrame
self.hit_max_steps = False
self.errored = False
self.trace_collector.events.clear()
self._rust.clear_cancel()
self._rust._session.start_turn(task, think=think, yolo=yolo)
final_text = ""
done = False
interrupted = False
while not done:
if self._rust.cancelled:
interrupted = True
done = True
break
ev = await asyncio.to_thread(self._rust._session.poll_event)
if ev is None:
yield LiveFrame(
events=self.trace_collector.snapshot(),
files=self.files(),
)
await asyncio.sleep(poll_interval)
continue
kind = ev.get("kind")
if kind == "approval":
approved = True
if self._rust.approval_handler is not None:
approved = await self._rust.approval_handler(ev.get("desc", ""))
self._rust._session.approve(approved)
continue
self._rust._ingest_event(ev)
if kind == "final":
final_text = ev.get("text", "")
if kind == "done":
done = True
yield LiveFrame(
events=self.trace_collector.snapshot(),
files=self.files(),
raw_event=ev,
)
if interrupted:
final_text = final_text or "interrupted"
self.errored = True
if final_text and not interrupted:
self._rust._session.record_turn(task, final_text)
steps = self._steps_from_trace()
yield LiveFrame(
steps=steps,
events=self.trace_collector.snapshot(),
files=self.files(),
done=True,
result=(final_text, steps),
)
def _steps_from_trace(self) -> list[Step]:
out: list[Step] = []
for i, ev in enumerate(self.trace_collector.events):
out.append(Step(number=i, kind=ev.kind, detail=ev.detail))
return out
def current_steps(self) -> list[Step]:
return self._steps_from_trace()
def raw_history(self) -> list:
return self.current_steps()
def files(self) -> dict[str, str]:
return self._rust.files()
@property
def rust_session(self) -> RustSession:
return self._rust
def cleanup(self) -> None:
if getattr(self, "_owns_workspace", False):
self.workspace.cleanup()