ForgeCAD / app.py
KaiWu
refactor(app): 移除 Gradio feedback 采集模块
2aa101f
import os
import shutil
from pathlib import Path
from uuid import uuid4
import gradio as gr
from agent_core.config import ARTIFACT_ROOT
from agent_core.service import run_chat_turn
def new_session_state() -> dict:
return {
"session_id": uuid4().hex[:12],
"agent_messages": [],
"chat_messages": [],
"latest_tool_payload": None,
}
def copy_uploaded_image(image_path: str | None, session_id: str) -> str | None:
if not image_path:
return None
source = Path(image_path)
suffix = source.suffix or ".png"
upload_dir = ARTIFACT_ROOT / "uploads" / session_id
upload_dir.mkdir(parents=True, exist_ok=True)
target = upload_dir / f"{uuid4().hex[:8]}{suffix}"
shutil.copy2(source, target)
return str(target)
def parse_user_message(message) -> tuple[str, str | None]:
if message is None:
return "", None
if isinstance(message, str):
return message.strip(), None
text = (message.get("text") or "").strip()
files = message.get("files") or []
image_path = files[0] if files else None
return text, image_path
def format_status(result: dict) -> str:
payload = result.get("tool_payload") or {}
if not payload:
return "No model generated yet."
if payload.get("ok"):
return "\n".join([
f"Run: `{payload.get('run_id')}`",
f"Output: `{payload.get('output_path')}`",
f"Preview: `{payload.get('preview_path') or payload.get('output_path')}`",
f"Manifest: `{payload.get('manifest_path')}`",
])
return "\n".join([
"Generation failed.",
f"Stage: `{payload.get('stage')}`",
f"Error: `{payload.get('error')}`",
f"Manifest: `{payload.get('manifest_path')}`" if payload.get("manifest_path") else "",
]).strip()
def submit_message(message, state):
state = state or new_session_state()
text, image_path = parse_user_message(message)
uploaded_image = copy_uploaded_image(image_path, state["session_id"])
if not text and not uploaded_image:
return state["chat_messages"], state, None, None, "Enter a message or upload an image.", {"text": "", "files": []}
display_text = text or "Generate a 3D model from the uploaded image."
if uploaded_image:
display_text = f"{display_text}\n\n[uploaded image]"
state["chat_messages"].append({"role": "user", "content": display_text})
try:
result = run_chat_turn(
messages=state["agent_messages"],
user_text=text or "Generate a 3D model from the uploaded image.",
image_path=uploaded_image,
)
state["agent_messages"] = result["messages"]
assistant_text = result["assistant_text"] or "Done."
if result.get("error") and not assistant_text:
assistant_text = result["error"]
state["chat_messages"].append({"role": "assistant", "content": assistant_text})
state["latest_tool_payload"] = result.get("tool_payload")
status = format_status(result)
return (
state["chat_messages"],
state,
result.get("preview_path"),
result.get("download_path"),
status,
{"text": "", "files": []},
)
except Exception as exc:
state["chat_messages"].append({"role": "assistant", "content": f"Failed: {exc}"})
return state["chat_messages"], state, None, None, f"Failed: `{exc}`", {"text": "", "files": []}
def reset_session():
state = new_session_state()
return [], state, None, None, "New session started.", {"text": "", "files": []}
with gr.Blocks(title="ForgeCAD", fill_height=True) as demo:
gr.Markdown("# AI CAD Agent")
state = gr.State(new_session_state())
with gr.Row():
with gr.Column(scale=5):
chatbot = gr.Chatbot(height=560, label="Conversation")
composer = gr.MultimodalTextbox(
sources=["upload"],
file_types=["image"],
file_count="single",
lines=2,
max_lines=6,
label="Message",
placeholder="Describe a CAD model, ask for an edit, or upload an image as reference.",
submit_btn="Send",
)
with gr.Row():
clear = gr.Button("New Session")
with gr.Column(scale=4):
preview = gr.Model3D(label="Model Preview", height=560)
download = gr.File(label="Download Latest Model")
status = gr.Markdown("No model generated yet.")
composer.submit(
submit_message,
inputs=[composer, state],
outputs=[
chatbot,
state,
preview,
download,
status,
composer,
],
)
clear.click(
reset_session,
outputs=[
chatbot,
state,
preview,
download,
status,
composer,
],
)
if __name__ == "__main__":
server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
demo.queue(default_concurrency_limit=1).launch(
server_name=server_name,
server_port=server_port,
allowed_paths=[str(ARTIFACT_ROOT)],
)