neuralcad / core /backends.py
CallMeDaniel's picture
refactor: config-drive all backends, use LLMBackend ABC from types
33f166d
"""
LLM backend implementations for CadQuery code generation.
Supports multiple backends:
- Anthropic Claude
- OpenAI GPT-4o
- Google Gemini (free tier available)
- Mock (dynamic generation, no API key required)
- NeuralCAD (local neural pipeline, not yet implemented)
"""
import base64
import mimetypes
import os
import re
from pathlib import Path
from core.types import LLMBackend
# ── LLM Backends ──────────────────────────────────────────────────────────
class AnthropicBackend(LLMBackend):
"""Generate CadQuery code using Anthropic Claude."""
def __init__(self, model: str | None = None, api_key: str | None = None):
import anthropic
from config.settings import settings
self.model = model or settings.model_for.get("anthropic", "claude-sonnet-4-20250514")
key = api_key or settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(api_key=key)
def generate(self, messages: list[dict]) -> str:
from config.settings import settings
system_msg, user_messages = self.split_system_message(messages)
response = self.client.messages.create(
model=self.model,
max_tokens=settings.max_tokens,
system=system_msg,
messages=user_messages,
)
return response.content[0].text
def generate_with_image(self, messages: list[dict], image_path: str | Path) -> str:
from config.settings import settings
image_path = Path(image_path)
media_type = mimetypes.guess_type(str(image_path))[0] or "image/png"
image_data = base64.b64encode(image_path.read_bytes()).decode("utf-8")
system_msg, user_messages = self.split_system_message(messages)
# Replace last user message content with multimodal blocks
last_user = user_messages[-1]
last_user["content"] = [
{"type": "image", "source": {"type": "base64", "media_type": media_type, "data": image_data}},
{"type": "text", "text": last_user["content"]},
]
response = self.client.messages.create(
model=self.model,
max_tokens=settings.max_tokens,
system=system_msg,
messages=user_messages,
)
return response.content[0].text
class OpenAIBackend(LLMBackend):
"""Generate CadQuery code using OpenAI GPT-4o."""
def __init__(self, model: str | None = None, api_key: str | None = None):
import openai
from config.settings import settings
self.model = model or settings.model_for.get("openai", "gpt-4o")
key = api_key or settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
self.client = openai.OpenAI(api_key=key)
def generate(self, messages: list[dict]) -> str:
from config.settings import settings
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=settings.max_tokens,
temperature=settings.temperature,
)
return response.choices[0].message.content
def generate_with_image(self, messages: list[dict], image_path: str | Path) -> str:
from config.settings import settings
image_path = Path(image_path)
media_type = mimetypes.guess_type(str(image_path))[0] or "image/png"
image_data = base64.b64encode(image_path.read_bytes()).decode("utf-8")
data_url = f"data:{media_type};base64,{image_data}"
# Copy messages, replace last user message with multimodal content
patched = [dict(m) for m in messages]
last_user = patched[-1]
last_user["content"] = [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": last_user["content"]},
]
response = self.client.chat.completions.create(
model=self.model,
messages=patched,
max_tokens=settings.max_tokens,
temperature=settings.temperature,
)
return response.choices[0].message.content
class GeminiBackend(LLMBackend):
"""Generate CadQuery code using Google Gemini (free tier available)."""
def __init__(self, model: str | None = None, api_key: str | None = None):
from google import genai
from config.settings import settings
self.model = model or settings.model_for.get("gemini", "gemini-2.5-flash")
key = api_key or settings.google_api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
self.client = genai.Client(api_key=key)
def generate(self, messages: list[dict]) -> str:
from config.settings import settings
from google.genai import types
system_msg, other_messages = self.split_system_message(messages)
contents = []
for m in other_messages:
if m["role"] == "user":
contents.append({"role": "user", "parts": [{"text": m["content"]}]})
elif m["role"] == "assistant":
contents.append({"role": "model", "parts": [{"text": m["content"]}]})
response = self.client.models.generate_content(
model=self.model,
contents=contents,
config=types.GenerateContentConfig(
system_instruction=system_msg,
max_output_tokens=settings.max_tokens,
temperature=settings.temperature,
),
)
return response.text
def generate_with_image(self, messages: list[dict], image_path: str | Path) -> str:
from config.settings import settings
from google.genai import types
image_path = Path(image_path)
image_data = image_path.read_bytes()
media_type = mimetypes.guess_type(str(image_path))[0] or "image/png"
system_msg, other_messages = self.split_system_message(messages)
contents = []
for m in other_messages:
if m["role"] == "user":
contents.append({"role": "user", "parts": [{"text": m["content"]}]})
elif m["role"] == "assistant":
contents.append({"role": "model", "parts": [{"text": m["content"]}]})
# Add image to the last user message
if contents and contents[-1]["role"] == "user":
contents[-1]["parts"].insert(0, {
"inline_data": {"mime_type": media_type, "data": image_data}
})
response = self.client.models.generate_content(
model=self.model,
contents=contents,
config=types.GenerateContentConfig(
system_instruction=system_msg,
max_output_tokens=settings.max_tokens,
temperature=settings.temperature,
),
)
return response.text
class MockBackend(LLMBackend):
"""
Mock backend that dynamically generates CadQuery code from any prompt.
Parses dimensions, shape type, and features from the text, then assembles
parametric code. No API key required.
"""
# Word-to-number mapping for natural language counts
_WORD_NUMS = {
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
"twelve": 12,
"sixteen": 16,
"twenty": 20,
}
# Shape detection patterns → base shape key
_SHAPE_PATTERNS = {
"cylinder": [
"cylinder",
"rod",
"shaft",
"axle",
"spacer",
"washer",
"bushing",
"sleeve",
"tube",
"pipe",
"dowel",
"pin",
],
"plate": [
"plate",
"bracket",
"mount",
"flange",
"baseplate",
"panel",
"shim",
"cover",
"lid",
],
"box": [
"box",
"block",
"enclosure",
"housing",
"case",
"cube",
"container",
"shell",
],
"l_bracket": [
"l-bracket",
"l bracket",
"angle bracket",
"corner bracket",
"l-shaped",
],
}
# Feature detection keywords
_FEATURE_KEYWORDS = {
"holes": ["hole", "holes", "bolt", "bolts", "screw", "screws", "bore", "bores"],
"pocket": ["pocket", "recess", "cavity", "cutout", "mortise"],
"slot": ["slot", "slots", "groove", "channel", "keyway"],
"fillet": ["fillet", "fillets", "round", "rounded"],
"chamfer": ["chamfer", "chamfers", "bevel", "beveled"],
"through_hole": ["through hole", "through-hole", "thru hole", "thru-hole"],
"counterbore": ["counterbore", "counterbored", "cbore"],
"fins": ["fin", "fins", "cooling", "heatsink", "heat sink", "radiator"],
"ribs": ["rib", "ribs", "stiffener", "stiffeners", "web"],
"boss": ["boss", "bosses", "standoff", "standoffs", "pillar"],
}
@property
def _thread_clearance(self) -> dict[str, float]:
from config.settings import settings
return settings.fasteners
def _parse_prompt(self, text: str) -> dict:
"""Extract dimensions, shape, and features from natural language."""
lower = text.lower()
# Extract all numbers with optional units
raw_nums = re.findall(r"(\d+\.?\d*)\s*(?:mm|cm|m\b)?", lower)
dimensions = [float(n) for n in raw_nums if 0.1 < float(n) < 2000]
# Detect metric thread sizes (M3, M6, etc.)
thread_match = re.search(r"\bm(\d+)\b", lower)
hole_dia = None
if thread_match:
key = f"m{thread_match.group(1)}"
hole_dia = self._thread_clearance.get(
key, float(thread_match.group(1)) * 1.1
)
# Detect hole diameter from "Xmm hole"
hole_dim_match = re.search(
r"(\d+\.?\d*)\s*mm\s*(?:hole|bore|holes|bores)", lower
)
if hole_dim_match and not hole_dia:
hole_dia = float(hole_dim_match.group(1))
# Detect count (numeric or word)
count = None
count_match = re.search(
r"(\d+)\s*(?:hole|bolt|screw|bore|fin|rib|slot|boss)", lower
)
if count_match:
count = int(count_match.group(1))
else:
for word, num in self._WORD_NUMS.items():
if re.search(rf"\b{word}\b.*(?:hole|bolt|screw|bore|fin|slot)", lower):
count = num
break
# Detect base shape
shape = "box"
for shape_key, keywords in self._SHAPE_PATTERNS.items():
if any(kw in lower for kw in keywords):
shape = shape_key
break
# Detect features
features = set()
for feat, keywords in self._FEATURE_KEYWORDS.items():
if any(kw in lower for kw in keywords):
features.add(feat)
# If holes mentioned but no specific feature, add generic holes
if (
any(w in lower for w in ["hole", "holes", "bolt", "screw"])
and "holes" not in features
):
features.add("holes")
return {
"dimensions": dimensions,
"shape": shape,
"features": features,
"hole_dia": hole_dia or 5.5,
"count": count or 4,
"prompt": text,
}
def _generate_code(self, p: dict) -> str:
"""Build CadQuery code from parsed parameters."""
dims = p["dimensions"]
shape = p["shape"]
features = p["features"]
prompt = p["prompt"]
lines = ["import cadquery as cq"]
if shape == "cylinder" and "fins" in features:
lines.append("import math")
lines.append(f"")
lines.append(f"# Generated from: {prompt}")
if shape == "cylinder":
radius = dims[0] / 2 if dims else 15.0
height = dims[1] if len(dims) > 1 else radius * 2
lines.append(f"# Cylinder: radius={radius}mm, height={height}mm")
lines.append(f"result = (")
lines.append(f" cq.Workplane('XY')")
lines.append(f" .cylinder({height}, {radius})")
if "holes" in features or "through_hole" in features:
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .hole({p['hole_dia']})")
if "chamfer" in features or "fillet" not in features:
lines.append(f" .edges('>Z or <Z').chamfer(0.5)")
if "fillet" in features:
lines.append(f" .edges('>Z or <Z').fillet(1.0)")
lines.append(f")")
if "fins" in features:
n_fins = p["count"] if p["count"] > 4 else 8
fin_h = max(height * 0.8, 5)
fin_w = 1.5
lines.append(f"")
lines.append(f"# Add {n_fins} cooling fins")
lines.append(f"for i in range({n_fins}):")
lines.append(f" angle = i * 360 / {n_fins}")
lines.append(f" rad = math.radians(angle)")
lines.append(f" fx = {radius + 3} * math.cos(rad)")
lines.append(f" fy = {radius + 3} * math.sin(rad)")
lines.append(f" fin = (")
lines.append(f" cq.Workplane('XY')")
lines.append(
f" .transformed(offset=(fx, fy, 0), rotate=(0, 0, angle))"
)
lines.append(f" .rect({fin_w}, {radius * 0.6})")
lines.append(f" .extrude({fin_h})")
lines.append(f" )")
lines.append(f" result = result.union(fin)")
elif shape == "plate":
w = dims[0] if dims else 80.0
h = dims[1] if len(dims) > 1 else w * 0.6
t = dims[2] if len(dims) > 2 else 5.0
lines.append(f"# Plate: {w}x{h}x{t}mm")
lines.append(f"result = (")
lines.append(f" cq.Workplane('XY')")
lines.append(f" .box({w}, {h}, {t})")
if "holes" in features or "through_hole" in features:
n = p["count"]
dia = p["hole_dia"]
# Distribute holes in a grid or circle
if "flange" in p["prompt"].lower() or n >= 6:
# Bolt circle pattern
r = min(w, h) * 0.35
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .polarArray({r}, 0, 360, {n})")
lines.append(f" .hole({dia})")
if "bore" in p["prompt"].lower() or "flange" in p["prompt"].lower():
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .hole({dia * 3}) # Center bore")
else:
# Rectangular pattern
ox = w * 0.35
oy = h * 0.35
pts = []
if n == 1:
pts = [(0, 0)]
elif n == 2:
pts = [(-ox, 0), (ox, 0)]
elif n == 4:
pts = [(-ox, -oy), (-ox, oy), (ox, -oy), (ox, oy)]
else:
pts = [(-ox, -oy), (-ox, oy), (ox, -oy), (ox, oy)]
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .pushPoints({pts})")
lines.append(f" .hole({dia})")
if "pocket" in features:
pw = w * 0.4
ph = h * 0.35
pd = t * 0.6
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .rect({pw}, {ph})")
lines.append(f" .cutBlind(-{pd}) # Central pocket")
if "slot" in features:
sl = w * 0.35
sw = max(t * 0.8, 4)
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .slot2D({sl}, {sw}).cutBlind(-{t})")
if "fillet" in features:
lines.append(f" .edges('|Z').fillet({max(t * 0.4, 1.5)})")
else:
lines.append(f" .edges('>Z').chamfer(0.5)")
lines.append(f")")
elif shape == "l_bracket":
arm = dims[0] if dims else 50.0
width = dims[1] if len(dims) > 1 else 20.0
t = dims[2] if len(dims) > 2 else 4.0
lines.append(f"# L-bracket: {arm}mm arms, {width}mm wide, {t}mm thick")
lines.append(f"result = (")
lines.append(f" cq.Workplane('XZ')")
lines.append(f" .moveTo(0, 0)")
lines.append(f" .lineTo({arm}, 0)")
lines.append(f" .lineTo({arm}, {t})")
lines.append(f" .lineTo({t}, {t})")
lines.append(f" .lineTo({t}, {arm})")
lines.append(f" .lineTo(0, {arm})")
lines.append(f" .close()")
lines.append(f" .extrude({width})")
lines.append(f" .edges('|Y').fillet({max(t * 0.5, 1.5)})")
if "holes" in features:
lines.append(
f" .faces('>Z').workplane(centerOption='CenterOfBoundBox')"
)
lines.append(f" .center({arm * 0.5}, 0)")
lines.append(f" .hole({p['hole_dia']})")
lines.append(
f" .faces('>X').workplane(centerOption='CenterOfBoundBox')"
)
lines.append(f" .center(0, {arm * 0.5})")
lines.append(f" .hole({p['hole_dia']})")
lines.append(f" .edges().chamfer(0.5)")
lines.append(f")")
else: # box / enclosure / housing
w = dims[0] if dims else 60.0
h = dims[1] if len(dims) > 1 else w * 0.65
d = dims[2] if len(dims) > 2 else 20.0
lines.append(f"# Box: {w}x{h}x{d}mm")
lines.append(f"result = (")
lines.append(f" cq.Workplane('XY')")
lines.append(f" .box({w}, {h}, {d})")
if "holes" in features or "through_hole" in features:
ox = w * 0.35
oy = h * 0.35
pts = [(-ox, -oy), (-ox, oy), (ox, -oy), (ox, oy)]
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .pushPoints({pts})")
lines.append(f" .hole({p['hole_dia']})")
if "pocket" in features:
pw = w * 0.5
ph = h * 0.4
pd = d * 0.4
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .rect({pw}, {ph})")
lines.append(f" .cutBlind(-{pd})")
if "slot" in features:
sl = w * 0.4
sw = 6
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .slot2D({sl}, {sw}).cutBlind(-{d})")
if "boss" in features:
n = min(p["count"], 4)
bx = w * 0.3
by = h * 0.3
boss_pts = [(-bx, -by), (-bx, by), (bx, -by), (bx, by)][:n]
lines.append(f" .faces('>Z').workplane()")
lines.append(f" .pushPoints({boss_pts})")
lines.append(f" .circle(4).extrude(6) # Mounting bosses")
if "ribs" in features:
n_ribs = p["count"] if p["count"] <= 8 else 4
spacing = w / (n_ribs + 1)
lines.append(f" .faces('>Z').workplane()")
for i in range(n_ribs):
rx = -w / 2 + spacing * (i + 1)
lines.append(f" .center({rx if i == 0 else spacing}, 0)")
lines.append(f" .rect(2, {h * 0.8}).extrude({d * 0.3})")
if "fillet" in features:
lines.append(f" .edges('|Z').fillet({min(d * 0.2, 3)})")
elif "chamfer" in features:
lines.append(f" .edges('>Z').chamfer(1.0)")
else:
lines.append(f" .edges('>Z').chamfer(0.5)")
lines.append(f")")
return "\n".join(lines) + "\n"
# Curated hero responses for specific prompts
_CURATED = {
"gear": """\
import cadquery as cq
import math
# Simple spur gear approximation: 20 teeth, module 2, 10mm thick
module = 2
teeth = 20
pitch_radius = module * teeth / 2
outer_radius = pitch_radius + module
tooth_angle = 360 / teeth
result = (
cq.Workplane("XY")
.cylinder(10, outer_radius)
.faces(">Z").workplane()
.hole(12)
)
for i in range(teeth):
angle = i * tooth_angle
rad = math.radians(angle)
gap_x = pitch_radius * math.cos(rad)
gap_y = pitch_radius * math.sin(rad)
cutter = (
cq.Workplane("XY")
.transformed(offset=(gap_x, gap_y, 0), rotate=(0, 0, angle))
.rect(module * 0.8, module * 2.5)
.extrude(12)
)
result = result.cut(cutter)
result = result.edges(">Z or <Z").chamfer(0.3)
""",
}
def generate(self, messages: list[dict]) -> str:
user_msg = messages[-1]["content"]
lower = user_msg.lower()
# Check curated responses first
for key, code in self._CURATED.items():
if key in lower:
return code
# Dynamic generation for everything else
params = self._parse_prompt(user_msg)
return self._generate_code(params)
class NeuralCADBackend(LLMBackend):
"""
Neural CAD pipeline backend.
Runs trained models locally:
Text/Image → CLIP encoder → contrastive latent
→ Diffusion prior → latent
→ Transformer decoder → CAD command sequence
→ OpenCascade kernel → B-rep solid
Unlike LLM backends, this does not generate CadQuery code strings.
Instead it produces CAD command sequences decoded directly into geometry.
"""
def __init__(
self,
model_dir: str | Path = "./models",
device: str = "cuda",
clip_model: str = "clip_encoder.pt",
prior_model: str = "diffusion_prior.pt",
decoder_model: str = "transformer_decoder.pt",
):
self.model_dir = Path(model_dir)
self.device = device
self.clip_encoder = None
self.diffusion_prior = None
self.transformer_decoder = None
self._model_config = {
"clip": clip_model,
"prior": prior_model,
"decoder": decoder_model,
}
def load_models(self):
"""Load all model weights from disk. Call once before inference."""
raise NotImplementedError(
f"Model loading not yet implemented. "
f"Expected model files in: {self.model_dir}"
)
def encode_text(self, text: str):
"""Encode text prompt to CLIP latent vector."""
raise NotImplementedError("CLIP text encoder not yet implemented")
def encode_image(self, image_path: str | Path):
"""Encode image (photo/sketch) to CLIP latent vector."""
raise NotImplementedError("CLIP image encoder not yet implemented")
def run_diffusion_prior(self, clip_embedding):
"""Map CLIP embedding to CAD latent via diffusion prior."""
raise NotImplementedError("Diffusion prior not yet implemented")
def decode_to_cad_sequence(self, latent):
"""Decode latent to CAD command sequence."""
raise NotImplementedError("Transformer decoder not yet implemented")
def cad_sequence_to_solid(self, cad_commands: list[dict]):
"""Execute CAD command sequence through OpenCascade kernel → B-rep solid."""
raise NotImplementedError("CAD kernel execution not yet implemented")
def generate(self, messages: list[dict]) -> str:
"""
LLMBackend-compatible interface.
Extracts the text prompt from messages, runs the full neural pipeline,
and returns CadQuery-equivalent code as a string for compatibility
with the existing execution/validation/export pipeline.
"""
user_msg = messages[-1]["content"]
clip_emb = self.encode_text(user_msg)
latent = self.run_diffusion_prior(clip_emb)
cad_commands = self.decode_to_cad_sequence(latent)
return self._cad_commands_to_code(cad_commands)
def generate_from_image(self, image_path: str | Path, text_hint: str = "") -> str:
"""
Image-conditioned generation (not available on LLM backends).
Args:
image_path: Path to photo or sketch of the desired part.
text_hint: Optional text to guide generation alongside the image.
Returns:
CadQuery code string for pipeline compatibility.
"""
img_emb = self.encode_image(image_path)
if text_hint:
txt_emb = self.encode_text(text_hint)
# Fuse text + image embeddings (strategy TBD — average, concat, cross-attn)
clip_emb = (img_emb + txt_emb) / 2 # placeholder fusion
else:
clip_emb = img_emb
latent = self.run_diffusion_prior(clip_emb)
cad_commands = self.decode_to_cad_sequence(latent)
return self._cad_commands_to_code(cad_commands)
def _cad_commands_to_code(self, cad_commands: list[dict]) -> str:
"""Convert internal CAD command sequence to CadQuery Python code string."""
raise NotImplementedError(
"CAD command → CadQuery code serializer not yet implemented"
)