neuralcad / agents /tools.py
CallMeDaniel's picture
refactor: type ContextVar design state as DesignState
83f3ff9
"""CrewAI tools for CadQuery code execution and CNC validation.
Uses BaseTool subclasses with Pydantic args_schema for structured input.
"""
from __future__ import annotations
import json
import logging
from contextvars import ContextVar
from typing import Type
from pydantic import BaseModel, Field
from agents.design_state import DesignState
logger = logging.getLogger(__name__)
try:
from crewai.tools import BaseTool
except ImportError:
class BaseTool: # type: ignore[no-redef]
name: str = ""
description: str = ""
args_schema: type | None = None
def _run(self, **kwargs) -> str:
return ""
# ── Per-request state (ContextVar — async-safe) ─────────────────────────
_last_shape_var: ContextVar[object | None] = ContextVar("last_shape", default=None)
_design_state_var: ContextVar[DesignState | None] = ContextVar("design_state", default=None)
def set_last_shape(shape):
_last_shape_var.set(shape)
def get_last_shape():
return _last_shape_var.get()
def set_design_state(state: DesignState):
_design_state_var.set(state)
def get_design_state() -> DesignState | None:
return _design_state_var.get()
# ── Tool input schemas ──────────────────────────────────────────────────
class ExecuteCadInput(BaseModel):
code: str = Field(..., description="CadQuery Python code. Must assign result to `result` as cq.Workplane. Import cadquery as cq.")
class ValidateCadInput(BaseModel):
check_type: str = Field(default="full", description="Validation type: 'full' for complete CNC manufacturability check.")
class GenerateGcodeInput(BaseModel):
operations: list[str] = Field(..., description="Ordered list of operations: adaptive, pocket, profile, face, drill, surface, waterline")
tool_diameter: float = Field(default=6.0, description="Endmill diameter in mm")
post_processor: str = Field(default="grbl", description="G-code format: grbl, linuxcnc, fanuc")
VALID_CHECKS = {"all", "material", "dimensions", "features", "constraints", "axis"}
class QueryDesignStateInput(BaseModel):
check: str = Field(default="all", description="What to check: 'all' for full state, or a specific field (material, dimensions, features, constraints, axis).")
# ── Tool implementations ────────────────────────────────────────────────
class ExecuteCadTool(BaseTool):
name: str = "Execute CadQuery Code"
description: str = "Execute CadQuery Python code and return geometry info: volume, bounding box, face count, edge count."
args_schema: Type[BaseModel] = ExecuteCadInput
def _run(self, code: str) -> str:
from core.executor import execute_cadquery
result = execute_cadquery(code)
if result.success and result.result is not None:
set_last_shape(result.result)
return json.dumps(result.model_dump(by_alias=True), indent=2)
class ValidateCadTool(BaseTool):
name: str = "Validate CNC Manufacturability"
description: str = "Run CNC manufacturability checks on the last executed shape. Returns machinable status, axis recommendation, and issues list."
args_schema: Type[BaseModel] = ValidateCadInput
def _run(self, check_type: str = "full") -> str:
from core.validator import validate_for_cnc
shape = get_last_shape()
if shape is None:
return json.dumps({"success": False, "error": "No shape available. Run Execute CadQuery Code first."})
validation = validate_for_cnc(shape)
return json.dumps({"success": True, "validation": validation.model_dump()}, indent=2)
class GenerateGcodeTool(BaseTool):
name: str = "Generate G-code Toolpath"
description: str = "Generate CNC G-code toolpath from the last executed CadQuery shape."
args_schema: Type[BaseModel] = GenerateGcodeInput
def _run(self, operations: list[str], tool_diameter: float = 6.0, post_processor: str = "grbl") -> str:
from core.cam import generate_gcode, ToolConfig
shape = get_last_shape()
if shape is None:
return json.dumps({"success": False, "error": "No shape available. Run Execute CadQuery Code first."})
tool_config = ToolConfig(diameter=tool_diameter, h_feed=800, v_feed=200, speed=18000)
result = generate_gcode(
shape=shape, operations=operations,
tool_config=tool_config, post_processor=post_processor,
)
return json.dumps(result.model_dump(), indent=2)
class QueryDesignStateTool(BaseTool):
name: str = "Query Design State"
description: str = "Query the orchestrator for current design state and readiness. Call BEFORE saying NOT READY to check what information is already available."
args_schema: Type[BaseModel] = QueryDesignStateInput
def _run(self, check: str = "all") -> str:
from agents.design_state import compute_score
from config.settings import settings
if check not in VALID_CHECKS:
return json.dumps({"error": f"Invalid check: {check!r}. Valid: {sorted(VALID_CHECKS)}"})
state = get_design_state()
if state is None:
return json.dumps({"error": "No design state available."})
score = compute_score(state)
threshold = settings.planning.threshold
known = {}
missing = []
if state.part_name:
known["part_name"] = state.part_name
else:
missing.append("part_name")
if state.material:
known["material"] = state.material
else:
missing.append("material")
if state.dimensions:
known["dimensions"] = state.dimensions
else:
missing.append("dimensions")
if state.features:
known["features"] = state.features
else:
missing.append("features")
if state.constraints:
known["constraints"] = state.constraints
else:
missing.append("constraints")
if state.axis_recommendation:
known["axis_recommendation"] = state.axis_recommendation
else:
missing.append("axis_recommendation")
if state.description:
known["description"] = state.description
if state.decisions:
known["recent_decisions"] = state.decisions[-5:]
result = {
"known": known,
"missing": missing,
"readiness_score": score,
"threshold": threshold,
"ready": score >= threshold,
"phase": state.phase,
}
if check != "all" and check in known:
return json.dumps({"field": check, "value": known[check], "ready": score >= threshold})
if check != "all" and check in missing:
return json.dumps({"field": check, "value": None, "missing": True, "ready": score >= threshold})
return json.dumps(result, indent=2)