import os import shutil from pathlib import Path from uuid import uuid4 import gradio as gr from agent_core.config import ARTIFACT_ROOT from agent_core.service import run_chat_turn def new_session_state() -> dict: return { "session_id": uuid4().hex[:12], "agent_messages": [], "chat_messages": [], "latest_tool_payload": None, } def copy_uploaded_image(image_path: str | None, session_id: str) -> str | None: if not image_path: return None source = Path(image_path) suffix = source.suffix or ".png" upload_dir = ARTIFACT_ROOT / "uploads" / session_id upload_dir.mkdir(parents=True, exist_ok=True) target = upload_dir / f"{uuid4().hex[:8]}{suffix}" shutil.copy2(source, target) return str(target) def parse_user_message(message) -> tuple[str, str | None]: if message is None: return "", None if isinstance(message, str): return message.strip(), None text = (message.get("text") or "").strip() files = message.get("files") or [] image_path = files[0] if files else None return text, image_path def format_status(result: dict) -> str: payload = result.get("tool_payload") or {} if not payload: return "No model generated yet." if payload.get("ok"): return "\n".join([ f"Run: `{payload.get('run_id')}`", f"Output: `{payload.get('output_path')}`", f"Preview: `{payload.get('preview_path') or payload.get('output_path')}`", f"Manifest: `{payload.get('manifest_path')}`", ]) return "\n".join([ "Generation failed.", f"Stage: `{payload.get('stage')}`", f"Error: `{payload.get('error')}`", f"Manifest: `{payload.get('manifest_path')}`" if payload.get("manifest_path") else "", ]).strip() def submit_message(message, state): state = state or new_session_state() text, image_path = parse_user_message(message) uploaded_image = copy_uploaded_image(image_path, state["session_id"]) if not text and not uploaded_image: return state["chat_messages"], state, None, None, "Enter a message or upload an image.", {"text": "", "files": []} display_text = text or "Generate a 3D model from the uploaded image." if uploaded_image: display_text = f"{display_text}\n\n[uploaded image]" state["chat_messages"].append({"role": "user", "content": display_text}) try: result = run_chat_turn( messages=state["agent_messages"], user_text=text or "Generate a 3D model from the uploaded image.", image_path=uploaded_image, ) state["agent_messages"] = result["messages"] assistant_text = result["assistant_text"] or "Done." if result.get("error") and not assistant_text: assistant_text = result["error"] state["chat_messages"].append({"role": "assistant", "content": assistant_text}) state["latest_tool_payload"] = result.get("tool_payload") status = format_status(result) return ( state["chat_messages"], state, result.get("preview_path"), result.get("download_path"), status, {"text": "", "files": []}, ) except Exception as exc: state["chat_messages"].append({"role": "assistant", "content": f"Failed: {exc}"}) return state["chat_messages"], state, None, None, f"Failed: `{exc}`", {"text": "", "files": []} def reset_session(): state = new_session_state() return [], state, None, None, "New session started.", {"text": "", "files": []} with gr.Blocks(title="ForgeCAD", fill_height=True) as demo: gr.Markdown("# AI CAD Agent") state = gr.State(new_session_state()) with gr.Row(): with gr.Column(scale=5): chatbot = gr.Chatbot(height=560, label="Conversation") composer = gr.MultimodalTextbox( sources=["upload"], file_types=["image"], file_count="single", lines=2, max_lines=6, label="Message", placeholder="Describe a CAD model, ask for an edit, or upload an image as reference.", submit_btn="Send", ) with gr.Row(): clear = gr.Button("New Session") with gr.Column(scale=4): preview = gr.Model3D(label="Model Preview", height=560) download = gr.File(label="Download Latest Model") status = gr.Markdown("No model generated yet.") composer.submit( submit_message, inputs=[composer, state], outputs=[ chatbot, state, preview, download, status, composer, ], ) clear.click( reset_session, outputs=[ chatbot, state, preview, download, status, composer, ], ) if __name__ == "__main__": server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860")) demo.queue(default_concurrency_limit=1).launch( server_name=server_name, server_port=server_port, allowed_paths=[str(ARTIFACT_ROOT)], )