Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import open3d as o3d | |
| import numpy as np | |
| import cadquery as cq | |
| # Load the tokenizer from Qwen2-1.5B and model weights from filapro/cad-recode | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B", trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained("filapro/cad-recode", trust_remote_code=True) | |
| # Set device (GPU if available, CPU otherwise) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| print(f"Model loaded on {device}") | |
| def load_point_cloud(file): | |
| """Loads a point cloud from a uploaded file.""" | |
| if not file: | |
| return None | |
| if file.type not in ("application/octet-stream", "text/plain"): | |
| st.error("Please upload a point cloud file (.pcd, .xyz, etc.)") | |
| return None | |
| try: | |
| point_cloud = o3d.io.read_point_cloud(file) | |
| except Exception as e: | |
| st.error(f"Error loading point cloud: {e}") | |
| return None | |
| return point_cloud | |
| def prepare_input_data(point_cloud): | |
| """Prepares point cloud data for model input.""" | |
| if not point_cloud: | |
| return None | |
| point_cloud_array = np.asarray(point_cloud.points).flatten() | |
| input_text = " ".join(map(str, point_cloud_array)) | |
| return input_text | |
| def generate_cad_code(input_text): | |
| """Runs inference and decodes generated output.""" | |
| if not input_text: | |
| return None | |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id) | |
| cad_code = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return cad_code | |
| def generate_cad_model(cad_code): | |
| """Generates a CAD model from the provided code.""" | |
| if not cad_code: | |
| return None | |
| try: | |
| # Execute CAD code using CadQuery library | |
| exec(cad_code) | |
| cad_model = cq.Workplane("XY").val() | |
| except Exception as e: | |
| st.error(f"Error generating CAD model: {e}") | |
| return None | |
| return cad_model | |
| def main(): | |
| """Streamlit app for point cloud to CAD code conversion.""" | |
| st.title("Point Cloud to CAD Code Converter") | |
| st.write("This app uses the filapro/cad-recode model to generate Python code for a 3D CAD model from your point cloud data.") | |
| uploaded_file = st.file_uploader("Upload Point Cloud File") | |
| point_cloud = load_point_cloud(uploaded_file) | |
| if point_cloud: | |
| input_text = prepare_input_data(point_cloud) | |
| cad_code = generate_cad_code(input_text) | |
| if cad_code: | |
| st.success("Generated Python CAD Code:") | |
| st.code(cad_code) | |
| cad_model = generate_cad_model(cad_code) | |
| if cad_model: | |
| # Optionally, use a 3D visualization library like trimesh | |
| # to display the generated CAD model (not included) | |
| st.success("Generated CAD Model (Visualization not yet implemented)") | |
| # st.write(cad_model) # Replace with visualization code | |
| if __name__ == "__main__": | |
| main() |