File size: 3,523 Bytes
89684b8
 
 
 
 
 
 
 
4cb54c8
 
 
 
 
 
 
 
 
89684b8
 
 
 
 
 
 
4cb54c8
 
 
 
 
 
 
 
 
89684b8
 
 
4cb54c8
 
89684b8
 
4cb54c8
89684b8
 
 
 
 
 
 
 
 
 
4cb54c8
 
 
 
 
 
89684b8
4cb54c8
89684b8
 
 
4cb54c8
 
 
 
 
 
89684b8
 
4cb54c8
 
 
 
 
 
 
 
89684b8
 
4cb54c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89684b8
 
 
 
 
 
 
 
4cb54c8
89684b8
 
 
 
 
 
 
 
 
 
4cb54c8
 
89684b8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os, pathlib, io
import numpy as np, cv2, torch, trimesh
from PIL import Image
import gradio as gr

from segment_anything import sam_model_registry, SamPredictor
from diffusers import StableDiffusionControlNetPipeline

# SHAP-E main branch (import minimal, potresti dover adattare in futuro)
try:
    from shap_e.diffusion.sample import build_model
except ImportError:
    build_model = None

# — checkpoint directory —
CKPT_DIR = pathlib.Path("checkpoints")
CKPT_DIR.mkdir(exist_ok=True)
SAM_PATH = CKPT_DIR / "sam_vit_b.pth"
CN_PATH  = CKPT_DIR / "controlnet_scribble.bin"

def fetch(url, dst):
    if dst.exists(): return
    os.system(f"wget -q {url} -O {dst}")

# scarica i checkpoint se assenti
fetch(
    "https://github.com/facebookresearch/segment-anything/releases/download/v0.1.0/sam_vit_b_01ec64.pth",
    SAM_PATH,
)
fetch(
    "https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/pytorch_model.bin",
    CN_PATH,
)

device = torch.device("cpu")

# SAM
sam   = sam_model_registry["vit_b"](checkpoint=str(SAM_PATH)).to(device)
predictor = SamPredictor(sam)

# Stable Diffusion + ControlNet
controlnet = StableDiffusionControlNetPipeline.from_pretrained(
    "lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float32
).controlnet.to(device)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet,
    torch_dtype=torch.float32
).to(device)
pipe.enable_attention_slicing()

# SHAP-E (se presente)
if build_model:
    shap_e_model = build_model("diffusion", device=device)
    shap_e_model.eval()
else:
    shap_e_model = None

# — helper pipeline —
def get_silhouette(img: Image.Image):
    arr = np.array(img.convert("RGB"))
    predictor.set_image(arr)
    mask, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=None,
        multimask_output=False
    )
    return (mask[0].astype(np.uint8) * 255)

def generate_concept(mask: np.ndarray, prompt: str):
    mask_rgb = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
    out = pipe(
        prompt=prompt,
        image=Image.fromarray(mask_rgb),
        num_inference_steps=30,
        guidance_scale=7.5
    )
    return out.images[0]

def generate_mesh(concept_img: Image.Image):
    if shap_e_model is None:
        # fallback stub: restituisco un cubo STL minimo
        cube = trimesh.creation.box(extents=(1,1,1))
        return cube
    buf = io.BytesIO()
    concept_img.save(buf, format="PNG")
    buf.seek(0)
    # decodifica SHAP-E (API stabile se build_model esiste)
    mesh = shap_e_model.sample_latents_and_decode(buf)
    return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)

def export_stl(mesh: trimesh.Trimesh):
    out_path = "output_eol.stl"
    mesh.export(out_path)
    return out_path

def eol_pipeline(image, prompt):
    mask    = get_silhouette(image)
    concept = generate_concept(mask, prompt)
    mesh    = generate_mesh(concept)
    stl     = export_stl(mesh)
    return concept, stl

# — Gradio UI —
demo = gr.Interface(
    fn=eol_pipeline,
    inputs=[
        gr.Image(type="pil", label="Foto componente EoL"),
        gr.Textbox(lines=2, placeholder="Prompt creativo…")
    ],
    outputs=[
        gr.Image(type="pil", label="Concept 2D"),
        gr.File(label="Scarica STL")
    ],
    title="EoL Component Generative Design (CPU-only)",
    description="Upload foto componente EoL ➜ concept 2D ➜ mesh 3D STL"
)

if __name__ == "__main__":
    demo.launch()