import os, pathlib, io import numpy as np, cv2, torch, trimesh from PIL import Image import gradio as gr from segment_anything import sam_model_registry, SamPredictor from diffusers import StableDiffusionControlNetPipeline # SHAP-E main branch (import minimal, potresti dover adattare in futuro) try: from shap_e.diffusion.sample import build_model except ImportError: build_model = None # — checkpoint directory — CKPT_DIR = pathlib.Path("checkpoints") CKPT_DIR.mkdir(exist_ok=True) SAM_PATH = CKPT_DIR / "sam_vit_b.pth" CN_PATH = CKPT_DIR / "controlnet_scribble.bin" def fetch(url, dst): if dst.exists(): return os.system(f"wget -q {url} -O {dst}") # scarica i checkpoint se assenti fetch( "https://github.com/facebookresearch/segment-anything/releases/download/v0.1.0/sam_vit_b_01ec64.pth", SAM_PATH, ) fetch( "https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/pytorch_model.bin", CN_PATH, ) device = torch.device("cpu") # SAM sam = sam_model_registry["vit_b"](checkpoint=str(SAM_PATH)).to(device) predictor = SamPredictor(sam) # Stable Diffusion + ControlNet controlnet = StableDiffusionControlNetPipeline.from_pretrained( "lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float32 ).controlnet.to(device) pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32 ).to(device) pipe.enable_attention_slicing() # SHAP-E (se presente) if build_model: shap_e_model = build_model("diffusion", device=device) shap_e_model.eval() else: shap_e_model = None # — helper pipeline — def get_silhouette(img: Image.Image): arr = np.array(img.convert("RGB")) predictor.set_image(arr) mask, _, _ = predictor.predict( point_coords=None, point_labels=None, box=None, multimask_output=False ) return (mask[0].astype(np.uint8) * 255) def generate_concept(mask: np.ndarray, prompt: str): mask_rgb = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) out = pipe( prompt=prompt, image=Image.fromarray(mask_rgb), num_inference_steps=30, guidance_scale=7.5 ) return out.images[0] def generate_mesh(concept_img: Image.Image): if shap_e_model is None: # fallback stub: restituisco un cubo STL minimo cube = trimesh.creation.box(extents=(1,1,1)) return cube buf = io.BytesIO() concept_img.save(buf, format="PNG") buf.seek(0) # decodifica SHAP-E (API stabile se build_model esiste) mesh = shap_e_model.sample_latents_and_decode(buf) return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) def export_stl(mesh: trimesh.Trimesh): out_path = "output_eol.stl" mesh.export(out_path) return out_path def eol_pipeline(image, prompt): mask = get_silhouette(image) concept = generate_concept(mask, prompt) mesh = generate_mesh(concept) stl = export_stl(mesh) return concept, stl # — Gradio UI — demo = gr.Interface( fn=eol_pipeline, inputs=[ gr.Image(type="pil", label="Foto componente EoL"), gr.Textbox(lines=2, placeholder="Prompt creativo…") ], outputs=[ gr.Image(type="pil", label="Concept 2D"), gr.File(label="Scarica STL") ], title="EoL Component Generative Design (CPU-only)", description="Upload foto componente EoL ➜ concept 2D ➜ mesh 3D STL" ) if __name__ == "__main__": demo.launch()