Spaces:
Build error
Build error
| import os, pathlib, io | |
| import numpy as np, cv2, torch, trimesh | |
| from PIL import Image | |
| import gradio as gr | |
| from segment_anything import sam_model_registry, SamPredictor | |
| from diffusers import StableDiffusionControlNetPipeline | |
| # SHAP-E main branch (import minimal, potresti dover adattare in futuro) | |
| try: | |
| from shap_e.diffusion.sample import build_model | |
| except ImportError: | |
| build_model = None | |
| # — checkpoint directory — | |
| CKPT_DIR = pathlib.Path("checkpoints") | |
| CKPT_DIR.mkdir(exist_ok=True) | |
| SAM_PATH = CKPT_DIR / "sam_vit_b.pth" | |
| CN_PATH = CKPT_DIR / "controlnet_scribble.bin" | |
| def fetch(url, dst): | |
| if dst.exists(): return | |
| os.system(f"wget -q {url} -O {dst}") | |
| # scarica i checkpoint se assenti | |
| fetch( | |
| "https://github.com/facebookresearch/segment-anything/releases/download/v0.1.0/sam_vit_b_01ec64.pth", | |
| SAM_PATH, | |
| ) | |
| fetch( | |
| "https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/pytorch_model.bin", | |
| CN_PATH, | |
| ) | |
| device = torch.device("cpu") | |
| # SAM | |
| sam = sam_model_registry["vit_b"](checkpoint=str(SAM_PATH)).to(device) | |
| predictor = SamPredictor(sam) | |
| # Stable Diffusion + ControlNet | |
| controlnet = StableDiffusionControlNetPipeline.from_pretrained( | |
| "lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float32 | |
| ).controlnet.to(device) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=controlnet, | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| pipe.enable_attention_slicing() | |
| # SHAP-E (se presente) | |
| if build_model: | |
| shap_e_model = build_model("diffusion", device=device) | |
| shap_e_model.eval() | |
| else: | |
| shap_e_model = None | |
| # — helper pipeline — | |
| def get_silhouette(img: Image.Image): | |
| arr = np.array(img.convert("RGB")) | |
| predictor.set_image(arr) | |
| mask, _, _ = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=None, | |
| multimask_output=False | |
| ) | |
| return (mask[0].astype(np.uint8) * 255) | |
| def generate_concept(mask: np.ndarray, prompt: str): | |
| mask_rgb = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) | |
| out = pipe( | |
| prompt=prompt, | |
| image=Image.fromarray(mask_rgb), | |
| num_inference_steps=30, | |
| guidance_scale=7.5 | |
| ) | |
| return out.images[0] | |
| def generate_mesh(concept_img: Image.Image): | |
| if shap_e_model is None: | |
| # fallback stub: restituisco un cubo STL minimo | |
| cube = trimesh.creation.box(extents=(1,1,1)) | |
| return cube | |
| buf = io.BytesIO() | |
| concept_img.save(buf, format="PNG") | |
| buf.seek(0) | |
| # decodifica SHAP-E (API stabile se build_model esiste) | |
| mesh = shap_e_model.sample_latents_and_decode(buf) | |
| return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) | |
| def export_stl(mesh: trimesh.Trimesh): | |
| out_path = "output_eol.stl" | |
| mesh.export(out_path) | |
| return out_path | |
| def eol_pipeline(image, prompt): | |
| mask = get_silhouette(image) | |
| concept = generate_concept(mask, prompt) | |
| mesh = generate_mesh(concept) | |
| stl = export_stl(mesh) | |
| return concept, stl | |
| # — Gradio UI — | |
| demo = gr.Interface( | |
| fn=eol_pipeline, | |
| inputs=[ | |
| gr.Image(type="pil", label="Foto componente EoL"), | |
| gr.Textbox(lines=2, placeholder="Prompt creativo…") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Concept 2D"), | |
| gr.File(label="Scarica STL") | |
| ], | |
| title="EoL Component Generative Design (CPU-only)", | |
| description="Upload foto componente EoL ➜ concept 2D ➜ mesh 3D STL" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |