Pasipid791 commited on
Commit
e7a03ef
Β·
verified Β·
1 Parent(s): 3f7ecb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +429 -120
app.py CHANGED
@@ -1,149 +1,458 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from huggingface_hub import snapshot_download
5
  import os
 
6
  import json
7
- import logging
8
-
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- # Define model and checkpoint paths
14
- MODEL_REPO = "microsoft/CADFusion"
15
- CHECKPOINT_REVISION = "main"
16
- CHECKPOINT_SUBFOLDER = "exp/model_ckpt/v1_1"
17
- LOCAL_CHECKPOINT_DIR = "./model_ckpt/v1_1"
18
- FALLBACK_MODEL = "meta-llama/Llama-2-7b"
19
 
20
- # Ensure local checkpoint directory exists
21
- os.makedirs(LOCAL_CHECKPOINT_DIR, exist_ok=True)
22
 
23
- # Download checkpoint files
24
  try:
25
- logger.info("Downloading checkpoint files...")
26
- snapshot_download(
27
- repo_id=MODEL_REPO,
28
- revision=CHECKPOINT_REVISION,
29
- allow_patterns=f"{CHECKPOINT_SUBFOLDER}/*",
30
- local_dir=LOCAL_CHECKPOINT_DIR,
31
- local_dir_use_symlinks=False
32
  )
33
- logger.info("Checkpoint files downloaded successfully.")
34
- except Exception as e:
35
- logger.error(f"Error downloading checkpoint files: {str(e)}")
36
- raise e
 
 
 
 
37
 
38
- # Load model and tokenizer
39
- try:
40
- logger.info("Loading tokenizer from local checkpoint...")
41
- tokenizer = AutoTokenizer.from_pretrained(
42
- LOCAL_CHECKPOINT_DIR,
43
- trust_remote_code=True
44
- )
45
- logger.info("Loading model from local checkpoint...")
46
- model = AutoModelForCausalLM.from_pretrained(
47
- LOCAL_CHECKPOINT_DIR,
48
- torch_dtype=torch.float16,
49
- device_map="auto",
50
- trust_remote_code=True
51
- )
52
- logger.info("Model and tokenizer loaded successfully from local checkpoint.")
53
- except Exception as e:
54
- logger.error(f"Error loading from local checkpoint: {str(e)}")
55
- logger.info(f"Falling back to {FALLBACK_MODEL}...")
56
- try:
57
- tokenizer = AutoTokenizer.from_pretrained(
58
- FALLBACK_MODEL,
59
- trust_remote_code=True
60
- )
61
- model = AutoModelForCausalLM.from_pretrained(
62
- FALLBACK_MODEL,
63
- torch_dtype=torch.float16,
64
- device_map="auto",
65
- trust_remote_code=True
66
- )
67
- logger.info(f"Fallback model {FALLBACK_MODEL} loaded successfully.")
68
- except Exception as fallback_e:
69
- logger.error(f"Error loading fallback model: {str(fallback_e)}")
70
- raise fallback_e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # Function to generate CAD model from text description
73
- def generate_cad_model(text_description):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
- if not text_description.strip():
76
- return "Error: Please provide a valid text description."
77
-
78
- # Tokenize input
79
- inputs = tokenizer(text_description, return_tensors="pt").to(model.device)
80
-
81
- # Generate output
82
- outputs = model.generate(
83
- **inputs,
84
- max_length=512,
85
- num_return_sequences=1,
86
- do_sample=True,
87
- temperature=0.7,
88
- top_p=0.9
89
  )
90
 
91
- # Decode output
92
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # Parse generated text to extract CAD model data (assuming JSON-like output)
95
- try:
96
- cad_data = json.loads(generated_text)
97
- return json.dumps(cad_data, indent=2)
98
- except json.JSONDecodeError:
99
- return generated_text # Return raw text if JSON parsing fails
100
  except Exception as e:
101
- logger.error(f"Error during generation: {str(e)}")
102
- return f"Error: {str(e)}"
103
 
104
- # Gradio interface
105
  def create_gradio_interface():
106
- with gr.Blocks() as demo:
107
- gr.Markdown("# CADFusion: Text-to-CAD Generation")
108
- gr.Markdown("Enter a textual description of the CAD model you want to generate. For example: 'A 3D model of a chair with four legs and a curved backrest.'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  with gr.Row():
111
- with gr.Column():
 
 
112
  text_input = gr.Textbox(
113
- label="Text Description",
114
- placeholder="Enter your CAD model description here...",
115
- lines=5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  )
117
- submit_button = gr.Button("Generate CAD Model")
118
 
119
- with gr.Column():
120
- output_text = gr.Textbox(
121
- label="Generated CAD Model (JSON or Text)",
122
- placeholder="Generated output will appear here...",
123
- lines=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  )
125
 
126
- submit_button.click(
127
- fn=generate_cad_model,
128
- inputs=text_input,
129
- outputs=output_text
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
 
 
132
  gr.Markdown("""
133
- **Note**:
134
- - This deployment may use a fallback model (Llama-2-7b) due to issues with the CADFusion v1_1 checkpoint.
135
- - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
136
- - For full CADFusion functionality, follow the setup instructions in the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion).
137
- - Contact Shizhao Sun (shizsu@microsoft.com) for checkpoint access or issues.
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  """)
 
 
 
 
 
 
 
 
139
 
140
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- # Launch Gradio app
143
  if __name__ == "__main__":
144
- try:
145
- demo = create_gradio_interface()
146
- demo.launch()
147
- except Exception as e:
148
- logger.error(f"Error launching Gradio app: {str(e)}")
149
- raise e
 
 
 
 
 
1
  import os
2
+ import sys
3
  import json
4
+ import torch
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ import tempfile
10
+ import subprocess
11
+ import shutil
12
+ from typing import Optional, List, Dict, Any
 
 
 
13
 
14
+ # Add the src directory to Python path for imports
15
+ sys.path.insert(0, './src')
16
 
 
17
  try:
18
+ from transformers import (
19
+ AutoTokenizer,
20
+ AutoModelForCausalLM,
21
+ LlamaTokenizer,
22
+ LlamaForCausalLM
 
 
23
  )
24
+ from huggingface_hub import snapshot_download
25
+ print("βœ… Successfully imported transformers and huggingface_hub")
26
+ except ImportError as e:
27
+ print(f"❌ Import error: {e}")
28
+ print("Installing required packages...")
29
+ subprocess.run([sys.executable, "-m", "pip", "install", "transformers", "huggingface_hub", "torch", "accelerate"])
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM
31
+ from huggingface_hub import snapshot_download
32
 
33
+ class CADFusionModel:
34
+ def __init__(self, model_path: str = "microsoft/CADFusion", version: str = "v1_1"):
35
+ """
36
+ Initialize the CADFusion model
37
+
38
+ Args:
39
+ model_path: Path to the model on Hugging Face Hub
40
+ version: Model version (v1_0 or v1_1)
41
+ """
42
+ self.model_path = model_path
43
+ self.version = version
44
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ print(f"πŸš€ Initializing CADFusion {version} on {self.device}")
47
+
48
+ # Download model if not already present
49
+ self.model_dir = self._download_model()
50
+
51
+ # Initialize tokenizer and model
52
+ self.tokenizer = None
53
+ self.model = None
54
+ self._load_model()
55
+
56
+ # CAD sequence processing utilities
57
+ self.max_sequence_length = 512
58
+
59
+ def _download_model(self) -> str:
60
+ """Download the model from Hugging Face Hub"""
61
+ try:
62
+ cache_dir = "./model_cache"
63
+ model_dir = snapshot_download(
64
+ repo_id=self.model_path,
65
+ revision=self.version,
66
+ cache_dir=cache_dir,
67
+ token=os.getenv("HF_TOKEN") # Use HF token if available
68
+ )
69
+ print(f"βœ… Model downloaded to: {model_dir}")
70
+ return model_dir
71
+ except Exception as e:
72
+ print(f"❌ Error downloading model: {e}")
73
+ # Fallback to local directory structure
74
+ return f"./{self.version}"
75
+
76
+ def _load_model(self):
77
+ """Load the tokenizer and model"""
78
+ try:
79
+ # Try loading as LLaMA model first (CADFusion is based on LLaMA)
80
+ model_files = list(Path(self.model_dir).glob("*.bin")) + list(Path(self.model_dir).glob("*.safetensors"))
81
+
82
+ if model_files:
83
+ print(f"πŸ“¦ Loading model from {self.model_dir}")
84
+
85
+ # Load tokenizer
86
+ self.tokenizer = AutoTokenizer.from_pretrained(
87
+ self.model_dir,
88
+ trust_remote_code=True,
89
+ padding_side="left"
90
+ )
91
+
92
+ # Ensure pad token exists
93
+ if self.tokenizer.pad_token is None:
94
+ self.tokenizer.pad_token = self.tokenizer.eos_token
95
+
96
+ # Load model
97
+ self.model = AutoModelForCausalLM.from_pretrained(
98
+ self.model_dir,
99
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
100
+ device_map="auto" if self.device.type == "cuda" else None,
101
+ trust_remote_code=True
102
+ )
103
+
104
+ if self.device.type != "cuda":
105
+ self.model = self.model.to(self.device)
106
+
107
+ self.model.eval()
108
+ print("βœ… Model loaded successfully")
109
+
110
+ else:
111
+ raise FileNotFoundError("No model files found")
112
+
113
+ except Exception as e:
114
+ print(f"❌ Error loading model: {e}")
115
+ print("πŸ“ Using placeholder model for demo purposes")
116
+ self._setup_placeholder_model()
117
+
118
+ def _setup_placeholder_model(self):
119
+ """Setup a placeholder model for demo purposes"""
120
+ print("⚠️ Setting up placeholder model")
121
+ # This is a fallback when the actual model can't be loaded
122
+ self.model = None
123
+ self.tokenizer = None
124
+
125
+ def preprocess_text(self, text: str) -> str:
126
+ """Preprocess input text for CAD generation"""
127
+ # Basic text cleaning and formatting
128
+ text = text.strip()
129
+ if not text:
130
+ return "Generate a simple 3D object"
131
+
132
+ # Add any specific preprocessing for CAD descriptions
133
+ if not any(word in text.lower() for word in ['create', 'design', 'make', 'generate', 'build']):
134
+ text = f"Create a {text}"
135
+
136
+ return text
137
+
138
+ def generate_cad_sequence(self, text: str, max_length: int = 512, temperature: float = 0.7) -> Dict[str, Any]:
139
+ """
140
+ Generate CAD parametric sequence from text description
141
+
142
+ Args:
143
+ text: Text description of the CAD object
144
+ max_length: Maximum sequence length
145
+ temperature: Generation temperature
146
+
147
+ Returns:
148
+ Dictionary containing the generated sequence and metadata
149
+ """
150
+ try:
151
+ if self.model is None or self.tokenizer is None:
152
+ # Return placeholder response
153
+ return {
154
+ "success": False,
155
+ "message": "Model not loaded - showing demo output",
156
+ "sequence": self._generate_demo_sequence(text),
157
+ "text_input": text,
158
+ "parameters": {
159
+ "max_length": max_length,
160
+ "temperature": temperature
161
+ }
162
+ }
163
+
164
+ # Preprocess input text
165
+ processed_text = self.preprocess_text(text)
166
+
167
+ # Tokenize input
168
+ inputs = self.tokenizer(
169
+ processed_text,
170
+ return_tensors="pt",
171
+ padding=True,
172
+ truncation=True,
173
+ max_length=256
174
+ ).to(self.device)
175
+
176
+ # Generate sequence
177
+ with torch.no_grad():
178
+ outputs = self.model.generate(
179
+ inputs.input_ids,
180
+ attention_mask=inputs.attention_mask,
181
+ max_length=max_length,
182
+ temperature=temperature,
183
+ do_sample=True,
184
+ top_p=0.9,
185
+ top_k=50,
186
+ pad_token_id=self.tokenizer.pad_token_id,
187
+ eos_token_id=self.tokenizer.eos_token_id
188
+ )
189
+
190
+ # Decode output
191
+ generated_sequence = self.tokenizer.decode(
192
+ outputs[0],
193
+ skip_special_tokens=True
194
+ )
195
+
196
+ # Extract the generated part (remove input prompt)
197
+ if processed_text in generated_sequence:
198
+ generated_part = generated_sequence.replace(processed_text, "").strip()
199
+ else:
200
+ generated_part = generated_sequence
201
+
202
+ return {
203
+ "success": True,
204
+ "sequence": generated_part,
205
+ "full_output": generated_sequence,
206
+ "text_input": processed_text,
207
+ "parameters": {
208
+ "max_length": max_length,
209
+ "temperature": temperature
210
+ }
211
+ }
212
+
213
+ except Exception as e:
214
+ print(f"❌ Generation error: {e}")
215
+ return {
216
+ "success": False,
217
+ "message": f"Generation failed: {str(e)}",
218
+ "sequence": self._generate_demo_sequence(text),
219
+ "text_input": text
220
+ }
221
+
222
+ def _generate_demo_sequence(self, text: str) -> str:
223
+ """Generate a demo CAD sequence for demonstration purposes"""
224
+ # This is a simplified demo sequence based on the input text
225
+ demo_sequences = {
226
+ "cube": "Sketch('xy') -> Rectangle(0, 0, 10, 10) -> Extrude(10)",
227
+ "cylinder": "Sketch('xy') -> Circle(0, 0, 5) -> Extrude(15)",
228
+ "sphere": "Sketch('xy') -> Circle(0, 0, 5) -> Revolve(360)",
229
+ "bracket": "Sketch('xy') -> Rectangle(0, 0, 20, 10) -> Extrude(5) -> Sketch('top') -> Circle(15, 5, 2) -> Cut(5)"
230
+ }
231
+
232
+ text_lower = text.lower()
233
+ for key, sequence in demo_sequences.items():
234
+ if key in text_lower:
235
+ return sequence
236
+
237
+ # Default sequence
238
+ return "Sketch('xy') -> Rectangle(0, 0, 10, 10) -> Extrude(5)"
239
+
240
+ # Global model instance
241
+ model = None
242
 
243
+ def initialize_model():
244
+ """Initialize the global model instance"""
245
+ global model
246
+ if model is None:
247
+ print("πŸ”„ Initializing CADFusion model...")
248
+ model = CADFusionModel()
249
+ return model
250
+
251
+ def generate_cad(
252
+ text_input: str,
253
+ max_length: int = 512,
254
+ temperature: float = 0.7
255
+ ) -> tuple:
256
+ """
257
+ Gradio interface function for CAD generation
258
+
259
+ Returns:
260
+ Tuple of (generated_sequence, status_message, parameters_info)
261
+ """
262
  try:
263
+ # Initialize model if needed
264
+ global model
265
+ if model is None:
266
+ model = initialize_model()
267
+
268
+ # Validate inputs
269
+ if not text_input or not text_input.strip():
270
+ return "Please provide a text description.", "❌ Error: Empty input", "No parameters"
271
+
272
+ # Generate CAD sequence
273
+ result = model.generate_cad_sequence(
274
+ text_input,
275
+ max_length=max_length,
276
+ temperature=temperature
277
  )
278
 
279
+ # Format output
280
+ if result["success"]:
281
+ status = "βœ… Generation successful"
282
+ sequence = result["sequence"]
283
+ else:
284
+ status = f"⚠️ {result.get('message', 'Generation failed')}"
285
+ sequence = result["sequence"]
286
+
287
+ # Format parameters info
288
+ params = result.get("parameters", {})
289
+ param_info = f"Max Length: {params.get('max_length', max_length)}, Temperature: {params.get('temperature', temperature)}"
290
+
291
+ return sequence, status, param_info
292
 
 
 
 
 
 
 
293
  except Exception as e:
294
+ error_msg = f"❌ Error: {str(e)}"
295
+ return "Generation failed", error_msg, "No parameters"
296
 
 
297
  def create_gradio_interface():
298
+ """Create the Gradio interface"""
299
+
300
+ # Custom CSS for better styling
301
+ css = """
302
+ .gradio-container {
303
+ font-family: 'Arial', sans-serif;
304
+ }
305
+ .gr-button-primary {
306
+ background: linear-gradient(45deg, #1e3a8a, #3b82f6);
307
+ border: none;
308
+ }
309
+ .gr-panel {
310
+ border-radius: 8px;
311
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
312
+ }
313
+ """
314
+
315
+ with gr.Blocks(css=css, title="CADFusion - Text to CAD Generation") as interface:
316
+
317
+ # Header
318
+ gr.Markdown("""
319
+ # πŸ”§ CADFusion - Text to CAD Generation
320
+
321
+ Convert natural language descriptions into CAD parametric sequences using Microsoft's CADFusion model.
322
+
323
+ **Model**: microsoft/CADFusion v1.1
324
+ **Paper**: [Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models](https://arxiv.org/abs/2501.19054)
325
+ """)
326
 
327
  with gr.Row():
328
+ with gr.Column(scale=2):
329
+ # Input section
330
+ gr.Markdown("### πŸ“ Input")
331
  text_input = gr.Textbox(
332
+ label="CAD Description",
333
+ placeholder="Describe the CAD object you want to create (e.g., 'Create a cylindrical bracket with mounting holes')",
334
+ lines=3,
335
+ value="Create a simple rectangular bracket with two circular holes"
336
+ )
337
+
338
+ # Parameters section
339
+ gr.Markdown("### βš™οΈ Generation Parameters")
340
+ with gr.Row():
341
+ max_length = gr.Slider(
342
+ label="Max Length",
343
+ minimum=128,
344
+ maximum=1024,
345
+ value=512,
346
+ step=64,
347
+ info="Maximum length of generated sequence"
348
+ )
349
+ temperature = gr.Slider(
350
+ label="Temperature",
351
+ minimum=0.1,
352
+ maximum=1.5,
353
+ value=0.7,
354
+ step=0.1,
355
+ info="Generation randomness (lower = more deterministic)"
356
+ )
357
+
358
+ # Generate button
359
+ generate_btn = gr.Button(
360
+ "πŸš€ Generate CAD Sequence",
361
+ variant="primary",
362
+ size="lg"
363
  )
 
364
 
365
+ with gr.Column(scale=3):
366
+ # Output section
367
+ gr.Markdown("### 🎯 Generated CAD Sequence")
368
+ sequence_output = gr.Textbox(
369
+ label="Parametric Sequence",
370
+ lines=8,
371
+ interactive=False,
372
+ placeholder="Generated CAD sequence will appear here..."
373
+ )
374
+
375
+ status_output = gr.Textbox(
376
+ label="Status",
377
+ lines=1,
378
+ interactive=False
379
+ )
380
+
381
+ params_output = gr.Textbox(
382
+ label="Parameters Used",
383
+ lines=1,
384
+ interactive=False
385
  )
386
 
387
+ # Examples section
388
+ gr.Markdown("### πŸ’‘ Example Prompts")
389
+ examples = gr.Examples(
390
+ examples=[
391
+ ["Create a cylindrical rod with a square base"],
392
+ ["Design a mounting bracket with four holes"],
393
+ ["Make a simple cube with rounded corners"],
394
+ ["Create a T-shaped connector piece"],
395
+ ["Design a gear wheel with 12 teeth"],
396
+ ["Make a pipe elbow joint at 90 degrees"],
397
+ ["Create a hexagonal bolt head"],
398
+ ["Design a simple housing enclosure"]
399
+ ],
400
+ inputs=[text_input],
401
+ label="Click on any example to try it out"
402
  )
403
 
404
+ # Information section
405
  gr.Markdown("""
406
+ ### ℹ️ About CADFusion
407
+
408
+ CADFusion is a state-of-the-art text-to-CAD generation model that:
409
+ - Uses visual feedback to enhance LLM performance
410
+ - Generates parametric sequences for CAD modeling
411
+ - Supports complex 3D object descriptions
412
+ - Based on alternating sequential and visual learning stages
413
+
414
+ **Usage Tips**:
415
+ - Be specific about shapes, dimensions, and features
416
+ - Use technical CAD terminology when possible
417
+ - Mention materials or constraints if relevant
418
+ - Start with simple descriptions and add complexity gradually
419
+
420
+ **Model Info**:
421
+ - Version: v1.1 (9 rounds of alternate training)
422
+ - Base Model: LLaMA architecture
423
+ - Training Data: SkexGen dataset with human annotations
424
  """)
425
+
426
+ # Connect the generate button to the function
427
+ generate_btn.click(
428
+ fn=generate_cad,
429
+ inputs=[text_input, max_length, temperature],
430
+ outputs=[sequence_output, status_output, params_output],
431
+ show_progress=True
432
+ )
433
 
434
+ return interface
435
+
436
+ def main():
437
+ """Main function to run the Gradio app"""
438
+ print("🌟 Starting CADFusion Gradio App")
439
+
440
+ # Initialize model
441
+ print("πŸ”„ Initializing model...")
442
+ initialize_model()
443
+
444
+ # Create and launch interface
445
+ interface = create_gradio_interface()
446
+
447
+ # Launch configuration
448
+ interface.launch(
449
+ server_name="0.0.0.0", # Allow external access
450
+ server_port=7860, # Standard Gradio port
451
+ share=False, # Set to True for public sharing
452
+ debug=True, # Enable debug mode
453
+ show_error=True, # Show errors in interface
454
+ quiet=False # Show startup logs
455
+ )
456
 
 
457
  if __name__ == "__main__":
458
+ main()