KSAutoCAD / app.py
Pasipid791's picture
Update app.py
6b11178 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import json
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define model and checkpoint paths
MODEL_PATH = "microsoft/CADFusion"
REVISION = "2687619" # Use commit hash from the document
# Load model and tokenizer
try:
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
revision=REVISION,
trust_remote_code=True
)
logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
revision=REVISION,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Error loading model or tokenizer: {e}")
raise Exception(f"Failed to load model from {MODEL_PATH} with revision {REVISION}. Please check the repository and revision ID.")
# Function to generate CAD model from text description
def generate_cad_model(text_description):
try:
if not text_description.strip():
return "Error: Please provide a valid text description."
# Tokenize input
inputs = tokenizer(text_description, return_tensors="pt").to(model.device)
# Generate output
outputs = model.generate(
**inputs,
max_length=512,
num_return_sequences=1,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Parse generated text to extract CAD model data (assuming JSON-like output)
try:
cad_data = json.loads(generated_text)
return json.dumps(cad_data, indent=2)
except json.JSONDecodeError:
return generated_text # Return raw text if JSON parsing fails
except Exception as e:
logger.error(f"Error during generation: {e}")
return f"Error: {str(e)}"
# Gradio interface
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# CADFusion: Text-to-CAD Generation")
gr.Markdown("Enter a textual description of the CAD model you want to generate. For example: 'A 3D model of a chair with four legs and a curved backrest.'")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Text Description",
placeholder="Enter your CAD model description here...",
lines=5
)
submit_button = gr.Button("Generate CAD Model")
with gr.Column():
output_text = gr.Textbox(
label="Generated CAD Model (JSON or Text)",
placeholder="Generated output will appear here...",
lines=10
)
submit_button.click(
fn=generate_cad_model,
inputs=text_input,
outputs=output_text
)
gr.Markdown("""
**Note**:
- CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
- Ensure descriptions are clear and specific for best results.
- For more details, visit the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion).
""")
return demo
# Launch Gradio app
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)