Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| import tempfile | |
| import subprocess | |
| import shutil | |
| from typing import Optional, List, Dict, Any | |
| # Add the src directory to Python path for imports | |
| sys.path.insert(0, './src') | |
| try: | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| LlamaTokenizer, | |
| LlamaForCausalLM | |
| ) | |
| from huggingface_hub import snapshot_download | |
| print("β Successfully imported transformers and huggingface_hub") | |
| except ImportError as e: | |
| print(f"β Import error: {e}") | |
| print("Installing required packages...") | |
| subprocess.run([sys.executable, "-m", "pip", "install", "transformers", "huggingface_hub", "torch", "accelerate"]) | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import snapshot_download | |
| class CADFusionModel: | |
| def __init__(self, model_path: str = "microsoft/CADFusion", version: str = "v1_1"): | |
| """ | |
| Initialize the CADFusion model | |
| Args: | |
| model_path: Path to the model on Hugging Face Hub | |
| version: Model version (v1_0 or v1_1) | |
| """ | |
| self.model_path = model_path | |
| self.version = version | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"π Initializing CADFusion {version} on {self.device}") | |
| # Download model if not already present | |
| self.model_dir = self._download_model() | |
| # Initialize tokenizer and model | |
| self.tokenizer = None | |
| self.model = None | |
| self._load_model() | |
| # CAD sequence processing utilities | |
| self.max_sequence_length = 512 | |
| def _download_model(self) -> str: | |
| """Download the model from Hugging Face Hub""" | |
| try: | |
| cache_dir = "./model_cache" | |
| model_dir = snapshot_download( | |
| repo_id=self.model_path, | |
| revision=self.version, | |
| cache_dir=cache_dir, | |
| token=os.getenv("HF_TOKEN") # Use HF token if available | |
| ) | |
| print(f"β Model downloaded to: {model_dir}") | |
| return model_dir | |
| except Exception as e: | |
| print(f"β Error downloading model: {e}") | |
| # Fallback to local directory structure | |
| return f"./{self.version}" | |
| def _load_model(self): | |
| """Load the tokenizer and model""" | |
| try: | |
| # Try loading as LLaMA model first (CADFusion is based on LLaMA) | |
| model_files = list(Path(self.model_dir).glob("*.bin")) + list(Path(self.model_dir).glob("*.safetensors")) | |
| if model_files: | |
| print(f"π¦ Loading model from {self.model_dir}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_dir, | |
| trust_remote_code=True, | |
| padding_side="left" | |
| ) | |
| # Ensure pad token exists | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Load model | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_dir, | |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, | |
| device_map="auto" if self.device.type == "cuda" else None, | |
| trust_remote_code=True | |
| ) | |
| if self.device.type != "cuda": | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| print("β Model loaded successfully") | |
| else: | |
| raise FileNotFoundError("No model files found") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| print("π Using placeholder model for demo purposes") | |
| self._setup_placeholder_model() | |
| def _setup_placeholder_model(self): | |
| """Setup a placeholder model for demo purposes""" | |
| print("β οΈ Setting up placeholder model") | |
| # This is a fallback when the actual model can't be loaded | |
| self.model = None | |
| self.tokenizer = None | |
| def preprocess_text(self, text: str) -> str: | |
| """Preprocess input text for CAD generation""" | |
| # Basic text cleaning and formatting | |
| text = text.strip() | |
| if not text: | |
| return "Generate a simple 3D object" | |
| # Add any specific preprocessing for CAD descriptions | |
| if not any(word in text.lower() for word in ['create', 'design', 'make', 'generate', 'build']): | |
| text = f"Create a {text}" | |
| return text | |
| def generate_cad_sequence(self, text: str, max_length: int = 512, temperature: float = 0.7) -> Dict[str, Any]: | |
| """ | |
| Generate CAD parametric sequence from text description | |
| Args: | |
| text: Text description of the CAD object | |
| max_length: Maximum sequence length | |
| temperature: Generation temperature | |
| Returns: | |
| Dictionary containing the generated sequence and metadata | |
| """ | |
| try: | |
| if self.model is None or self.tokenizer is None: | |
| # Return placeholder response | |
| return { | |
| "success": False, | |
| "message": "Model not loaded - showing demo output", | |
| "sequence": self._generate_demo_sequence(text), | |
| "text_input": text, | |
| "parameters": { | |
| "max_length": max_length, | |
| "temperature": temperature | |
| } | |
| } | |
| # Preprocess input text | |
| processed_text = self.preprocess_text(text) | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| processed_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=256 | |
| ).to(self.device) | |
| # Generate sequence | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_length=max_length, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.9, | |
| top_k=50, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode output | |
| generated_sequence = self.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| # Extract the generated part (remove input prompt) | |
| if processed_text in generated_sequence: | |
| generated_part = generated_sequence.replace(processed_text, "").strip() | |
| else: | |
| generated_part = generated_sequence | |
| return { | |
| "success": True, | |
| "sequence": generated_part, | |
| "full_output": generated_sequence, | |
| "text_input": processed_text, | |
| "parameters": { | |
| "max_length": max_length, | |
| "temperature": temperature | |
| } | |
| } | |
| except Exception as e: | |
| print(f"β Generation error: {e}") | |
| return { | |
| "success": False, | |
| "message": f"Generation failed: {str(e)}", | |
| "sequence": self._generate_demo_sequence(text), | |
| "text_input": text | |
| } | |
| def _generate_demo_sequence(self, text: str) -> str: | |
| """Generate a demo CAD sequence for demonstration purposes""" | |
| # This is a simplified demo sequence based on the input text | |
| demo_sequences = { | |
| "cube": "Sketch('xy') -> Rectangle(0, 0, 10, 10) -> Extrude(10)", | |
| "cylinder": "Sketch('xy') -> Circle(0, 0, 5) -> Extrude(15)", | |
| "sphere": "Sketch('xy') -> Circle(0, 0, 5) -> Revolve(360)", | |
| "bracket": "Sketch('xy') -> Rectangle(0, 0, 20, 10) -> Extrude(5) -> Sketch('top') -> Circle(15, 5, 2) -> Cut(5)" | |
| } | |
| text_lower = text.lower() | |
| for key, sequence in demo_sequences.items(): | |
| if key in text_lower: | |
| return sequence | |
| # Default sequence | |
| return "Sketch('xy') -> Rectangle(0, 0, 10, 10) -> Extrude(5)" | |
| # Global model instance | |
| model = None | |
| def initialize_model(): | |
| """Initialize the global model instance""" | |
| global model | |
| if model is None: | |
| print("π Initializing CADFusion model...") | |
| model = CADFusionModel() | |
| return model | |
| def generate_cad( | |
| text_input: str, | |
| max_length: int = 512, | |
| temperature: float = 0.7 | |
| ) -> tuple: | |
| """ | |
| Gradio interface function for CAD generation | |
| Returns: | |
| Tuple of (generated_sequence, status_message, parameters_info) | |
| """ | |
| try: | |
| # Initialize model if needed | |
| global model | |
| if model is None: | |
| model = initialize_model() | |
| # Validate inputs | |
| if not text_input or not text_input.strip(): | |
| return "Please provide a text description.", "β Error: Empty input", "No parameters" | |
| # Generate CAD sequence | |
| result = model.generate_cad_sequence( | |
| text_input, | |
| max_length=max_length, | |
| temperature=temperature | |
| ) | |
| # Format output | |
| if result["success"]: | |
| status = "β Generation successful" | |
| sequence = result["sequence"] | |
| else: | |
| status = f"β οΈ {result.get('message', 'Generation failed')}" | |
| sequence = result["sequence"] | |
| # Format parameters info | |
| params = result.get("parameters", {}) | |
| param_info = f"Max Length: {params.get('max_length', max_length)}, Temperature: {params.get('temperature', temperature)}" | |
| return sequence, status, param_info | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| return "Generation failed", error_msg, "No parameters" | |
| def create_gradio_interface(): | |
| """Create the Gradio interface""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(45deg, #1e3a8a, #3b82f6); | |
| border: none; | |
| } | |
| .gr-panel { | |
| border-radius: 8px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="CADFusion - Text to CAD Generation") as interface: | |
| # Header | |
| gr.Markdown(""" | |
| # π§ CADFusion - Text to CAD Generation | |
| Convert natural language descriptions into CAD parametric sequences using Microsoft's CADFusion model. | |
| **Model**: microsoft/CADFusion v1.1 | |
| **Paper**: [Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models](https://arxiv.org/abs/2501.19054) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Input section | |
| gr.Markdown("### π Input") | |
| text_input = gr.Textbox( | |
| label="CAD Description", | |
| placeholder="Describe the CAD object you want to create (e.g., 'Create a cylindrical bracket with mounting holes')", | |
| lines=3, | |
| value="Create a simple rectangular bracket with two circular holes" | |
| ) | |
| # Parameters section | |
| gr.Markdown("### βοΈ Generation Parameters") | |
| with gr.Row(): | |
| max_length = gr.Slider( | |
| label="Max Length", | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| info="Maximum length of generated sequence" | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| info="Generation randomness (lower = more deterministic)" | |
| ) | |
| # Generate button | |
| generate_btn = gr.Button( | |
| "π Generate CAD Sequence", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=3): | |
| # Output section | |
| gr.Markdown("### π― Generated CAD Sequence") | |
| sequence_output = gr.Textbox( | |
| label="Parametric Sequence", | |
| lines=8, | |
| interactive=False, | |
| placeholder="Generated CAD sequence will appear here..." | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=1, | |
| interactive=False | |
| ) | |
| params_output = gr.Textbox( | |
| label="Parameters Used", | |
| lines=1, | |
| interactive=False | |
| ) | |
| # Examples section | |
| gr.Markdown("### π‘ Example Prompts") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["Create a cylindrical rod with a square base"], | |
| ["Design a mounting bracket with four holes"], | |
| ["Make a simple cube with rounded corners"], | |
| ["Create a T-shaped connector piece"], | |
| ["Design a gear wheel with 12 teeth"], | |
| ["Make a pipe elbow joint at 90 degrees"], | |
| ["Create a hexagonal bolt head"], | |
| ["Design a simple housing enclosure"] | |
| ], | |
| inputs=[text_input], | |
| label="Click on any example to try it out" | |
| ) | |
| # Information section | |
| gr.Markdown(""" | |
| ### βΉοΈ About CADFusion | |
| CADFusion is a state-of-the-art text-to-CAD generation model that: | |
| - Uses visual feedback to enhance LLM performance | |
| - Generates parametric sequences for CAD modeling | |
| - Supports complex 3D object descriptions | |
| - Based on alternating sequential and visual learning stages | |
| **Usage Tips**: | |
| - Be specific about shapes, dimensions, and features | |
| - Use technical CAD terminology when possible | |
| - Mention materials or constraints if relevant | |
| - Start with simple descriptions and add complexity gradually | |
| **Model Info**: | |
| - Version: v1.1 (9 rounds of alternate training) | |
| - Base Model: LLaMA architecture | |
| - Training Data: SkexGen dataset with human annotations | |
| """) | |
| # Connect the generate button to the function | |
| generate_btn.click( | |
| fn=generate_cad, | |
| inputs=[text_input, max_length, temperature], | |
| outputs=[sequence_output, status_output, params_output], | |
| show_progress=True | |
| ) | |
| return interface | |
| def main(): | |
| """Main function to run the Gradio app""" | |
| print("π Starting CADFusion Gradio App") | |
| # Initialize model | |
| print("π Initializing model...") | |
| initialize_model() | |
| # Create and launch interface | |
| interface = create_gradio_interface() | |
| # Launch configuration | |
| interface.launch( | |
| server_name="0.0.0.0", # Allow external access | |
| server_port=7860, # Standard Gradio port | |
| share=False, # Set to True for public sharing | |
| debug=True, # Enable debug mode | |
| show_error=True, # Show errors in interface | |
| quiet=False # Show startup logs | |
| ) | |
| if __name__ == "__main__": | |
| main() |