Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import os | |
| import json | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Define model and checkpoint paths | |
| MODEL_PATH = "microsoft/CADFusion" | |
| REVISION = "2687619" # Use commit hash from the document | |
| # Load model and tokenizer | |
| try: | |
| logger.info("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_PATH, | |
| revision=REVISION, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| revision=REVISION, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| logger.info("Model and tokenizer loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Error loading model or tokenizer: {e}") | |
| raise Exception(f"Failed to load model from {MODEL_PATH} with revision {REVISION}. Please check the repository and revision ID.") | |
| # Function to generate CAD model from text description | |
| def generate_cad_model(text_description): | |
| try: | |
| if not text_description.strip(): | |
| return "Error: Please provide a valid text description." | |
| # Tokenize input | |
| inputs = tokenizer(text_description, return_tensors="pt").to(model.device) | |
| # Generate output | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=512, | |
| num_return_sequences=1, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| # Decode output | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Parse generated text to extract CAD model data (assuming JSON-like output) | |
| try: | |
| cad_data = json.loads(generated_text) | |
| return json.dumps(cad_data, indent=2) | |
| except json.JSONDecodeError: | |
| return generated_text # Return raw text if JSON parsing fails | |
| except Exception as e: | |
| logger.error(f"Error during generation: {e}") | |
| return f"Error: {str(e)}" | |
| # Gradio interface | |
| def create_gradio_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# CADFusion: Text-to-CAD Generation") | |
| gr.Markdown("Enter a textual description of the CAD model you want to generate. For example: 'A 3D model of a chair with four legs and a curved backrest.'") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Text Description", | |
| placeholder="Enter your CAD model description here...", | |
| lines=5 | |
| ) | |
| submit_button = gr.Button("Generate CAD Model") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Generated CAD Model (JSON or Text)", | |
| placeholder="Generated output will appear here...", | |
| lines=10 | |
| ) | |
| submit_button.click( | |
| fn=generate_cad_model, | |
| inputs=text_input, | |
| outputs=output_text | |
| ) | |
| gr.Markdown(""" | |
| **Note**: | |
| - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation. | |
| - Ensure descriptions are clear and specific for best results. | |
| - For more details, visit the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion). | |
| """) | |
| return demo | |
| # Launch Gradio app | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |