Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +44 -62
- inference.py +40 -56
app.py
CHANGED
|
@@ -45,83 +45,67 @@ class UNetNoCondWrapper(nn.Module):
|
|
| 45 |
return getattr(self.unet, name)
|
| 46 |
|
| 47 |
def save_pretrained(self, save_directory, **kwargs):
|
| 48 |
-
|
| 49 |
return self.unet.save_pretrained(save_directory, **kwargs)
|
| 50 |
|
| 51 |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
# @spaces.GPU
|
| 79 |
|
| 80 |
-
# def pil_to_data_uri(img: Image.Image) -> str:
|
| 81 |
-
# buf = io.BytesIO()
|
| 82 |
-
# img.save(buf, format="PNG")
|
| 83 |
-
# b = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 84 |
-
# return f"data:image/png;base64,{b}"
|
| 85 |
-
|
| 86 |
|
| 87 |
|
| 88 |
def build_textured_cube(pil_imgs, face_rotations=None):
|
| 89 |
"""
|
| 90 |
-
|
| 91 |
-
- pil_imgs: liste/tuple de 4 PIL.Image dans l'ordre [front, right, back, left]
|
| 92 |
-
- Retour: (chemin_absolu_obj, tmpdir)
|
| 93 |
-
Defaults:
|
| 94 |
-
default_rots = {"front": 0, "right": 270, "back": 180, "left": 90, "top": 0, "bottom": 0}
|
| 95 |
-
face_order = ["top","right","bottom","left","front","back"]
|
| 96 |
-
Notes:
|
| 97 |
-
- Ecrit les fichiers dans /tmp/gradio si possible (HF Spaces).
|
| 98 |
-
- front/back utilisent la taille de pil_imgs[0] ; left/right utilisent leur propre largeur (rectangles).
|
| 99 |
"""
|
| 100 |
import os
|
| 101 |
import tempfile
|
| 102 |
from PIL import Image
|
| 103 |
|
| 104 |
-
|
| 105 |
if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4):
|
| 106 |
-
raise ValueError("build_textured_cube
|
| 107 |
|
| 108 |
-
# defaults rotation & ordre
|
| 109 |
default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0}
|
| 110 |
-
# default_rots = {"front": 0, "right": 0, "back": 0, "left": 0, "top": 0, "bottom": 0}
|
| 111 |
if face_rotations is None:
|
| 112 |
face_rotations = default_rots
|
| 113 |
else:
|
| 114 |
for k, v in default_rots.items():
|
| 115 |
face_rotations.setdefault(k, v)
|
| 116 |
|
| 117 |
-
|
| 118 |
base_dir = "/tmp/gradio"
|
| 119 |
if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK):
|
| 120 |
tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir)
|
| 121 |
else:
|
| 122 |
tmpdir = tempfile.mkdtemp(prefix="parallelep_")
|
| 123 |
|
| 124 |
-
#
|
| 125 |
tex_names = {
|
| 126 |
"front": "tex_front.png",
|
| 127 |
"right": "tex_right.png",
|
|
@@ -131,34 +115,32 @@ def build_textured_cube(pil_imgs, face_rotations=None):
|
|
| 131 |
"bottom": "tex_bottom.png",
|
| 132 |
}
|
| 133 |
|
| 134 |
-
# récupérer tailles (on suppose que inference a déjà redimensionné left/right si besoin)
|
| 135 |
front_w, front_h = pil_imgs[0].size
|
| 136 |
right_w, right_h = pil_imgs[1].size
|
| 137 |
|
| 138 |
ratio = 45/145
|
| 139 |
right_w = int(front_w * ratio)
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
width_px = float(front_w)
|
| 143 |
-
height_px = float(right_w)
|
| 144 |
-
# profondeur Y : on prend la largeur des faces latérales (moyenne left/right)
|
| 145 |
depth_px = float(front_h)
|
| 146 |
|
| 147 |
-
#
|
| 148 |
max_dim = max(width_px, depth_px, height_px, 1.0)
|
| 149 |
scale = 1.0 / max_dim
|
| 150 |
-
half_x = (width_px * 0.5) * scale
|
| 151 |
-
half_y = (depth_px * 0.5) * scale
|
| 152 |
-
half_z = (height_px * 0.5) * scale
|
| 153 |
|
| 154 |
-
|
| 155 |
mapping_order = ["front", "right", "back", "left"]
|
| 156 |
-
#
|
| 157 |
for img, face_name in zip(pil_imgs[:4], mapping_order):
|
| 158 |
im = img.convert("RGB")
|
| 159 |
angle = face_rotations.get(face_name, 0)
|
| 160 |
if angle % 360 != 0:
|
| 161 |
-
# PIL rotate: angle
|
| 162 |
im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
|
| 163 |
path = os.path.join(tmpdir, tex_names[face_name])
|
| 164 |
im.save(path, format="PNG")
|
|
@@ -167,7 +149,7 @@ def build_textured_cube(pil_imgs, face_rotations=None):
|
|
| 167 |
except Exception:
|
| 168 |
pass
|
| 169 |
|
| 170 |
-
# top/bottom
|
| 171 |
black = Image.new("RGB", (front_w, front_h), (0, 0, 0))
|
| 172 |
for face_name in ("top", "bottom"):
|
| 173 |
im = black
|
|
@@ -181,7 +163,7 @@ def build_textured_cube(pil_imgs, face_rotations=None):
|
|
| 181 |
except Exception:
|
| 182 |
pass
|
| 183 |
|
| 184 |
-
# ---
|
| 185 |
mtl_path = os.path.join(tmpdir, "parallelep.mtl")
|
| 186 |
with open(mtl_path, "w", encoding="utf-8") as f:
|
| 187 |
f.write("# Material file for parallelepiped\n")
|
|
@@ -294,7 +276,7 @@ def build_textured_cube(pil_imgs, face_rotations=None):
|
|
| 294 |
# -------------------------
|
| 295 |
def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
|
| 296 |
try:
|
| 297 |
-
outputs = inference(fibers, rings, num_steps)
|
| 298 |
if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4):
|
| 299 |
raise ValueError("user_inference must return a list/tuple of 4 images.")
|
| 300 |
|
|
|
|
| 45 |
return getattr(self.unet, name)
|
| 46 |
|
| 47 |
def save_pretrained(self, save_directory, **kwargs):
|
| 48 |
+
|
| 49 |
return self.unet.save_pretrained(save_directory, **kwargs)
|
| 50 |
|
| 51 |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 52 |
|
| 53 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 54 |
+
|
| 55 |
+
model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
|
| 56 |
+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
|
| 57 |
+
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 58 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
| 59 |
+
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
|
| 60 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
|
| 61 |
+
|
| 62 |
+
# 2) Chargez votre UNet non‑conditionné et wrappez‑le
|
| 63 |
+
base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
|
| 64 |
+
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
|
| 65 |
+
|
| 66 |
+
# 3) Construisez la pipeline manuellement
|
| 67 |
+
pipe = StableDiffusionInstructPix2PixPipeline(
|
| 68 |
+
vae=vae,
|
| 69 |
+
text_encoder=text_encoder,
|
| 70 |
+
tokenizer=tokenizer,
|
| 71 |
+
unet=wrapped_unet,
|
| 72 |
+
scheduler=scheduler,
|
| 73 |
+
safety_checker=None,
|
| 74 |
+
feature_extractor=feature_extractor,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
pipe = pipe.to(torch.float32).to(device)
|
| 78 |
# @spaces.GPU
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def build_textured_cube(pil_imgs, face_rotations=None):
|
| 83 |
"""
|
| 84 |
+
Creates a textured parallelepiped (OBJ + MTL + textures).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
import os
|
| 87 |
import tempfile
|
| 88 |
from PIL import Image
|
| 89 |
|
| 90 |
+
|
| 91 |
if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4):
|
| 92 |
+
raise ValueError("build_textured_cube expects a list/tuple of 4 PIL images (front, right, back, left).")
|
| 93 |
|
|
|
|
| 94 |
default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0}
|
|
|
|
| 95 |
if face_rotations is None:
|
| 96 |
face_rotations = default_rots
|
| 97 |
else:
|
| 98 |
for k, v in default_rots.items():
|
| 99 |
face_rotations.setdefault(k, v)
|
| 100 |
|
| 101 |
+
|
| 102 |
base_dir = "/tmp/gradio"
|
| 103 |
if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK):
|
| 104 |
tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir)
|
| 105 |
else:
|
| 106 |
tmpdir = tempfile.mkdtemp(prefix="parallelep_")
|
| 107 |
|
| 108 |
+
# relative names for textures (mtl will use these names)
|
| 109 |
tex_names = {
|
| 110 |
"front": "tex_front.png",
|
| 111 |
"right": "tex_right.png",
|
|
|
|
| 115 |
"bottom": "tex_bottom.png",
|
| 116 |
}
|
| 117 |
|
|
|
|
| 118 |
front_w, front_h = pil_imgs[0].size
|
| 119 |
right_w, right_h = pil_imgs[1].size
|
| 120 |
|
| 121 |
ratio = 45/145
|
| 122 |
right_w = int(front_w * ratio)
|
| 123 |
|
| 124 |
+
# define the physical dimensions of the parallelepiped (in “px”, then normalize)
|
| 125 |
+
width_px = float(front_w)
|
| 126 |
+
height_px = float(right_w)
|
|
|
|
| 127 |
depth_px = float(front_h)
|
| 128 |
|
| 129 |
+
# normalization to keep coordinates within ±0.5
|
| 130 |
max_dim = max(width_px, depth_px, height_px, 1.0)
|
| 131 |
scale = 1.0 / max_dim
|
| 132 |
+
half_x = (width_px * 0.5) * scale
|
| 133 |
+
half_y = (depth_px * 0.5) * scale
|
| 134 |
+
half_z = (height_px * 0.5) * scale
|
| 135 |
|
| 136 |
+
|
| 137 |
mapping_order = ["front", "right", "back", "left"]
|
| 138 |
+
# save textures in tmpdir
|
| 139 |
for img, face_name in zip(pil_imgs[:4], mapping_order):
|
| 140 |
im = img.convert("RGB")
|
| 141 |
angle = face_rotations.get(face_name, 0)
|
| 142 |
if angle % 360 != 0:
|
| 143 |
+
# PIL rotate: angle in degrees, positive = CCW
|
| 144 |
im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
|
| 145 |
path = os.path.join(tmpdir, tex_names[face_name])
|
| 146 |
im.save(path, format="PNG")
|
|
|
|
| 149 |
except Exception:
|
| 150 |
pass
|
| 151 |
|
| 152 |
+
# black top/bottom
|
| 153 |
black = Image.new("RGB", (front_w, front_h), (0, 0, 0))
|
| 154 |
for face_name in ("top", "bottom"):
|
| 155 |
im = black
|
|
|
|
| 163 |
except Exception:
|
| 164 |
pass
|
| 165 |
|
| 166 |
+
# --- write .mtl ---
|
| 167 |
mtl_path = os.path.join(tmpdir, "parallelep.mtl")
|
| 168 |
with open(mtl_path, "w", encoding="utf-8") as f:
|
| 169 |
f.write("# Material file for parallelepiped\n")
|
|
|
|
| 276 |
# -------------------------
|
| 277 |
def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
|
| 278 |
try:
|
| 279 |
+
outputs = inference(pipe, fibers, rings, num_steps)
|
| 280 |
if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4):
|
| 281 |
raise ValueError("user_inference must return a list/tuple of 4 images.")
|
| 282 |
|
inference.py
CHANGED
|
@@ -21,8 +21,8 @@ def pil_from(x):
|
|
| 21 |
if isinstance(x, str):
|
| 22 |
return PIL.Image.open(x)
|
| 23 |
return x
|
| 24 |
-
|
| 25 |
-
|
| 26 |
"""
|
| 27 |
fiber_imgs: PIL.Image or paths
|
| 28 |
ring_imgs: PIL.Image or paths
|
|
@@ -31,68 +31,52 @@ def inference(fiber_imgs, ring_imgs, num_steps):
|
|
| 31 |
returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]
|
| 32 |
"""
|
| 33 |
# seed + generator
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
# sizes
|
| 39 |
tile = 512
|
| 40 |
canvas_size = tile * 2
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# # ensure pred is canvas_size x canvas_size
|
| 74 |
-
# if pred.size != (canvas_size, canvas_size):
|
| 75 |
-
# pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
|
| 76 |
|
| 77 |
# split into 4 tiles in same order TL, TR, BL, BR
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
|
| 83 |
-
ring_imgs = PIL.Image.fromarray(ring_imgs)
|
| 84 |
-
fiber_imgs = PIL.Image.fromarray(fiber_imgs)
|
| 85 |
|
| 86 |
-
tl = ring_imgs.crop((0, 0, tile, tile))
|
| 87 |
-
tr = ring_imgs.crop((tile, 0, canvas_size, tile))
|
| 88 |
-
bl = ring_imgs.crop((0, tile, tile, canvas_size))
|
| 89 |
-
br = ring_imgs.crop((tile, tile, canvas_size, canvas_size))
|
| 90 |
-
|
| 91 |
-
# tr = cv2.resize(np.asarray(tr), (new_width, original_height), interpolation=cv2.INTER_LANCZOS4)
|
| 92 |
-
# br = cv2.resize(np.asarray(br), (new_width, original_height), interpolation=cv2.INTER_LANCZOS4)
|
| 93 |
-
|
| 94 |
-
# tr = PIL.Image.fromarray(tr)
|
| 95 |
-
# br = PIL.Image.fromarray(br)
|
| 96 |
|
| 97 |
# close opened images to free handles
|
| 98 |
fiber_imgs.close()
|
|
|
|
| 21 |
if isinstance(x, str):
|
| 22 |
return PIL.Image.open(x)
|
| 23 |
return x
|
| 24 |
+
|
| 25 |
+
def inference(pipe, fiber_imgs, ring_imgs, num_steps):
|
| 26 |
"""
|
| 27 |
fiber_imgs: PIL.Image or paths
|
| 28 |
ring_imgs: PIL.Image or paths
|
|
|
|
| 31 |
returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]
|
| 32 |
"""
|
| 33 |
# seed + generator
|
| 34 |
+
seed = random.randrange(0, 2**32)
|
| 35 |
+
torch.manual_seed(seed)
|
| 36 |
+
generator = torch.Generator("cpu").manual_seed(seed)
|
| 37 |
|
| 38 |
# sizes
|
| 39 |
tile = 512
|
| 40 |
canvas_size = tile * 2
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# stack channels: [fiber, ring, ring] -> H,W,3
|
| 44 |
+
arr_f = np.array(fiber_imgs).astype(np.uint8)
|
| 45 |
+
arr_r = np.array(ring_imgs).astype(np.uint8)
|
| 46 |
+
arr_in = np.stack([arr_f, arr_r, arr_r], axis=2) # H,W,3
|
| 47 |
+
input_image = PIL.Image.fromarray(arr_in) # PIL RGB
|
| 48 |
+
|
| 49 |
+
# run pipeline (use autocast consistent with device)
|
| 50 |
+
if torch.backends.mps.is_available():
|
| 51 |
+
autocast_ctx = nullcontext()
|
| 52 |
+
else:
|
| 53 |
+
autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
|
| 54 |
+
|
| 55 |
+
with autocast_ctx:
|
| 56 |
+
out = pipe(
|
| 57 |
+
prompt="", # empty prompt (your model ignores prompt)
|
| 58 |
+
image=input_image,
|
| 59 |
+
num_inference_steps=num_steps,
|
| 60 |
+
image_guidance_scale=1.9,
|
| 61 |
+
guidance_scale=10.0,
|
| 62 |
+
generator=generator,
|
| 63 |
+
safety_checker=None,
|
| 64 |
+
num_images_per_prompt=1,
|
| 65 |
+
)
|
| 66 |
+
# out.images may be a list; take first
|
| 67 |
+
pred = out.images[0]
|
| 68 |
+
|
| 69 |
+
# ensure pred is canvas_size x canvas_size
|
| 70 |
+
if pred.size != (canvas_size, canvas_size):
|
| 71 |
+
pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# split into 4 tiles in same order TL, TR, BL, BR
|
| 74 |
+
tl = pred.crop((0, 0, tile, tile))
|
| 75 |
+
tr = pred.crop((tile, 0, canvas_size, tile))
|
| 76 |
+
bl = pred.crop((0, tile, tile, canvas_size))
|
| 77 |
+
br = pred.crop((tile, tile, canvas_size, canvas_size))
|
| 78 |
|
|
|
|
|
|
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# close opened images to free handles
|
| 82 |
fiber_imgs.close()
|