kshdes37 commited on
Commit
2f0a32a
Β·
verified Β·
1 Parent(s): e74dac6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -0
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ import os
5
+ import tempfile
6
+ import subprocess
7
+ import sys
8
+ from pathlib import Path
9
+ from huggingface_hub import snapshot_download
10
+ import logging
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class CADFusionInference:
17
+ def __init__(self):
18
+ self.model = None
19
+ self.tokenizer = None
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ self.model_loaded = False
22
+
23
+ def load_model(self, model_path="microsoft/CADFusion", revision="v1_1"):
24
+ """Load the CADFusion model and tokenizer"""
25
+ try:
26
+ logger.info(f"Loading CADFusion model from {model_path} (revision: {revision})")
27
+
28
+ # Download model files
29
+ model_dir = snapshot_download(
30
+ repo_id=model_path,
31
+ revision=revision,
32
+ cache_dir="./model_cache"
33
+ )
34
+
35
+ # Try to load the model - this is a placeholder as we need to see the actual model structure
36
+ # The actual implementation would depend on the model architecture used
37
+ from transformers import AutoTokenizer, AutoModelForCausalLM
38
+
39
+ # Load tokenizer
40
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
41
+ if self.tokenizer.pad_token is None:
42
+ self.tokenizer.pad_token = self.tokenizer.eos_token
43
+
44
+ # Load model
45
+ self.model = AutoModelForCausalLM.from_pretrained(
46
+ model_dir,
47
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
48
+ device_map="auto" if torch.cuda.is_available() else None,
49
+ trust_remote_code=True
50
+ )
51
+
52
+ self.model_loaded = True
53
+ logger.info("Model loaded successfully!")
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error loading model: {str(e)}")
57
+ raise e
58
+
59
+ def generate_cad_sequence(self, text_prompt, max_length=512, temperature=0.8, top_p=0.9):
60
+ """Generate CAD sequence from text prompt"""
61
+ if not self.model_loaded:
62
+ raise ValueError("Model not loaded. Please load the model first.")
63
+
64
+ try:
65
+ # Format the prompt for CAD generation
66
+ formatted_prompt = f"Generate CAD sequence for: {text_prompt}\nCAD:"
67
+
68
+ # Tokenize input
69
+ inputs = self.tokenizer.encode(formatted_prompt, return_tensors="pt")
70
+ inputs = inputs.to(self.device)
71
+
72
+ # Generate
73
+ with torch.no_grad():
74
+ outputs = self.model.generate(
75
+ inputs,
76
+ max_length=max_length,
77
+ temperature=temperature,
78
+ top_p=top_p,
79
+ do_sample=True,
80
+ pad_token_id=self.tokenizer.pad_token_id,
81
+ eos_token_id=self.tokenizer.eos_token_id
82
+ )
83
+
84
+ # Decode output
85
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+
87
+ # Extract CAD sequence (remove the prompt part)
88
+ cad_sequence = generated_text[len(formatted_prompt):].strip()
89
+
90
+ return cad_sequence
91
+
92
+ except Exception as e:
93
+ logger.error(f"Error generating CAD sequence: {str(e)}")
94
+ raise e
95
+
96
+ def render_cad_visualization(self, cad_sequence):
97
+ """Convert CAD sequence to visualization (placeholder - would need actual rendering code)"""
98
+ # This is a placeholder function. In the actual implementation, you would:
99
+ # 1. Parse the CAD sequence into geometric operations
100
+ # 2. Use the rendering utilities from the CADFusion repo
101
+ # 3. Generate 3D visualization or images
102
+
103
+ try:
104
+ # Create a simple text representation for now
105
+ visualization_info = {
106
+ "sequence": cad_sequence,
107
+ "operations": cad_sequence.count("extrude") + cad_sequence.count("revolve"),
108
+ "sketches": cad_sequence.count("sketch"),
109
+ "status": "Generated (visualization placeholder)"
110
+ }
111
+
112
+ return visualization_info
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error rendering CAD: {str(e)}")
116
+ return {"error": str(e)}
117
+
118
+ # Initialize the inference class
119
+ cad_fusion = CADFusionInference()
120
+
121
+ def generate_cad_from_text(text_prompt, max_length=512, temperature=0.8, top_p=0.9):
122
+ """Main function for Gradio interface"""
123
+ try:
124
+ # Load model if not already loaded
125
+ if not cad_fusion.model_loaded:
126
+ cad_fusion.load_model()
127
+
128
+ # Generate CAD sequence
129
+ cad_sequence = cad_fusion.generate_cad_sequence(
130
+ text_prompt,
131
+ max_length=int(max_length),
132
+ temperature=temperature,
133
+ top_p=top_p
134
+ )
135
+
136
+ # Create visualization info
137
+ viz_info = cad_fusion.render_cad_visualization(cad_sequence)
138
+
139
+ # Format output
140
+ output_text = f"""
141
+ **Generated CAD Sequence:**
142
+ {cad_sequence}
143
+
144
+ **Analysis:**
145
+ - Operations detected: {viz_info.get('operations', 0)}
146
+ - Sketches detected: {viz_info.get('sketches', 0)}
147
+ - Status: {viz_info.get('status', 'Generated')}
148
+ """
149
+
150
+ return output_text, cad_sequence
151
+
152
+ except Exception as e:
153
+ error_msg = f"Error: {str(e)}"
154
+ logger.error(error_msg)
155
+ return error_msg, ""
156
+
157
+ def create_gradio_interface():
158
+ """Create the Gradio interface"""
159
+
160
+ with gr.Blocks(
161
+ title="CADFusion - Text-to-CAD Generation",
162
+ theme=gr.themes.Soft(),
163
+ css="""
164
+ .gradio-container {
165
+ max-width: 1200px;
166
+ margin: auto;
167
+ }
168
+ .title {
169
+ text-align: center;
170
+ margin-bottom: 20px;
171
+ }
172
+ """
173
+ ) as demo:
174
+
175
+ gr.Markdown("""
176
+ # πŸ”§ CADFusion - Text-to-CAD Generation
177
+
178
+ Convert natural language descriptions into CAD model sequences using Microsoft's CADFusion framework.
179
+
180
+ **Features:**
181
+ - Generate parametric CAD sequences from text descriptions
182
+ - Built on fine-tuned LLMs with visual feedback learning
183
+ - Supports complex 3D modeling operations
184
+
185
+ **Example prompts:**
186
+ - "Create a cylindrical cup with a handle"
187
+ - "Design a rectangular bracket with mounting holes"
188
+ - "Generate a gear wheel with 12 teeth"
189
+ """, elem_classes="title")
190
+
191
+ with gr.Row():
192
+ with gr.Column(scale=2):
193
+ # Input section
194
+ gr.Markdown("## πŸ“ Input")
195
+ text_input = gr.Textbox(
196
+ label="CAD Description",
197
+ placeholder="Describe the CAD model you want to generate...",
198
+ lines=3,
199
+ value="Create a simple cylindrical cup with a handle on the side"
200
+ )
201
+
202
+ with gr.Accordion("Advanced Settings", open=False):
203
+ max_length = gr.Slider(
204
+ minimum=128,
205
+ maximum=1024,
206
+ value=512,
207
+ step=32,
208
+ label="Max Sequence Length"
209
+ )
210
+ temperature = gr.Slider(
211
+ minimum=0.1,
212
+ maximum=2.0,
213
+ value=0.8,
214
+ step=0.1,
215
+ label="Temperature"
216
+ )
217
+ top_p = gr.Slider(
218
+ minimum=0.1,
219
+ maximum=1.0,
220
+ value=0.9,
221
+ step=0.05,
222
+ label="Top-p"
223
+ )
224
+
225
+ generate_btn = gr.Button(
226
+ "πŸš€ Generate CAD",
227
+ variant="primary",
228
+ size="lg"
229
+ )
230
+
231
+ with gr.Column(scale=3):
232
+ # Output section
233
+ gr.Markdown("## 🎯 Generated CAD")
234
+ output_display = gr.Markdown(label="Results")
235
+
236
+ with gr.Accordion("Raw CAD Sequence", open=False):
237
+ raw_sequence = gr.Textbox(
238
+ label="CAD Sequence",
239
+ lines=10,
240
+ max_lines=15,
241
+ show_copy_button=True
242
+ )
243
+
244
+ # Examples section
245
+ gr.Markdown("## πŸ“š Example Prompts")
246
+ examples = gr.Examples(
247
+ examples=[
248
+ ["Create a simple cylindrical cup with a handle"],
249
+ ["Design a rectangular bracket with four mounting holes"],
250
+ ["Generate a gear wheel with 10 teeth and a central hole"],
251
+ ["Make a L-shaped bracket for wall mounting"],
252
+ ["Create a hexagonal nut with internal threading"],
253
+ ["Design a simple phone stand with an angled surface"],
254
+ ],
255
+ inputs=[text_input],
256
+ label="Click on any example to try it"
257
+ )
258
+
259
+ # Event handlers
260
+ generate_btn.click(
261
+ fn=generate_cad_from_text,
262
+ inputs=[text_input, max_length, temperature, top_p],
263
+ outputs=[output_display, raw_sequence],
264
+ show_progress=True
265
+ )
266
+
267
+ # Footer
268
+ gr.Markdown("""
269
+ ---
270
+ **About CADFusion:**
271
+ This model is based on the paper ["Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models"](https://arxiv.org/abs/2501.19054) by Microsoft Research.
272
+
273
+ **Note:** This demo shows the text-to-sequence generation capability. Full 3D rendering would require additional computational resources and the complete CADFusion rendering pipeline.
274
+ """)
275
+
276
+ return demo
277
+
278
+ # Create and launch the interface
279
+ if __name__ == "__main__":
280
+ try:
281
+ # Pre-load the model for better performance
282
+ logger.info("Initializing CADFusion model...")
283
+
284
+ demo = create_gradio_interface()
285
+
286
+ # Launch the app
287
+ demo.launch(
288
+ server_name="0.0.0.0",
289
+ server_port=7860,
290
+ share=False,
291
+ show_error=True,
292
+ show_tips=True,
293
+ enable_queue=True,
294
+ max_threads=4
295
+ )
296
+
297
+ except Exception as e:
298
+ logger.error(f"Failed to launch application: {str(e)}")
299
+ sys.exit(1)