File size: 4,833 Bytes
be0593d
 
 
 
9ba6e21
be0593d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba6e21
 
be0593d
 
 
 
 
 
 
 
9ba6e21
be0593d
 
 
 
 
 
 
9ba6e21
be0593d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# app.py
import os
import subprocess
import traceback
import gradio as gr

MODEL_ID = "CADCODER/CAD-Coder"   # HF model id
REPO_GIT = "https://github.com/CADCODER/CAD-Coder.git"
REPO_DIR = "CAD-Coder"

# 1) git-clone the repo if missing (your preference)
if not os.path.isdir(REPO_DIR):
    try:
        print("Cloning CAD-Coder repo...")
        subprocess.run(["git", "clone", REPO_GIT, REPO_DIR], check=True)
    except Exception as e:
        print("Could not clone repository:", e)

# 2) Prepare model loader with graceful fallback to HF Inference API
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_API_TOKEN")
local_generate = None
api_generate = None

# Try to load model locally (8-bit if possible)
try:
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    try:
        import bitsandbytes as bnb  # optional; enables 8-bit loading
        has_bnb = True
    except Exception:
        has_bnb = False

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=hf_token, trust_remote_code=True)

    load_kwargs = {"device_map": "auto", "trust_remote_code": True}
    if has_bnb:
        print("bitsandbytes available β€” will attempt 8-bit load (saves memory).")
        load_kwargs.update({"load_in_8bit": True, "torch_dtype": torch.float16})
    else:
        # attempt fp16 auto if GPU present
        if torch.cuda.is_available():
            load_kwargs["torch_dtype"] = torch.float16

    print("Loading model (this can take a while)...")
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=hf_token, **load_kwargs)

    if hasattr(model, "to"):
        # make sure model moved to devices by device_map
        pass

    device = next(model.parameters()).device
    print("Model loaded on device:", device)

    def local_generate_fn(prompt, max_new_tokens=512):
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        gen = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        return tokenizer.decode(gen[0], skip_special_tokens=True)

    local_generate = local_generate_fn

except Exception as e:
    print("Local model load failed or not feasible in this environment.")
    traceback.print_exc()

# Fallback: Hugging Face Inference API (works without loading weights locally)
if local_generate is None:
    try:
        from huggingface_hub import InferenceApi
        print("Setting up HF Inference API client as fallback...")
        api = InferenceApi(repo_id=MODEL_ID, token=hf_token)

        def api_generate_fn(prompt, max_new_tokens=512):
            # call the hosted inference endpoint
            out = api(inputs=prompt, params={"max_new_tokens": max_new_tokens})
            # Response can be a dict or list depending on pipeline; extract defensively
            if isinstance(out, list):
                first = out[0]
                if isinstance(first, dict):
                    return first.get("generated_text") or str(first)
                return str(first)
            elif isinstance(out, dict):
                return out.get("generated_text") or str(out)
            else:
                return str(out)

        api_generate = api_generate_fn
        print("Inference API fallback ready.")
    except Exception as e:
        print("HF Inference API not available:", e)
        traceback.print_exc()

# Final generate function: prefer local, otherwise API fallback, otherwise error
def generate(prompt, max_new_tokens=512):
    if local_generate:
        return local_generate(prompt, max_new_tokens=max_new_tokens)
    elif api_generate:
        return api_generate(prompt, max_new_tokens=max_new_tokens)
    else:
        return "ERROR: No model loaded and no API fallback available. Check HF_TOKEN and Space hardware."

# Gradio UI
def run_prompt(prompt, max_tokens=512):
    if not prompt or prompt.strip() == "":
        return "Enter a prompt describing the CAD sketch you want (e.g., 'rectangle width 10 height 5 with hole radius 1')."
    try:
        return generate(prompt, max_new_tokens=int(max_tokens))
    except Exception as e:
        traceback.print_exc()
        return f"Generation error: {e}"

with gr.Blocks() as demo:
    gr.Markdown("# CAD-Coder (Text β†’ CadQuery code)")
    prompt = gr.Textbox(label="Natural language prompt", lines=4, placeholder="e.g. 'create a rectangular plate 100x50 with a centered 10mm hole'...")
    max_tokens = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Max new tokens")
    out = gr.Textbox(label="Generated CadQuery code", lines=18)
    btn = gr.Button("Generate")
    btn.click(run_prompt, inputs=[prompt, max_tokens], outputs=out)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))