KSAutoCAD / app.py
Pasipid791's picture
Update app.py
e7a03ef verified
raw
history blame
16.6 kB
import os
import sys
import json
import torch
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import tempfile
import subprocess
import shutil
from typing import Optional, List, Dict, Any
# Add the src directory to Python path for imports
sys.path.insert(0, './src')
try:
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM
)
from huggingface_hub import snapshot_download
print("βœ… Successfully imported transformers and huggingface_hub")
except ImportError as e:
print(f"❌ Import error: {e}")
print("Installing required packages...")
subprocess.run([sys.executable, "-m", "pip", "install", "transformers", "huggingface_hub", "torch", "accelerate"])
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import snapshot_download
class CADFusionModel:
def __init__(self, model_path: str = "microsoft/CADFusion", version: str = "v1_1"):
"""
Initialize the CADFusion model
Args:
model_path: Path to the model on Hugging Face Hub
version: Model version (v1_0 or v1_1)
"""
self.model_path = model_path
self.version = version
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"πŸš€ Initializing CADFusion {version} on {self.device}")
# Download model if not already present
self.model_dir = self._download_model()
# Initialize tokenizer and model
self.tokenizer = None
self.model = None
self._load_model()
# CAD sequence processing utilities
self.max_sequence_length = 512
def _download_model(self) -> str:
"""Download the model from Hugging Face Hub"""
try:
cache_dir = "./model_cache"
model_dir = snapshot_download(
repo_id=self.model_path,
revision=self.version,
cache_dir=cache_dir,
token=os.getenv("HF_TOKEN") # Use HF token if available
)
print(f"βœ… Model downloaded to: {model_dir}")
return model_dir
except Exception as e:
print(f"❌ Error downloading model: {e}")
# Fallback to local directory structure
return f"./{self.version}"
def _load_model(self):
"""Load the tokenizer and model"""
try:
# Try loading as LLaMA model first (CADFusion is based on LLaMA)
model_files = list(Path(self.model_dir).glob("*.bin")) + list(Path(self.model_dir).glob("*.safetensors"))
if model_files:
print(f"πŸ“¦ Loading model from {self.model_dir}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir,
trust_remote_code=True,
padding_side="left"
)
# Ensure pad token exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto" if self.device.type == "cuda" else None,
trust_remote_code=True
)
if self.device.type != "cuda":
self.model = self.model.to(self.device)
self.model.eval()
print("βœ… Model loaded successfully")
else:
raise FileNotFoundError("No model files found")
except Exception as e:
print(f"❌ Error loading model: {e}")
print("πŸ“ Using placeholder model for demo purposes")
self._setup_placeholder_model()
def _setup_placeholder_model(self):
"""Setup a placeholder model for demo purposes"""
print("⚠️ Setting up placeholder model")
# This is a fallback when the actual model can't be loaded
self.model = None
self.tokenizer = None
def preprocess_text(self, text: str) -> str:
"""Preprocess input text for CAD generation"""
# Basic text cleaning and formatting
text = text.strip()
if not text:
return "Generate a simple 3D object"
# Add any specific preprocessing for CAD descriptions
if not any(word in text.lower() for word in ['create', 'design', 'make', 'generate', 'build']):
text = f"Create a {text}"
return text
def generate_cad_sequence(self, text: str, max_length: int = 512, temperature: float = 0.7) -> Dict[str, Any]:
"""
Generate CAD parametric sequence from text description
Args:
text: Text description of the CAD object
max_length: Maximum sequence length
temperature: Generation temperature
Returns:
Dictionary containing the generated sequence and metadata
"""
try:
if self.model is None or self.tokenizer is None:
# Return placeholder response
return {
"success": False,
"message": "Model not loaded - showing demo output",
"sequence": self._generate_demo_sequence(text),
"text_input": text,
"parameters": {
"max_length": max_length,
"temperature": temperature
}
}
# Preprocess input text
processed_text = self.preprocess_text(text)
# Tokenize input
inputs = self.tokenizer(
processed_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
).to(self.device)
# Generate sequence
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
top_k=50,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode output
generated_sequence = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# Extract the generated part (remove input prompt)
if processed_text in generated_sequence:
generated_part = generated_sequence.replace(processed_text, "").strip()
else:
generated_part = generated_sequence
return {
"success": True,
"sequence": generated_part,
"full_output": generated_sequence,
"text_input": processed_text,
"parameters": {
"max_length": max_length,
"temperature": temperature
}
}
except Exception as e:
print(f"❌ Generation error: {e}")
return {
"success": False,
"message": f"Generation failed: {str(e)}",
"sequence": self._generate_demo_sequence(text),
"text_input": text
}
def _generate_demo_sequence(self, text: str) -> str:
"""Generate a demo CAD sequence for demonstration purposes"""
# This is a simplified demo sequence based on the input text
demo_sequences = {
"cube": "Sketch('xy') -> Rectangle(0, 0, 10, 10) -> Extrude(10)",
"cylinder": "Sketch('xy') -> Circle(0, 0, 5) -> Extrude(15)",
"sphere": "Sketch('xy') -> Circle(0, 0, 5) -> Revolve(360)",
"bracket": "Sketch('xy') -> Rectangle(0, 0, 20, 10) -> Extrude(5) -> Sketch('top') -> Circle(15, 5, 2) -> Cut(5)"
}
text_lower = text.lower()
for key, sequence in demo_sequences.items():
if key in text_lower:
return sequence
# Default sequence
return "Sketch('xy') -> Rectangle(0, 0, 10, 10) -> Extrude(5)"
# Global model instance
model = None
def initialize_model():
"""Initialize the global model instance"""
global model
if model is None:
print("πŸ”„ Initializing CADFusion model...")
model = CADFusionModel()
return model
def generate_cad(
text_input: str,
max_length: int = 512,
temperature: float = 0.7
) -> tuple:
"""
Gradio interface function for CAD generation
Returns:
Tuple of (generated_sequence, status_message, parameters_info)
"""
try:
# Initialize model if needed
global model
if model is None:
model = initialize_model()
# Validate inputs
if not text_input or not text_input.strip():
return "Please provide a text description.", "❌ Error: Empty input", "No parameters"
# Generate CAD sequence
result = model.generate_cad_sequence(
text_input,
max_length=max_length,
temperature=temperature
)
# Format output
if result["success"]:
status = "βœ… Generation successful"
sequence = result["sequence"]
else:
status = f"⚠️ {result.get('message', 'Generation failed')}"
sequence = result["sequence"]
# Format parameters info
params = result.get("parameters", {})
param_info = f"Max Length: {params.get('max_length', max_length)}, Temperature: {params.get('temperature', temperature)}"
return sequence, status, param_info
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
return "Generation failed", error_msg, "No parameters"
def create_gradio_interface():
"""Create the Gradio interface"""
# Custom CSS for better styling
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.gr-button-primary {
background: linear-gradient(45deg, #1e3a8a, #3b82f6);
border: none;
}
.gr-panel {
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
"""
with gr.Blocks(css=css, title="CADFusion - Text to CAD Generation") as interface:
# Header
gr.Markdown("""
# πŸ”§ CADFusion - Text to CAD Generation
Convert natural language descriptions into CAD parametric sequences using Microsoft's CADFusion model.
**Model**: microsoft/CADFusion v1.1
**Paper**: [Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models](https://arxiv.org/abs/2501.19054)
""")
with gr.Row():
with gr.Column(scale=2):
# Input section
gr.Markdown("### πŸ“ Input")
text_input = gr.Textbox(
label="CAD Description",
placeholder="Describe the CAD object you want to create (e.g., 'Create a cylindrical bracket with mounting holes')",
lines=3,
value="Create a simple rectangular bracket with two circular holes"
)
# Parameters section
gr.Markdown("### βš™οΈ Generation Parameters")
with gr.Row():
max_length = gr.Slider(
label="Max Length",
minimum=128,
maximum=1024,
value=512,
step=64,
info="Maximum length of generated sequence"
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
info="Generation randomness (lower = more deterministic)"
)
# Generate button
generate_btn = gr.Button(
"πŸš€ Generate CAD Sequence",
variant="primary",
size="lg"
)
with gr.Column(scale=3):
# Output section
gr.Markdown("### 🎯 Generated CAD Sequence")
sequence_output = gr.Textbox(
label="Parametric Sequence",
lines=8,
interactive=False,
placeholder="Generated CAD sequence will appear here..."
)
status_output = gr.Textbox(
label="Status",
lines=1,
interactive=False
)
params_output = gr.Textbox(
label="Parameters Used",
lines=1,
interactive=False
)
# Examples section
gr.Markdown("### πŸ’‘ Example Prompts")
examples = gr.Examples(
examples=[
["Create a cylindrical rod with a square base"],
["Design a mounting bracket with four holes"],
["Make a simple cube with rounded corners"],
["Create a T-shaped connector piece"],
["Design a gear wheel with 12 teeth"],
["Make a pipe elbow joint at 90 degrees"],
["Create a hexagonal bolt head"],
["Design a simple housing enclosure"]
],
inputs=[text_input],
label="Click on any example to try it out"
)
# Information section
gr.Markdown("""
### ℹ️ About CADFusion
CADFusion is a state-of-the-art text-to-CAD generation model that:
- Uses visual feedback to enhance LLM performance
- Generates parametric sequences for CAD modeling
- Supports complex 3D object descriptions
- Based on alternating sequential and visual learning stages
**Usage Tips**:
- Be specific about shapes, dimensions, and features
- Use technical CAD terminology when possible
- Mention materials or constraints if relevant
- Start with simple descriptions and add complexity gradually
**Model Info**:
- Version: v1.1 (9 rounds of alternate training)
- Base Model: LLaMA architecture
- Training Data: SkexGen dataset with human annotations
""")
# Connect the generate button to the function
generate_btn.click(
fn=generate_cad,
inputs=[text_input, max_length, temperature],
outputs=[sequence_output, status_output, params_output],
show_progress=True
)
return interface
def main():
"""Main function to run the Gradio app"""
print("🌟 Starting CADFusion Gradio App")
# Initialize model
print("πŸ”„ Initializing model...")
initialize_model()
# Create and launch interface
interface = create_gradio_interface()
# Launch configuration
interface.launch(
server_name="0.0.0.0", # Allow external access
server_port=7860, # Standard Gradio port
share=False, # Set to True for public sharing
debug=True, # Enable debug mode
show_error=True, # Show errors in interface
quiet=False # Show startup logs
)
if __name__ == "__main__":
main()