Spaces:
Sleeping
Sleeping
File size: 11,605 Bytes
3c8903a 8c9164c 3c8903a 98f01d7 3c8903a 0600e9e 3c8903a 0600e9e 74eb73a 9299372 973f3eb da6818c 98f01d7 0600e9e 98f01d7 da6818c 0600e9e da6818c 0600e9e da6818c a28b5ab da6818c 0600e9e da6818c 0600e9e da6818c 377bc40 da6818c 2ec02af da6818c 0600e9e a307280 da6818c 0600e9e da6818c 0600e9e da6818c 0600e9e da6818c 0600e9e da6818c 0600e9e da6818c 0600e9e da6818c 0600e9e da6818c a28b5ab da6818c 04fbcf3 da6818c 007755c 04fbcf3 4348f60 da6818c 04fbcf3 da6818c 04fbcf3 4348f60 da6818c 04fbcf3 da6818c 04fbcf3 4348f60 da6818c 04fbcf3 4348f60 da6818c 007755c 04fbcf3 4348f60 da6818c a28b5ab da6818c 007755c 6324410 98f01d7 04fbcf3 98f01d7 c9c18ab 8c9164c 0600e9e 98f01d7 a28b5ab 04fbcf3 8c9164c 98f01d7 8c9164c 8ab5e1a 8c9164c 04fbcf3 ce50616 98f01d7 04fbcf3 98f01d7 8c9164c 98f01d7 8c9164c 98f01d7 8c9164c 98f01d7 6324410 04fbcf3 6324410 8c9164c 6324410 93cd9b7 6324410 8c9164c 6324410 8c9164c 04fbcf3 8c9164c 93cd9b7 3c8903a 98f01d7 6324410 3c8903a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 10 11:16:28 2025
@author: camaac
"""
import gradio as gr
from PIL import Image
import io, base64, json, traceback
import torch
from inference import inference
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
import torch.nn as nn
import numpy as np
import tempfile
import os
import shutil
import uuid
class UNetNoCondWrapper(nn.Module):
def __init__(self, base_unet: UNet2DModel):
super().__init__()
self.unet = base_unet
def forward(
self,
sample,
timestep,
encoder_hidden_states=None,
added_cond_kwargs=None,
cross_attention_kwargs=None,
return_dict=False,
**kwargs
):
return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
def __getattr__(self, name):
if name in ("unet", "forward", "__getstate__", "__setstate__"):
return super().__getattr__(name)
return getattr(self.unet, name)
def save_pretrained(self, save_directory, **kwargs):
return self.unet.save_pretrained(save_directory, **kwargs)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
# 2) Chargez votre UNet non‑conditionné et wrappez‑le
base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
# 3) Construisez la pipeline manuellement
pipe = StableDiffusionInstructPix2PixPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=wrapped_unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
pipe = pipe.to(torch.float32).to(device)
# @spaces.GPU
def build_textured_cube(pil_imgs, face_rotations=None):
"""
Creates a textured parallelepiped (OBJ + MTL + textures).
"""
import os
import tempfile
from PIL import Image
if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4):
raise ValueError("build_textured_cube expects a list/tuple of 4 PIL images (front, right, back, left).")
default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0}
if face_rotations is None:
face_rotations = default_rots
else:
for k, v in default_rots.items():
face_rotations.setdefault(k, v)
base_dir = "/tmp/gradio"
if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK):
tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir)
else:
tmpdir = tempfile.mkdtemp(prefix="parallelep_")
# relative names for textures (mtl will use these names)
tex_names = {
"front": "tex_front.png",
"right": "tex_right.png",
"back": "tex_back.png",
"left": "tex_left.png",
"top": "tex_top.png",
"bottom": "tex_bottom.png",
}
front_w, front_h = pil_imgs[0].size
right_w, right_h = pil_imgs[1].size
ratio = 45/145
right_w = int(front_w * ratio)
# define the physical dimensions of the parallelepiped (in “px”, then normalize)
width_px = float(front_w)
height_px = float(right_w)
depth_px = float(front_h)
# normalization to keep coordinates within ±0.5
max_dim = max(width_px, depth_px, height_px, 1.0)
scale = 1.0 / max_dim
half_x = (width_px * 0.5) * scale
half_y = (depth_px * 0.5) * scale
half_z = (height_px * 0.5) * scale
mapping_order = ["front", "right", "back", "left"]
# save textures in tmpdir
for img, face_name in zip(pil_imgs[:4], mapping_order):
im = img.convert("RGB")
angle = face_rotations.get(face_name, 0)
if angle % 360 != 0:
# PIL rotate: angle in degrees, positive = CCW
im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
path = os.path.join(tmpdir, tex_names[face_name])
im.save(path, format="PNG")
try:
os.chmod(path, 0o644)
except Exception:
pass
# black top/bottom
black = Image.new("RGB", (front_w, front_h), (0, 0, 0))
for face_name in ("top", "bottom"):
im = black
angle = face_rotations.get(face_name, 0)
if angle % 360 != 0:
im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
p = os.path.join(tmpdir, tex_names[face_name])
im.save(p, format="PNG")
try:
os.chmod(p, 0o644)
except Exception:
pass
# --- write .mtl ---
mtl_path = os.path.join(tmpdir, "parallelep.mtl")
with open(mtl_path, "w", encoding="utf-8") as f:
f.write("# Material file for parallelepiped\n")
for mat_name, tex_file in tex_names.items():
f.write(f"newmtl m_{mat_name}\n")
f.write("Ka 1.000 1.000 1.000\n")
f.write("Kd 1.000 1.000 1.000\n")
f.write("Ks 0.000 0.000 0.000\n")
f.write("Ns 10.000\n")
f.write("illum 2\n")
f.write(f"map_Kd {tex_file}\n\n")
try:
os.chmod(mtl_path, 0o644)
except Exception:
pass
# --- geometry: define quads per face (CCW when looking at the face from the outside)
# Convention: +X = right, +Y = front, +Z = up
# 8 corners:
# (-x,-y,-z), ( x,-y,-z), ( x, y,-z), (-x, y,-z),
# (-x,-y, z), ( x,-y, z), ( x, y, z), (-x, y, z)
quads = {
# top (+Z) : look from +Z
"front": [
(-half_x, -half_y, half_z),
( half_x, -half_y, half_z),
( half_x, half_y, half_z),
(-half_x, half_y, half_z),
],
# right (+X) : look from +X
"right": [
( half_x, -half_y, -half_z),
( half_x, half_y, -half_z),
( half_x, half_y, half_z),
( half_x, -half_y, half_z),
],
# bottom (-Z) : look from -Z
"back": [
(-half_x, half_y, -half_z),
( half_x, half_y, -half_z),
( half_x, -half_y, -half_z),
(-half_x, -half_y, -half_z),
],
# left (-X) : look from -X
"left": [
(-half_x, -half_y, half_z),
(-half_x, half_y, half_z),
(-half_x, half_y, -half_z),
(-half_x, -half_y, -half_z),
],
# front (+Y) : look from +Y
"top": [
(-half_x, half_y, -half_z),
(-half_x, half_y, half_z),
( half_x, half_y, half_z),
( half_x, half_y, -half_z),
],
# back (-Y) : look from -Y
"bottom": [
( half_x, -half_y, -half_z),
( half_x, -half_y, half_z),
(-half_x, -half_y, half_z),
(-half_x, -half_y, -half_z),
],
}
face_order = ["top", "right", "bottom", "left", "front", "back"]
obj_path = os.path.join(tmpdir, "parallelep.obj")
with open(obj_path, "w", encoding="utf-8") as f:
f.write("# Parallelepiped OBJ generated by build_textured_cube\n")
f.write("mtllib parallelep.mtl\n\n")
for face_name in face_order:
for v in quads[face_name]:
f.write("v {:.6f} {:.6f} {:.6f}\n".format(*v))
f.write("\n")
uvs = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]
for _ in range(6):
for (u, v) in uvs:
f.write("vt {:.6f} {:.6f}\n".format(u, v))
f.write("\n")
for i, face_name in enumerate(face_order):
f.write(f"usemtl m_{face_name}\n")
v_base = i * 4 + 1
t_base = i * 4 + 1
v1, v2, v3, v4 = v_base, v_base + 1, v_base + 2, v_base + 3
t1, t2, t3, t4 = t_base, t_base + 1, t_base + 2, t_base + 3
# deux triangles (v/vt)
f.write(f"f {v1}/{t1} {v2}/{t2} {v3}/{t3}\n")
f.write(f"f {v1}/{t1} {v3}/{t3} {v4}/{t4}\n\n")
try:
os.chmod(obj_path, 0o644)
except Exception:
pass
for fname in ["parallelep.obj", "parallelep.mtl"] + list(tex_names.values()):
p = os.path.join(tmpdir, fname)
if not os.path.exists(p):
raise FileNotFoundError(f"Expected file not found : {p}")
return (os.path.abspath(obj_path), tmpdir)
# -------------------------
# return : 4 img (PIL) + path to .obj (str)
# -------------------------
def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
try:
outputs = inference(pipe, fibers, rings, num_steps)
if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4):
raise ValueError("user_inference must return a list/tuple of 4 images.")
pil_imgs = []
for im in outputs[:4]:
if isinstance(im, np.ndarray):
im = Image.fromarray(im)
if im.mode != "RGB":
im = im.convert("RGB")
print(im.size)
pil_imgs.append(im)
thumbs = [im.copy() for im in pil_imgs]
obj_path, tmpdir = build_textured_cube(pil_imgs)
return (*thumbs, obj_path)
except Exception as e:
traceback.print_exc()
blank = Image.new("RGB", (256,256), (220,220,220))
return (blank, blank, blank, blank, None)
# -------------------------
# Interface Gradio
# -------------------------
with gr.Blocks(title="Photorealistic wood generator (4 faces)") as demo:
gr.HTML("<h1 style='text-align:center; margin-bottom:8px;'>Photorealistic wood generator (4 faces)</h1>")
gr.Markdown("""Upload 2 images (four fiber maps and four ring maps) corresponding to the board faces. The model will return four generated images (one per face), produced in a single coherent pass.
Set the number of inference steps. Higher values can improve quality but increase processing time.""")
with gr.Row():
with gr.Column(scale=1):
inp1 = gr.Image(label="Fiber", type="numpy")
inp2 = gr.Image(label="Ring", type="numpy")
inp3 = gr.Number(value=10, label="Number of inference steps")
run_btn = gr.Button("Run inference")
with gr.Column(scale=2):
model3d_out = gr.Model3D(label="3D board")
with gr.Row():
out1 = gr.Image(label="Front")
out2 = gr.Image(label="Right")
out3 = gr.Image(label="Back")
out4 = gr.Image(label="Left")
run_btn.click(fn=run, inputs=[inp1, inp2, inp3], outputs=[out1, out2, out3, out4, model3d_out])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|