Apple commited on
Commit
be0593d
·
1 Parent(s): 9ba6e21

Initial CADCoder Space with Gradio

Browse files
Files changed (2) hide show
  1. app.py +117 -34
  2. requirements.txt +6 -4
app.py CHANGED
@@ -1,39 +1,122 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
-
5
- # Load CAD-Coder model from Hugging Face
6
- MODEL_NAME = "CADCODER/CAD-Coder"
7
-
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- MODEL_NAME,
11
- trust_remote_code=True,
12
- torch_dtype=torch.float16,
13
- device_map="auto"
14
- )
15
-
16
- def generate_code(prompt):
17
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
18
- with torch.no_grad():
19
- outputs = model.generate(
20
- **inputs,
21
- max_new_tokens=512,
22
- do_sample=True,
23
- temperature=0.7,
24
- top_p=0.9
25
- )
26
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Gradio UI
29
- demo = gr.Interface(
30
- fn=generate_code,
31
- inputs=gr.Textbox(lines=5, placeholder="Enter your CAD design prompt..."),
32
- outputs="text",
33
- title="CAD-Coder Inference",
34
- description="Generate CAD code from natural language using CAD-Coder."
35
- )
 
36
 
37
- if __name__ == "__main__":
38
- demo.launch()
 
 
 
 
 
39
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import subprocess
4
+ import traceback
5
  import gradio as gr
6
+
7
+ MODEL_ID = "CADCODER/CAD-Coder" # HF model id
8
+ REPO_GIT = "https://github.com/CADCODER/CAD-Coder.git"
9
+ REPO_DIR = "CAD-Coder"
10
+
11
+ # 1) git-clone the repo if missing (your preference)
12
+ if not os.path.isdir(REPO_DIR):
13
+ try:
14
+ print("Cloning CAD-Coder repo...")
15
+ subprocess.run(["git", "clone", REPO_GIT, REPO_DIR], check=True)
16
+ except Exception as e:
17
+ print("Could not clone repository:", e)
18
+
19
+ # 2) Prepare model loader with graceful fallback to HF Inference API
20
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_API_TOKEN")
21
+ local_generate = None
22
+ api_generate = None
23
+
24
+ # Try to load model locally (8-bit if possible)
25
+ try:
26
+ import torch
27
+ from transformers import AutoTokenizer, AutoModelForCausalLM
28
+ try:
29
+ import bitsandbytes as bnb # optional; enables 8-bit loading
30
+ has_bnb = True
31
+ except Exception:
32
+ has_bnb = False
33
+
34
+ print("Loading tokenizer...")
35
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=hf_token, trust_remote_code=True)
36
+
37
+ load_kwargs = {"device_map": "auto", "trust_remote_code": True}
38
+ if has_bnb:
39
+ print("bitsandbytes available — will attempt 8-bit load (saves memory).")
40
+ load_kwargs.update({"load_in_8bit": True, "torch_dtype": torch.float16})
41
+ else:
42
+ # attempt fp16 auto if GPU present
43
+ if torch.cuda.is_available():
44
+ load_kwargs["torch_dtype"] = torch.float16
45
+
46
+ print("Loading model (this can take a while)...")
47
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=hf_token, **load_kwargs)
48
+
49
+ if hasattr(model, "to"):
50
+ # make sure model moved to devices by device_map
51
+ pass
52
+
53
+ device = next(model.parameters()).device
54
+ print("Model loaded on device:", device)
55
+
56
+ def local_generate_fn(prompt, max_new_tokens=512):
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
58
+ gen = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
59
+ return tokenizer.decode(gen[0], skip_special_tokens=True)
60
+
61
+ local_generate = local_generate_fn
62
+
63
+ except Exception as e:
64
+ print("Local model load failed or not feasible in this environment.")
65
+ traceback.print_exc()
66
+
67
+ # Fallback: Hugging Face Inference API (works without loading weights locally)
68
+ if local_generate is None:
69
+ try:
70
+ from huggingface_hub import InferenceApi
71
+ print("Setting up HF Inference API client as fallback...")
72
+ api = InferenceApi(repo_id=MODEL_ID, token=hf_token)
73
+
74
+ def api_generate_fn(prompt, max_new_tokens=512):
75
+ # call the hosted inference endpoint
76
+ out = api(inputs=prompt, params={"max_new_tokens": max_new_tokens})
77
+ # Response can be a dict or list depending on pipeline; extract defensively
78
+ if isinstance(out, list):
79
+ first = out[0]
80
+ if isinstance(first, dict):
81
+ return first.get("generated_text") or str(first)
82
+ return str(first)
83
+ elif isinstance(out, dict):
84
+ return out.get("generated_text") or str(out)
85
+ else:
86
+ return str(out)
87
+
88
+ api_generate = api_generate_fn
89
+ print("Inference API fallback ready.")
90
+ except Exception as e:
91
+ print("HF Inference API not available:", e)
92
+ traceback.print_exc()
93
+
94
+ # Final generate function: prefer local, otherwise API fallback, otherwise error
95
+ def generate(prompt, max_new_tokens=512):
96
+ if local_generate:
97
+ return local_generate(prompt, max_new_tokens=max_new_tokens)
98
+ elif api_generate:
99
+ return api_generate(prompt, max_new_tokens=max_new_tokens)
100
+ else:
101
+ return "ERROR: No model loaded and no API fallback available. Check HF_TOKEN and Space hardware."
102
 
103
  # Gradio UI
104
+ def run_prompt(prompt, max_tokens=512):
105
+ if not prompt or prompt.strip() == "":
106
+ return "Enter a prompt describing the CAD sketch you want (e.g., 'rectangle width 10 height 5 with hole radius 1')."
107
+ try:
108
+ return generate(prompt, max_new_tokens=int(max_tokens))
109
+ except Exception as e:
110
+ traceback.print_exc()
111
+ return f"Generation error: {e}"
112
 
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("# CAD-Coder (Text → CadQuery code)")
115
+ prompt = gr.Textbox(label="Natural language prompt", lines=4, placeholder="e.g. 'create a rectangular plate 100x50 with a centered 10mm hole'...")
116
+ max_tokens = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Max new tokens")
117
+ out = gr.Textbox(label="Generated CadQuery code", lines=18)
118
+ btn = gr.Button("Generate")
119
+ btn.click(run_prompt, inputs=[prompt, max_tokens], outputs=out)
120
 
121
+ if __name__ == "__main__":
122
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
- torch
2
- transformers
 
3
  gradio
4
- git+https://github.com/huggingface/accelerate.git
5
-
 
 
1
+ transformers>=4.30.0
2
+ accelerate
3
+ huggingface-hub
4
  gradio
5
+ torch
6
+ bitsandbytes
7
+ gitpython