Pasipid791 commited on
Commit
6b11178
·
verified ·
1 Parent(s): 2ac7eec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -476
app.py CHANGED
@@ -1,502 +1,109 @@
 
 
 
1
  import os
2
- import sys
3
  import json
4
- import torch
5
- import gradio as gr
6
- import numpy as np
7
- from PIL import Image
8
- from pathlib import Path
9
- import tempfile
10
- import subprocess
11
- import shutil
12
- from typing import Optional, List, Dict, Any
13
 
14
- # Add the src directory to Python path for imports
15
- sys.path.insert(0, './src')
 
16
 
 
17
  try:
18
- from transformers import (
19
- AutoTokenizer,
20
- AutoModelForCausalLM,
21
- LlamaTokenizer,
22
- LlamaForCausalLM
23
  )
24
- from huggingface_hub import snapshot_download, hf_hub_download
25
- print("✅ Successfully imported transformers and huggingface_hub")
26
- except ImportError as e:
27
- print(f"❌ Import error: {e}")
28
- print("Installing required packages...")
29
- subprocess.run([sys.executable, "-m", "pip", "install", "transformers", "huggingface_hub", "torch", "accelerate"])
30
- from transformers import AutoTokenizer, AutoModelForCausalLM
31
- from huggingface_hub import snapshot_download, hf_hub_download
32
-
33
- class CADFusionModel:
34
- def __init__(self, model_path: str = "microsoft/CADFusion", revision: str = "main"):
35
- """
36
- Initialize the CADFusion model
37
-
38
- Args:
39
- model_path: Path to the model on Hugging Face Hub
40
- revision: Model revision/branch (use 'main' instead of version numbers)
41
- """
42
- self.model_path = model_path
43
- self.revision = revision
44
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
-
46
- print(f"🚀 Initializing CADFusion from {model_path}@{revision} on {self.device}")
47
-
48
- # Initialize tokenizer and model
49
- self.tokenizer = None
50
- self.model = None
51
- self._load_model()
52
-
53
- # CAD sequence processing utilities
54
- self.max_sequence_length = 512
55
-
56
- def _load_model(self):
57
- """Load the tokenizer and model directly from Hugging Face Hub"""
58
- try:
59
- print(f"📦 Loading model from {self.model_path}")
60
-
61
- # Load tokenizer
62
- self.tokenizer = AutoTokenizer.from_pretrained(
63
- self.model_path,
64
- revision=self.revision,
65
- trust_remote_code=True,
66
- padding_side="left",
67
- token=os.getenv("HF_TOKEN") # Use HF token if available
68
- )
69
-
70
- # Ensure pad token exists
71
- if self.tokenizer.pad_token is None:
72
- self.tokenizer.pad_token = self.tokenizer.eos_token
73
-
74
- # Load model with appropriate dtype based on device
75
- model_kwargs = {
76
- "revision": self.revision,
77
- "trust_remote_code": True,
78
- "torch_dtype": torch.float16 if self.device.type == "cuda" else torch.float32,
79
- "token": os.getenv("HF_TOKEN")
80
- }
81
-
82
- # Add device mapping for CUDA
83
- if self.device.type == "cuda":
84
- model_kwargs["device_map"] = "auto"
85
- model_kwargs["low_cpu_mem_usage"] = True
86
-
87
- self.model = AutoModelForCausalLM.from_pretrained(
88
- self.model_path,
89
- **model_kwargs
90
- )
91
-
92
- # Move to device if not using device_map
93
- if self.device.type != "cuda":
94
- self.model = self.model.to(self.device)
95
-
96
- self.model.eval()
97
- print("✅ Model loaded successfully")
98
-
99
- except Exception as e:
100
- print(f"❌ Error loading model: {e}")
101
- print("📝 Setting up placeholder model for demo purposes")
102
- self._setup_placeholder_model()
103
-
104
- def _setup_placeholder_model(self):
105
- """Setup a placeholder model for demo purposes"""
106
- print("⚠️ Setting up placeholder model")
107
- # This is a fallback when the actual model can't be loaded
108
- self.model = None
109
- self.tokenizer = None
110
-
111
- def preprocess_text(self, text: str) -> str:
112
- """Preprocess input text for CAD generation"""
113
- # Basic text cleaning and formatting
114
- text = text.strip()
115
- if not text:
116
- return "Generate a simple 3D object"
117
-
118
- # Add any specific preprocessing for CAD descriptions
119
- if not any(word in text.lower() for word in ['create', 'design', 'make', 'generate', 'build']):
120
- text = f"Create a {text}"
121
-
122
- return text
123
-
124
- def generate_cad_sequence(self, text: str, max_length: int = 512, temperature: float = 0.7) -> Dict[str, Any]:
125
- """
126
- Generate CAD parametric sequence from text description
127
-
128
- Args:
129
- text: Text description of the CAD object
130
- max_length: Maximum sequence length
131
- temperature: Generation temperature
132
-
133
- Returns:
134
- Dictionary containing the generated sequence and metadata
135
- """
136
- try:
137
- if self.model is None or self.tokenizer is None:
138
- # Return placeholder response
139
- return {
140
- "success": False,
141
- "message": "Model not loaded - showing demo output",
142
- "sequence": self._generate_demo_sequence(text),
143
- "text_input": text,
144
- "parameters": {
145
- "max_length": max_length,
146
- "temperature": temperature
147
- }
148
- }
149
-
150
- # Preprocess input text
151
- processed_text = self.preprocess_text(text)
152
-
153
- # Add special formatting for CADFusion if needed
154
- # CADFusion may expect specific prompt formatting
155
- prompt = f"Design a CAD model: {processed_text}\nCAD sequence:"
156
-
157
- # Tokenize input
158
- inputs = self.tokenizer(
159
- prompt,
160
- return_tensors="pt",
161
- padding=True,
162
- truncation=True,
163
- max_length=256
164
- ).to(self.device)
165
-
166
- # Generate sequence
167
- with torch.no_grad():
168
- outputs = self.model.generate(
169
- inputs.input_ids,
170
- attention_mask=inputs.attention_mask,
171
- max_length=max_length,
172
- temperature=temperature,
173
- do_sample=True,
174
- top_p=0.9,
175
- top_k=50,
176
- pad_token_id=self.tokenizer.pad_token_id,
177
- eos_token_id=self.tokenizer.eos_token_id,
178
- repetition_penalty=1.1
179
- )
180
-
181
- # Decode output
182
- generated_sequence = self.tokenizer.decode(
183
- outputs[0],
184
- skip_special_tokens=True
185
- )
186
-
187
- # Extract the generated part (remove input prompt)
188
- if "CAD sequence:" in generated_sequence:
189
- generated_part = generated_sequence.split("CAD sequence:")[-1].strip()
190
- elif prompt in generated_sequence:
191
- generated_part = generated_sequence.replace(prompt, "").strip()
192
- else:
193
- generated_part = generated_sequence
194
-
195
- return {
196
- "success": True,
197
- "sequence": generated_part,
198
- "full_output": generated_sequence,
199
- "text_input": processed_text,
200
- "parameters": {
201
- "max_length": max_length,
202
- "temperature": temperature
203
- }
204
- }
205
-
206
- except Exception as e:
207
- print(f"❌ Generation error: {e}")
208
- return {
209
- "success": False,
210
- "message": f"Generation failed: {str(e)}",
211
- "sequence": self._generate_demo_sequence(text),
212
- "text_input": text
213
- }
214
-
215
- def _generate_demo_sequence(self, text: str) -> str:
216
- """Generate a demo CAD sequence for demonstration purposes"""
217
- # This is a simplified demo sequence based on the input text
218
- demo_sequences = {
219
- "cube": "NewSketch().Rectangle(0, 0, 10, 10).Extrude(10)",
220
- "cylinder": "NewSketch().Circle(0, 0, 5).Extrude(15)",
221
- "sphere": "NewSketch().Circle(0, 0, 5).Revolve(360, [0, 0, 1])",
222
- "bracket": "NewSketch().Rectangle(0, 0, 20, 10).Extrude(5).NewSketch('top').Circle(15, 5, 2).Cut(5)",
223
- "hole": "NewSketch().Rectangle(0, 0, 15, 8).Extrude(4).NewSketch('top').Circle(7.5, 4, 1.5).Cut(4)",
224
- "gear": "NewSketch().Circle(0, 0, 10).Extrude(3).NewSketch('top').Circle(0, 0, 2).Cut(3)",
225
- "pipe": "NewSketch().Circle(0, 0, 8).Extrude(20).NewSketch('top').Circle(0, 0, 6).Cut(20)",
226
- "bolt": "NewSketch().Circle(0, 0, 4).Extrude(15).NewSketch('top').RegularPolygon(6, 0, 0, 6).Extrude(3)"
227
- }
228
-
229
- text_lower = text.lower()
230
- for key, sequence in demo_sequences.items():
231
- if key in text_lower:
232
- return sequence
233
-
234
- # Default sequence for rectangular objects
235
- return "NewSketch().Rectangle(0, 0, 10, 10).Extrude(5)"
236
-
237
- # Global model instance
238
- model = None
239
-
240
- def initialize_model():
241
- """Initialize the global model instance"""
242
- global model
243
- if model is None:
244
- print("🔄 Initializing CADFusion model...")
245
- try:
246
- model = CADFusionModel()
247
- if model.model is None:
248
- print("⚠️ Model loaded in demo mode - using simulated responses")
249
- else:
250
- print("✅ Model loaded successfully!")
251
- except Exception as e:
252
- print(f"❌ Failed to initialize model: {e}")
253
- print("🔄 Creating fallback demo model...")
254
- model = CADFusionModel()
255
- return model
256
 
257
- def generate_cad(
258
- text_input: str,
259
- max_length: int = 512,
260
- temperature: float = 0.7
261
- ) -> tuple:
262
- """
263
- Gradio interface function for CAD generation
264
-
265
- Returns:
266
- Tuple of (generated_sequence, status_message, parameters_info)
267
- """
268
  try:
269
- # Initialize model if needed
270
- global model
271
- if model is None:
272
- model = initialize_model()
273
-
274
- # Validate inputs
275
- if not text_input or not text_input.strip():
276
- return "Please provide a text description.", "❌ Error: Empty input", "No parameters"
277
-
278
- # Generate CAD sequence
279
- result = model.generate_cad_sequence(
280
- text_input,
281
- max_length=max_length,
282
- temperature=temperature
283
  )
284
 
285
- # Format output
286
- if result["success"]:
287
- status = "✅ Generation successful"
288
- sequence = result["sequence"]
289
- else:
290
- status = f"⚠️ {result.get('message', 'Generation failed')}"
291
- sequence = result["sequence"]
292
-
293
- # Format parameters info
294
- params = result.get("parameters", {})
295
- param_info = f"Max Length: {params.get('max_length', max_length)}, Temperature: {params.get('temperature', temperature)}"
296
-
297
- return sequence, status, param_info
298
 
 
 
 
 
 
 
299
  except Exception as e:
300
- error_msg = f" Error: {str(e)}"
301
- return "Generation failed", error_msg, "No parameters"
302
 
 
303
  def create_gradio_interface():
304
- """Create the Gradio interface"""
305
-
306
- # Custom CSS for better styling
307
- css = """
308
- .gradio-container {
309
- font-family: 'Arial', sans-serif;
310
- }
311
- .gr-button-primary {
312
- background: linear-gradient(45deg, #1e3a8a, #3b82f6);
313
- border: none;
314
- }
315
- .gr-panel {
316
- border-radius: 8px;
317
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
318
- }
319
- .title-container {
320
- text-align: center;
321
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
322
- padding: 2rem;
323
- border-radius: 10px;
324
- margin-bottom: 2rem;
325
- color: white;
326
- }
327
- """
328
-
329
- with gr.Blocks(css=css, title="CADFusion - Text to CAD Generation") as interface:
330
-
331
- # Header
332
- with gr.HTML():
333
- gr.HTML("""
334
- <div class="title-container">
335
- <h1>🔧 CADFusion - Text to CAD Generation</h1>
336
- <p>Convert natural language descriptions into CAD parametric sequences using Microsoft's CADFusion model.</p>
337
- </div>
338
- """)
339
-
340
- gr.Markdown("""
341
- **Model**: microsoft/CADFusion (based on LLaMA-3-8B)
342
- **Paper**: [Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models](https://arxiv.org/abs/2501.19054)
343
- **Repository**: [GitHub](https://github.com/microsoft/CADFusion)
344
- """)
345
 
346
  with gr.Row():
347
- with gr.Column(scale=2):
348
- # Input section
349
- gr.Markdown("### 📝 Input")
350
  text_input = gr.Textbox(
351
- label="CAD Description",
352
- placeholder="Describe the CAD object you want to create (e.g., 'Create a cylindrical bracket with mounting holes')",
353
- lines=4,
354
- value="Create a rectangular bracket with two circular mounting holes"
355
- )
356
-
357
- # Parameters section
358
- gr.Markdown("### ⚙️ Generation Parameters")
359
- with gr.Row():
360
- max_length = gr.Slider(
361
- label="Max Length",
362
- minimum=128,
363
- maximum=1024,
364
- value=512,
365
- step=64,
366
- info="Maximum length of generated sequence"
367
- )
368
- temperature = gr.Slider(
369
- label="Temperature",
370
- minimum=0.1,
371
- maximum=1.5,
372
- value=0.7,
373
- step=0.1,
374
- info="Generation randomness (lower = more deterministic)"
375
- )
376
-
377
- # Generate button
378
- generate_btn = gr.Button(
379
- "🚀 Generate CAD Sequence",
380
- variant="primary",
381
- size="lg"
382
  )
 
383
 
384
- with gr.Column(scale=3):
385
- # Output section
386
- gr.Markdown("### 🎯 Generated CAD Sequence")
387
- sequence_output = gr.Textbox(
388
- label="Parametric Sequence",
389
- lines=10,
390
- interactive=False,
391
- placeholder="Generated CAD sequence will appear here..."
392
  )
393
-
394
- status_output = gr.Textbox(
395
- label="Status",
396
- lines=1,
397
- interactive=False
398
- )
399
-
400
- params_output = gr.Textbox(
401
- label="Parameters Used",
402
- lines=1,
403
- interactive=False
404
- )
405
-
406
- # Examples section
407
- gr.Markdown("### 💡 Example Prompts")
408
- examples = gr.Examples(
409
- examples=[
410
- ["Create a cylindrical rod with a square base"],
411
- ["Design a mounting bracket with four holes"],
412
- ["Make a simple cube with rounded corners"],
413
- ["Create a T-shaped connector piece"],
414
- ["Design a gear wheel with 12 teeth"],
415
- ["Make a pipe elbow joint at 90 degrees"],
416
- ["Create a hexagonal bolt head"],
417
- ["Design a simple housing enclosure"],
418
- ["Create a rectangular plate with center hole"],
419
- ["Design a cylindrical bearing housing"]
420
- ],
421
- inputs=[text_input],
422
- label="Click on any example to try it out"
423
- )
424
-
425
- # Information section
426
- with gr.Accordion("ℹ️ About CADFusion", open=False):
427
- gr.Markdown("""
428
- ### Model Overview
429
-
430
- CADFusion is a state-of-the-art text-to-CAD generation model that:
431
- - Uses visual feedback to enhance LLM performance
432
- - Generates parametric sequences for CAD modeling
433
- - Supports complex 3D object descriptions
434
- - Based on alternating sequential and visual learning stages
435
-
436
- ### Training Approach
437
- - **Sequential Learning**: Fine-tuning LLM with paired text-CAD data
438
- - **Visual Feedback**: Using vision-language models to improve generation quality
439
- - **Alternating Training**: 9 rounds of SL and VF stages for optimal performance
440
-
441
- ### Usage Tips
442
- - Be specific about shapes, dimensions, and features
443
- - Use technical CAD terminology when possible
444
- - Mention materials or constraints if relevant
445
- - Start with simple descriptions and add complexity gradually
446
-
447
- ### Model Specifications
448
- - **Base Model**: LLaMA-3-8B
449
- - **Training Data**: SkexGen dataset with human annotations
450
- - **License**: MIT License
451
- - **Intended Use**: Research and educational purposes
452
-
453
- ### Performance
454
- CADFusion significantly outperforms baselines like GPT-4o and Text2CAD:
455
- - **VLM Score**: 8.96 (vs 5.13 for GPT-4o, 2.01 for Text2CAD)
456
- - **Better**: Generation diversity, visual quality, and technical accuracy
457
- """)
458
 
459
- # Connect the generate button to the function
460
- generate_btn.click(
461
- fn=generate_cad,
462
- inputs=[text_input, max_length, temperature],
463
- outputs=[sequence_output, status_output, params_output],
464
- show_progress=True
465
  )
466
 
467
- # Auto-generate on example selection
468
- examples.click(
469
- fn=generate_cad,
470
- inputs=[text_input, max_length, temperature],
471
- outputs=[sequence_output, status_output, params_output],
472
- show_progress=True
473
- )
474
-
475
- return interface
476
-
477
- def main():
478
- """Main function to run the Gradio app"""
479
- print("===== Application Startup at {} =====".format(
480
- __import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')
481
- ))
482
- print("🌟 Starting CADFusion Gradio App")
483
-
484
- # Initialize model
485
- print("🔄 Initializing model...")
486
- initialize_model()
487
-
488
- # Create and launch interface
489
- interface = create_gradio_interface()
490
 
491
- # Launch configuration
492
- interface.launch(
493
- server_name="0.0.0.0", # Allow external access
494
- server_port=7860, # Standard Gradio port
495
- share=False, # Set to True for public sharing
496
- debug=False, # Disable debug mode in production
497
- show_error=True, # Show errors in interface
498
- quiet=False # Show startup logs
499
- )
500
 
 
501
  if __name__ == "__main__":
502
- main()
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import os
 
5
  import json
6
+ import logging
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
 
 
 
 
11
 
12
+ # Define model and checkpoint paths
13
+ MODEL_PATH = "microsoft/CADFusion"
14
+ REVISION = "2687619" # Use commit hash from the document
15
 
16
+ # Load model and tokenizer
17
  try:
18
+ logger.info("Loading tokenizer...")
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
+ MODEL_PATH,
21
+ revision=REVISION,
22
+ trust_remote_code=True
23
  )
24
+ logger.info("Loading model...")
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ MODEL_PATH,
27
+ revision=REVISION,
28
+ torch_dtype=torch.float16,
29
+ device_map="auto",
30
+ trust_remote_code=True
31
+ )
32
+ logger.info("Model and tokenizer loaded successfully.")
33
+ except Exception as e:
34
+ logger.error(f"Error loading model or tokenizer: {e}")
35
+ raise Exception(f"Failed to load model from {MODEL_PATH} with revision {REVISION}. Please check the repository and revision ID.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Function to generate CAD model from text description
38
+ def generate_cad_model(text_description):
 
 
 
 
 
 
 
 
 
39
  try:
40
+ if not text_description.strip():
41
+ return "Error: Please provide a valid text description."
42
+
43
+ # Tokenize input
44
+ inputs = tokenizer(text_description, return_tensors="pt").to(model.device)
45
+
46
+ # Generate output
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_length=512,
50
+ num_return_sequences=1,
51
+ do_sample=True,
52
+ temperature=0.7,
53
+ top_p=0.9
54
  )
55
 
56
+ # Decode output
57
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Parse generated text to extract CAD model data (assuming JSON-like output)
60
+ try:
61
+ cad_data = json.loads(generated_text)
62
+ return json.dumps(cad_data, indent=2)
63
+ except json.JSONDecodeError:
64
+ return generated_text # Return raw text if JSON parsing fails
65
  except Exception as e:
66
+ logger.error(f"Error during generation: {e}")
67
+ return f"Error: {str(e)}"
68
 
69
+ # Gradio interface
70
  def create_gradio_interface():
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("# CADFusion: Text-to-CAD Generation")
73
+ gr.Markdown("Enter a textual description of the CAD model you want to generate. For example: 'A 3D model of a chair with four legs and a curved backrest.'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  with gr.Row():
76
+ with gr.Column():
 
 
77
  text_input = gr.Textbox(
78
+ label="Text Description",
79
+ placeholder="Enter your CAD model description here...",
80
+ lines=5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
+ submit_button = gr.Button("Generate CAD Model")
83
 
84
+ with gr.Column():
85
+ output_text = gr.Textbox(
86
+ label="Generated CAD Model (JSON or Text)",
87
+ placeholder="Generated output will appear here...",
88
+ lines=10
 
 
 
89
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ submit_button.click(
92
+ fn=generate_cad_model,
93
+ inputs=text_input,
94
+ outputs=output_text
 
 
95
  )
96
 
97
+ gr.Markdown("""
98
+ **Note**:
99
+ - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
100
+ - Ensure descriptions are clear and specific for best results.
101
+ - For more details, visit the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion).
102
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ return demo
 
 
 
 
 
 
 
 
105
 
106
+ # Launch Gradio app
107
  if __name__ == "__main__":
108
+ demo = create_gradio_interface()
109
+ demo.launch(server_name="0.0.0.0", server_port=7860)