File size: 3,772 Bytes
6b11178
 
 
c18fcc0
 
6b11178
 
 
 
 
bcd6ea3
6b11178
 
 
bcd6ea3
6b11178
bcd6ea3
6b11178
 
 
 
 
bcd6ea3
6b11178
 
 
 
 
 
 
 
 
 
 
 
e7a03ef
6b11178
 
c18fcc0
6b11178
 
 
 
 
 
 
 
 
 
 
 
 
 
c18fcc0
 
6b11178
 
c18fcc0
6b11178
 
 
 
 
 
c18fcc0
6b11178
 
fd70c1b
6b11178
c18fcc0
6b11178
 
 
c18fcc0
 
6b11178
c18fcc0
6b11178
 
 
c18fcc0
6b11178
c18fcc0
6b11178
 
 
 
 
e7a03ef
 
6b11178
 
 
 
e7a03ef
4e1acb7
6b11178
 
 
 
 
 
e7a03ef
6b11178
fd70c1b
6b11178
fd70c1b
6b11178
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import json
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define model and checkpoint paths
MODEL_PATH = "microsoft/CADFusion"
REVISION = "2687619"  # Use commit hash from the document

# Load model and tokenizer
try:
    logger.info("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        revision=REVISION,
        trust_remote_code=True
    )
    logger.info("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        revision=REVISION,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
    logger.error(f"Error loading model or tokenizer: {e}")
    raise Exception(f"Failed to load model from {MODEL_PATH} with revision {REVISION}. Please check the repository and revision ID.")

# Function to generate CAD model from text description
def generate_cad_model(text_description):
    try:
        if not text_description.strip():
            return "Error: Please provide a valid text description."
        
        # Tokenize input
        inputs = tokenizer(text_description, return_tensors="pt").to(model.device)
        
        # Generate output
        outputs = model.generate(
            **inputs,
            max_length=512,
            num_return_sequences=1,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )
        
        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Parse generated text to extract CAD model data (assuming JSON-like output)
        try:
            cad_data = json.loads(generated_text)
            return json.dumps(cad_data, indent=2)
        except json.JSONDecodeError:
            return generated_text  # Return raw text if JSON parsing fails
    except Exception as e:
        logger.error(f"Error during generation: {e}")
        return f"Error: {str(e)}"

# Gradio interface
def create_gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# CADFusion: Text-to-CAD Generation")
        gr.Markdown("Enter a textual description of the CAD model you want to generate. For example: 'A 3D model of a chair with four legs and a curved backrest.'")
        
        with gr.Row():
            with gr.Column():
                text_input = gr.Textbox(
                    label="Text Description",
                    placeholder="Enter your CAD model description here...",
                    lines=5
                )
                submit_button = gr.Button("Generate CAD Model")
            
            with gr.Column():
                output_text = gr.Textbox(
                    label="Generated CAD Model (JSON or Text)",
                    placeholder="Generated output will appear here...",
                    lines=10
                )
        
        submit_button.click(
            fn=generate_cad_model,
            inputs=text_input,
            outputs=output_text
        )
        
        gr.Markdown("""
        **Note**: 
        - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
        - Ensure descriptions are clear and specific for best results.
        - For more details, visit the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion).
        """)
    
    return demo

# Launch Gradio app
if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860)