Apple
Initial CADCoder Space with Gradio
be0593d
# app.py
import os
import subprocess
import traceback
import gradio as gr
MODEL_ID = "CADCODER/CAD-Coder" # HF model id
REPO_GIT = "https://github.com/CADCODER/CAD-Coder.git"
REPO_DIR = "CAD-Coder"
# 1) git-clone the repo if missing (your preference)
if not os.path.isdir(REPO_DIR):
try:
print("Cloning CAD-Coder repo...")
subprocess.run(["git", "clone", REPO_GIT, REPO_DIR], check=True)
except Exception as e:
print("Could not clone repository:", e)
# 2) Prepare model loader with graceful fallback to HF Inference API
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_API_TOKEN")
local_generate = None
api_generate = None
# Try to load model locally (8-bit if possible)
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
import bitsandbytes as bnb # optional; enables 8-bit loading
has_bnb = True
except Exception:
has_bnb = False
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=hf_token, trust_remote_code=True)
load_kwargs = {"device_map": "auto", "trust_remote_code": True}
if has_bnb:
print("bitsandbytes available β€” will attempt 8-bit load (saves memory).")
load_kwargs.update({"load_in_8bit": True, "torch_dtype": torch.float16})
else:
# attempt fp16 auto if GPU present
if torch.cuda.is_available():
load_kwargs["torch_dtype"] = torch.float16
print("Loading model (this can take a while)...")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=hf_token, **load_kwargs)
if hasattr(model, "to"):
# make sure model moved to devices by device_map
pass
device = next(model.parameters()).device
print("Model loaded on device:", device)
def local_generate_fn(prompt, max_new_tokens=512):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
gen = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
return tokenizer.decode(gen[0], skip_special_tokens=True)
local_generate = local_generate_fn
except Exception as e:
print("Local model load failed or not feasible in this environment.")
traceback.print_exc()
# Fallback: Hugging Face Inference API (works without loading weights locally)
if local_generate is None:
try:
from huggingface_hub import InferenceApi
print("Setting up HF Inference API client as fallback...")
api = InferenceApi(repo_id=MODEL_ID, token=hf_token)
def api_generate_fn(prompt, max_new_tokens=512):
# call the hosted inference endpoint
out = api(inputs=prompt, params={"max_new_tokens": max_new_tokens})
# Response can be a dict or list depending on pipeline; extract defensively
if isinstance(out, list):
first = out[0]
if isinstance(first, dict):
return first.get("generated_text") or str(first)
return str(first)
elif isinstance(out, dict):
return out.get("generated_text") or str(out)
else:
return str(out)
api_generate = api_generate_fn
print("Inference API fallback ready.")
except Exception as e:
print("HF Inference API not available:", e)
traceback.print_exc()
# Final generate function: prefer local, otherwise API fallback, otherwise error
def generate(prompt, max_new_tokens=512):
if local_generate:
return local_generate(prompt, max_new_tokens=max_new_tokens)
elif api_generate:
return api_generate(prompt, max_new_tokens=max_new_tokens)
else:
return "ERROR: No model loaded and no API fallback available. Check HF_TOKEN and Space hardware."
# Gradio UI
def run_prompt(prompt, max_tokens=512):
if not prompt or prompt.strip() == "":
return "Enter a prompt describing the CAD sketch you want (e.g., 'rectangle width 10 height 5 with hole radius 1')."
try:
return generate(prompt, max_new_tokens=int(max_tokens))
except Exception as e:
traceback.print_exc()
return f"Generation error: {e}"
with gr.Blocks() as demo:
gr.Markdown("# CAD-Coder (Text β†’ CadQuery code)")
prompt = gr.Textbox(label="Natural language prompt", lines=4, placeholder="e.g. 'create a rectangular plate 100x50 with a centered 10mm hole'...")
max_tokens = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Max new tokens")
out = gr.Textbox(label="Generated CadQuery code", lines=18)
btn = gr.Button("Generate")
btn.click(run_prompt, inputs=[prompt, max_tokens], outputs=out)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))