import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import os import json import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Define model and checkpoint paths MODEL_PATH = "microsoft/CADFusion" REVISION = "2687619" # Use commit hash from the document # Load model and tokenizer try: logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH, revision=REVISION, trust_remote_code=True ) logger.info("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, revision=REVISION, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) logger.info("Model and tokenizer loaded successfully.") except Exception as e: logger.error(f"Error loading model or tokenizer: {e}") raise Exception(f"Failed to load model from {MODEL_PATH} with revision {REVISION}. Please check the repository and revision ID.") # Function to generate CAD model from text description def generate_cad_model(text_description): try: if not text_description.strip(): return "Error: Please provide a valid text description." # Tokenize input inputs = tokenizer(text_description, return_tensors="pt").to(model.device) # Generate output outputs = model.generate( **inputs, max_length=512, num_return_sequences=1, do_sample=True, temperature=0.7, top_p=0.9 ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Parse generated text to extract CAD model data (assuming JSON-like output) try: cad_data = json.loads(generated_text) return json.dumps(cad_data, indent=2) except json.JSONDecodeError: return generated_text # Return raw text if JSON parsing fails except Exception as e: logger.error(f"Error during generation: {e}") return f"Error: {str(e)}" # Gradio interface def create_gradio_interface(): with gr.Blocks() as demo: gr.Markdown("# CADFusion: Text-to-CAD Generation") gr.Markdown("Enter a textual description of the CAD model you want to generate. For example: 'A 3D model of a chair with four legs and a curved backrest.'") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Text Description", placeholder="Enter your CAD model description here...", lines=5 ) submit_button = gr.Button("Generate CAD Model") with gr.Column(): output_text = gr.Textbox( label="Generated CAD Model (JSON or Text)", placeholder="Generated output will appear here...", lines=10 ) submit_button.click( fn=generate_cad_model, inputs=text_input, outputs=output_text ) gr.Markdown(""" **Note**: - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation. - Ensure descriptions are clear and specific for best results. - For more details, visit the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion). """) return demo # Launch Gradio app if __name__ == "__main__": demo = create_gradio_interface() demo.launch(server_name="0.0.0.0", server_port=7860)