|
|
|
|
|
import os |
|
|
import subprocess |
|
|
import traceback |
|
|
import gradio as gr |
|
|
|
|
|
MODEL_ID = "CADCODER/CAD-Coder" |
|
|
REPO_GIT = "https://github.com/CADCODER/CAD-Coder.git" |
|
|
REPO_DIR = "CAD-Coder" |
|
|
|
|
|
|
|
|
if not os.path.isdir(REPO_DIR): |
|
|
try: |
|
|
print("Cloning CAD-Coder repo...") |
|
|
subprocess.run(["git", "clone", REPO_GIT, REPO_DIR], check=True) |
|
|
except Exception as e: |
|
|
print("Could not clone repository:", e) |
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_API_TOKEN") |
|
|
local_generate = None |
|
|
api_generate = None |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
has_bnb = True |
|
|
except Exception: |
|
|
has_bnb = False |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=hf_token, trust_remote_code=True) |
|
|
|
|
|
load_kwargs = {"device_map": "auto", "trust_remote_code": True} |
|
|
if has_bnb: |
|
|
print("bitsandbytes available β will attempt 8-bit load (saves memory).") |
|
|
load_kwargs.update({"load_in_8bit": True, "torch_dtype": torch.float16}) |
|
|
else: |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
load_kwargs["torch_dtype"] = torch.float16 |
|
|
|
|
|
print("Loading model (this can take a while)...") |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=hf_token, **load_kwargs) |
|
|
|
|
|
if hasattr(model, "to"): |
|
|
|
|
|
pass |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
print("Model loaded on device:", device) |
|
|
|
|
|
def local_generate_fn(prompt, max_new_tokens=512): |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
gen = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) |
|
|
return tokenizer.decode(gen[0], skip_special_tokens=True) |
|
|
|
|
|
local_generate = local_generate_fn |
|
|
|
|
|
except Exception as e: |
|
|
print("Local model load failed or not feasible in this environment.") |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if local_generate is None: |
|
|
try: |
|
|
from huggingface_hub import InferenceApi |
|
|
print("Setting up HF Inference API client as fallback...") |
|
|
api = InferenceApi(repo_id=MODEL_ID, token=hf_token) |
|
|
|
|
|
def api_generate_fn(prompt, max_new_tokens=512): |
|
|
|
|
|
out = api(inputs=prompt, params={"max_new_tokens": max_new_tokens}) |
|
|
|
|
|
if isinstance(out, list): |
|
|
first = out[0] |
|
|
if isinstance(first, dict): |
|
|
return first.get("generated_text") or str(first) |
|
|
return str(first) |
|
|
elif isinstance(out, dict): |
|
|
return out.get("generated_text") or str(out) |
|
|
else: |
|
|
return str(out) |
|
|
|
|
|
api_generate = api_generate_fn |
|
|
print("Inference API fallback ready.") |
|
|
except Exception as e: |
|
|
print("HF Inference API not available:", e) |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
def generate(prompt, max_new_tokens=512): |
|
|
if local_generate: |
|
|
return local_generate(prompt, max_new_tokens=max_new_tokens) |
|
|
elif api_generate: |
|
|
return api_generate(prompt, max_new_tokens=max_new_tokens) |
|
|
else: |
|
|
return "ERROR: No model loaded and no API fallback available. Check HF_TOKEN and Space hardware." |
|
|
|
|
|
|
|
|
def run_prompt(prompt, max_tokens=512): |
|
|
if not prompt or prompt.strip() == "": |
|
|
return "Enter a prompt describing the CAD sketch you want (e.g., 'rectangle width 10 height 5 with hole radius 1')." |
|
|
try: |
|
|
return generate(prompt, max_new_tokens=int(max_tokens)) |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
return f"Generation error: {e}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# CAD-Coder (Text β CadQuery code)") |
|
|
prompt = gr.Textbox(label="Natural language prompt", lines=4, placeholder="e.g. 'create a rectangular plate 100x50 with a centered 10mm hole'...") |
|
|
max_tokens = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Max new tokens") |
|
|
out = gr.Textbox(label="Generated CadQuery code", lines=18) |
|
|
btn = gr.Button("Generate") |
|
|
btn.click(run_prompt, inputs=[prompt, max_tokens], outputs=out) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |
|
|
|