# app.py import os import subprocess import traceback import gradio as gr MODEL_ID = "CADCODER/CAD-Coder" # HF model id REPO_GIT = "https://github.com/CADCODER/CAD-Coder.git" REPO_DIR = "CAD-Coder" # 1) git-clone the repo if missing (your preference) if not os.path.isdir(REPO_DIR): try: print("Cloning CAD-Coder repo...") subprocess.run(["git", "clone", REPO_GIT, REPO_DIR], check=True) except Exception as e: print("Could not clone repository:", e) # 2) Prepare model loader with graceful fallback to HF Inference API hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_API_TOKEN") local_generate = None api_generate = None # Try to load model locally (8-bit if possible) try: import torch from transformers import AutoTokenizer, AutoModelForCausalLM try: import bitsandbytes as bnb # optional; enables 8-bit loading has_bnb = True except Exception: has_bnb = False print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=hf_token, trust_remote_code=True) load_kwargs = {"device_map": "auto", "trust_remote_code": True} if has_bnb: print("bitsandbytes available — will attempt 8-bit load (saves memory).") load_kwargs.update({"load_in_8bit": True, "torch_dtype": torch.float16}) else: # attempt fp16 auto if GPU present if torch.cuda.is_available(): load_kwargs["torch_dtype"] = torch.float16 print("Loading model (this can take a while)...") model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=hf_token, **load_kwargs) if hasattr(model, "to"): # make sure model moved to devices by device_map pass device = next(model.parameters()).device print("Model loaded on device:", device) def local_generate_fn(prompt, max_new_tokens=512): inputs = tokenizer(prompt, return_tensors="pt").to(device) gen = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) return tokenizer.decode(gen[0], skip_special_tokens=True) local_generate = local_generate_fn except Exception as e: print("Local model load failed or not feasible in this environment.") traceback.print_exc() # Fallback: Hugging Face Inference API (works without loading weights locally) if local_generate is None: try: from huggingface_hub import InferenceApi print("Setting up HF Inference API client as fallback...") api = InferenceApi(repo_id=MODEL_ID, token=hf_token) def api_generate_fn(prompt, max_new_tokens=512): # call the hosted inference endpoint out = api(inputs=prompt, params={"max_new_tokens": max_new_tokens}) # Response can be a dict or list depending on pipeline; extract defensively if isinstance(out, list): first = out[0] if isinstance(first, dict): return first.get("generated_text") or str(first) return str(first) elif isinstance(out, dict): return out.get("generated_text") or str(out) else: return str(out) api_generate = api_generate_fn print("Inference API fallback ready.") except Exception as e: print("HF Inference API not available:", e) traceback.print_exc() # Final generate function: prefer local, otherwise API fallback, otherwise error def generate(prompt, max_new_tokens=512): if local_generate: return local_generate(prompt, max_new_tokens=max_new_tokens) elif api_generate: return api_generate(prompt, max_new_tokens=max_new_tokens) else: return "ERROR: No model loaded and no API fallback available. Check HF_TOKEN and Space hardware." # Gradio UI def run_prompt(prompt, max_tokens=512): if not prompt or prompt.strip() == "": return "Enter a prompt describing the CAD sketch you want (e.g., 'rectangle width 10 height 5 with hole radius 1')." try: return generate(prompt, max_new_tokens=int(max_tokens)) except Exception as e: traceback.print_exc() return f"Generation error: {e}" with gr.Blocks() as demo: gr.Markdown("# CAD-Coder (Text → CadQuery code)") prompt = gr.Textbox(label="Natural language prompt", lines=4, placeholder="e.g. 'create a rectangular plate 100x50 with a centered 10mm hole'...") max_tokens = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Max new tokens") out = gr.Textbox(label="Generated CadQuery code", lines=18) btn = gr.Button("Generate") btn.click(run_prompt, inputs=[prompt, max_tokens], outputs=out) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))