valegro's picture
Update app.py
4cb54c8 verified
import os, pathlib, io
import numpy as np, cv2, torch, trimesh
from PIL import Image
import gradio as gr
from segment_anything import sam_model_registry, SamPredictor
from diffusers import StableDiffusionControlNetPipeline
# SHAP-E main branch (import minimal, potresti dover adattare in futuro)
try:
from shap_e.diffusion.sample import build_model
except ImportError:
build_model = None
# — checkpoint directory —
CKPT_DIR = pathlib.Path("checkpoints")
CKPT_DIR.mkdir(exist_ok=True)
SAM_PATH = CKPT_DIR / "sam_vit_b.pth"
CN_PATH = CKPT_DIR / "controlnet_scribble.bin"
def fetch(url, dst):
if dst.exists(): return
os.system(f"wget -q {url} -O {dst}")
# scarica i checkpoint se assenti
fetch(
"https://github.com/facebookresearch/segment-anything/releases/download/v0.1.0/sam_vit_b_01ec64.pth",
SAM_PATH,
)
fetch(
"https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/pytorch_model.bin",
CN_PATH,
)
device = torch.device("cpu")
# SAM
sam = sam_model_registry["vit_b"](checkpoint=str(SAM_PATH)).to(device)
predictor = SamPredictor(sam)
# Stable Diffusion + ControlNet
controlnet = StableDiffusionControlNetPipeline.from_pretrained(
"lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float32
).controlnet.to(device)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch.float32
).to(device)
pipe.enable_attention_slicing()
# SHAP-E (se presente)
if build_model:
shap_e_model = build_model("diffusion", device=device)
shap_e_model.eval()
else:
shap_e_model = None
# — helper pipeline —
def get_silhouette(img: Image.Image):
arr = np.array(img.convert("RGB"))
predictor.set_image(arr)
mask, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=None,
multimask_output=False
)
return (mask[0].astype(np.uint8) * 255)
def generate_concept(mask: np.ndarray, prompt: str):
mask_rgb = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
out = pipe(
prompt=prompt,
image=Image.fromarray(mask_rgb),
num_inference_steps=30,
guidance_scale=7.5
)
return out.images[0]
def generate_mesh(concept_img: Image.Image):
if shap_e_model is None:
# fallback stub: restituisco un cubo STL minimo
cube = trimesh.creation.box(extents=(1,1,1))
return cube
buf = io.BytesIO()
concept_img.save(buf, format="PNG")
buf.seek(0)
# decodifica SHAP-E (API stabile se build_model esiste)
mesh = shap_e_model.sample_latents_and_decode(buf)
return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
def export_stl(mesh: trimesh.Trimesh):
out_path = "output_eol.stl"
mesh.export(out_path)
return out_path
def eol_pipeline(image, prompt):
mask = get_silhouette(image)
concept = generate_concept(mask, prompt)
mesh = generate_mesh(concept)
stl = export_stl(mesh)
return concept, stl
# — Gradio UI —
demo = gr.Interface(
fn=eol_pipeline,
inputs=[
gr.Image(type="pil", label="Foto componente EoL"),
gr.Textbox(lines=2, placeholder="Prompt creativo…")
],
outputs=[
gr.Image(type="pil", label="Concept 2D"),
gr.File(label="Scarica STL")
],
title="EoL Component Generative Design (CPU-only)",
description="Upload foto componente EoL ➜ concept 2D ➜ mesh 3D STL"
)
if __name__ == "__main__":
demo.launch()