Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Tue Jun 10 11:16:28 2025 | |
| @author: camaac | |
| """ | |
| import gradio as gr | |
| from PIL import Image | |
| import io, base64, json, traceback | |
| import torch | |
| from inference import inference | |
| from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor | |
| import torch.nn as nn | |
| import numpy as np | |
| import tempfile | |
| import os | |
| import shutil | |
| import uuid | |
| class UNetNoCondWrapper(nn.Module): | |
| def __init__(self, base_unet: UNet2DModel): | |
| super().__init__() | |
| self.unet = base_unet | |
| def forward( | |
| self, | |
| sample, | |
| timestep, | |
| encoder_hidden_states=None, | |
| added_cond_kwargs=None, | |
| cross_attention_kwargs=None, | |
| return_dict=False, | |
| **kwargs | |
| ): | |
| return self.unet(sample, timestep, return_dict=return_dict, **kwargs) | |
| def __getattr__(self, name): | |
| if name in ("unet", "forward", "__getstate__", "__setstate__"): | |
| return super().__getattr__(name) | |
| return getattr(self.unet, name) | |
| def save_pretrained(self, save_directory, **kwargs): | |
| return self.unet.save_pretrained(save_directory, **kwargs) | |
| # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces" | |
| vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device) | |
| scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") | |
| tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device) | |
| feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor") | |
| # 2) Chargez votre UNet non‑conditionné et wrappez‑le | |
| base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device) | |
| wrapped_unet = UNetNoCondWrapper(base_unet).to(device) | |
| # 3) Construisez la pipeline manuellement | |
| pipe = StableDiffusionInstructPix2PixPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=wrapped_unet, | |
| scheduler=scheduler, | |
| safety_checker=None, | |
| feature_extractor=feature_extractor, | |
| ) | |
| pipe = pipe.to(torch.float32).to(device) | |
| # @spaces.GPU | |
| def build_textured_cube(pil_imgs, face_rotations=None): | |
| """ | |
| Creates a textured parallelepiped (OBJ + MTL + textures). | |
| """ | |
| import os | |
| import tempfile | |
| from PIL import Image | |
| if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4): | |
| raise ValueError("build_textured_cube expects a list/tuple of 4 PIL images (front, right, back, left).") | |
| default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0} | |
| if face_rotations is None: | |
| face_rotations = default_rots | |
| else: | |
| for k, v in default_rots.items(): | |
| face_rotations.setdefault(k, v) | |
| base_dir = "/tmp/gradio" | |
| if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK): | |
| tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir) | |
| else: | |
| tmpdir = tempfile.mkdtemp(prefix="parallelep_") | |
| # relative names for textures (mtl will use these names) | |
| tex_names = { | |
| "front": "tex_front.png", | |
| "right": "tex_right.png", | |
| "back": "tex_back.png", | |
| "left": "tex_left.png", | |
| "top": "tex_top.png", | |
| "bottom": "tex_bottom.png", | |
| } | |
| front_w, front_h = pil_imgs[0].size | |
| right_w, right_h = pil_imgs[1].size | |
| ratio = 45/145 | |
| right_w = int(front_w * ratio) | |
| # define the physical dimensions of the parallelepiped (in “px”, then normalize) | |
| width_px = float(front_w) | |
| height_px = float(right_w) | |
| depth_px = float(front_h) | |
| # normalization to keep coordinates within ±0.5 | |
| max_dim = max(width_px, depth_px, height_px, 1.0) | |
| scale = 1.0 / max_dim | |
| half_x = (width_px * 0.5) * scale | |
| half_y = (depth_px * 0.5) * scale | |
| half_z = (height_px * 0.5) * scale | |
| mapping_order = ["front", "right", "back", "left"] | |
| # save textures in tmpdir | |
| for img, face_name in zip(pil_imgs[:4], mapping_order): | |
| im = img.convert("RGB") | |
| angle = face_rotations.get(face_name, 0) | |
| if angle % 360 != 0: | |
| # PIL rotate: angle in degrees, positive = CCW | |
| im = im.rotate(angle, resample=Image.BICUBIC, expand=False) | |
| path = os.path.join(tmpdir, tex_names[face_name]) | |
| im.save(path, format="PNG") | |
| try: | |
| os.chmod(path, 0o644) | |
| except Exception: | |
| pass | |
| # black top/bottom | |
| black = Image.new("RGB", (front_w, front_h), (0, 0, 0)) | |
| for face_name in ("top", "bottom"): | |
| im = black | |
| angle = face_rotations.get(face_name, 0) | |
| if angle % 360 != 0: | |
| im = im.rotate(angle, resample=Image.BICUBIC, expand=False) | |
| p = os.path.join(tmpdir, tex_names[face_name]) | |
| im.save(p, format="PNG") | |
| try: | |
| os.chmod(p, 0o644) | |
| except Exception: | |
| pass | |
| # --- write .mtl --- | |
| mtl_path = os.path.join(tmpdir, "parallelep.mtl") | |
| with open(mtl_path, "w", encoding="utf-8") as f: | |
| f.write("# Material file for parallelepiped\n") | |
| for mat_name, tex_file in tex_names.items(): | |
| f.write(f"newmtl m_{mat_name}\n") | |
| f.write("Ka 1.000 1.000 1.000\n") | |
| f.write("Kd 1.000 1.000 1.000\n") | |
| f.write("Ks 0.000 0.000 0.000\n") | |
| f.write("Ns 10.000\n") | |
| f.write("illum 2\n") | |
| f.write(f"map_Kd {tex_file}\n\n") | |
| try: | |
| os.chmod(mtl_path, 0o644) | |
| except Exception: | |
| pass | |
| # --- geometry: define quads per face (CCW when looking at the face from the outside) | |
| # Convention: +X = right, +Y = front, +Z = up | |
| # 8 corners: | |
| # (-x,-y,-z), ( x,-y,-z), ( x, y,-z), (-x, y,-z), | |
| # (-x,-y, z), ( x,-y, z), ( x, y, z), (-x, y, z) | |
| quads = { | |
| # top (+Z) : look from +Z | |
| "front": [ | |
| (-half_x, -half_y, half_z), | |
| ( half_x, -half_y, half_z), | |
| ( half_x, half_y, half_z), | |
| (-half_x, half_y, half_z), | |
| ], | |
| # right (+X) : look from +X | |
| "right": [ | |
| ( half_x, -half_y, -half_z), | |
| ( half_x, half_y, -half_z), | |
| ( half_x, half_y, half_z), | |
| ( half_x, -half_y, half_z), | |
| ], | |
| # bottom (-Z) : look from -Z | |
| "back": [ | |
| (-half_x, half_y, -half_z), | |
| ( half_x, half_y, -half_z), | |
| ( half_x, -half_y, -half_z), | |
| (-half_x, -half_y, -half_z), | |
| ], | |
| # left (-X) : look from -X | |
| "left": [ | |
| (-half_x, -half_y, half_z), | |
| (-half_x, half_y, half_z), | |
| (-half_x, half_y, -half_z), | |
| (-half_x, -half_y, -half_z), | |
| ], | |
| # front (+Y) : look from +Y | |
| "top": [ | |
| (-half_x, half_y, -half_z), | |
| (-half_x, half_y, half_z), | |
| ( half_x, half_y, half_z), | |
| ( half_x, half_y, -half_z), | |
| ], | |
| # back (-Y) : look from -Y | |
| "bottom": [ | |
| ( half_x, -half_y, -half_z), | |
| ( half_x, -half_y, half_z), | |
| (-half_x, -half_y, half_z), | |
| (-half_x, -half_y, -half_z), | |
| ], | |
| } | |
| face_order = ["top", "right", "bottom", "left", "front", "back"] | |
| obj_path = os.path.join(tmpdir, "parallelep.obj") | |
| with open(obj_path, "w", encoding="utf-8") as f: | |
| f.write("# Parallelepiped OBJ generated by build_textured_cube\n") | |
| f.write("mtllib parallelep.mtl\n\n") | |
| for face_name in face_order: | |
| for v in quads[face_name]: | |
| f.write("v {:.6f} {:.6f} {:.6f}\n".format(*v)) | |
| f.write("\n") | |
| uvs = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] | |
| for _ in range(6): | |
| for (u, v) in uvs: | |
| f.write("vt {:.6f} {:.6f}\n".format(u, v)) | |
| f.write("\n") | |
| for i, face_name in enumerate(face_order): | |
| f.write(f"usemtl m_{face_name}\n") | |
| v_base = i * 4 + 1 | |
| t_base = i * 4 + 1 | |
| v1, v2, v3, v4 = v_base, v_base + 1, v_base + 2, v_base + 3 | |
| t1, t2, t3, t4 = t_base, t_base + 1, t_base + 2, t_base + 3 | |
| # deux triangles (v/vt) | |
| f.write(f"f {v1}/{t1} {v2}/{t2} {v3}/{t3}\n") | |
| f.write(f"f {v1}/{t1} {v3}/{t3} {v4}/{t4}\n\n") | |
| try: | |
| os.chmod(obj_path, 0o644) | |
| except Exception: | |
| pass | |
| for fname in ["parallelep.obj", "parallelep.mtl"] + list(tex_names.values()): | |
| p = os.path.join(tmpdir, fname) | |
| if not os.path.exists(p): | |
| raise FileNotFoundError(f"Expected file not found : {p}") | |
| return (os.path.abspath(obj_path), tmpdir) | |
| # ------------------------- | |
| # return : 4 img (PIL) + path to .obj (str) | |
| # ------------------------- | |
| def run(fibers: Image.Image, rings: Image.Image, num_steps: int): | |
| try: | |
| outputs = inference(pipe, fibers, rings, num_steps) | |
| if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4): | |
| raise ValueError("user_inference must return a list/tuple of 4 images.") | |
| pil_imgs = [] | |
| for im in outputs[:4]: | |
| if isinstance(im, np.ndarray): | |
| im = Image.fromarray(im) | |
| if im.mode != "RGB": | |
| im = im.convert("RGB") | |
| print(im.size) | |
| pil_imgs.append(im) | |
| thumbs = [im.copy() for im in pil_imgs] | |
| obj_path, tmpdir = build_textured_cube(pil_imgs) | |
| return (*thumbs, obj_path) | |
| except Exception as e: | |
| traceback.print_exc() | |
| blank = Image.new("RGB", (256,256), (220,220,220)) | |
| return (blank, blank, blank, blank, None) | |
| # ------------------------- | |
| # Interface Gradio | |
| # ------------------------- | |
| with gr.Blocks(title="Photorealistic wood generator (4 faces)") as demo: | |
| gr.HTML("<h1 style='text-align:center; margin-bottom:8px;'>Photorealistic wood generator (4 faces)</h1>") | |
| gr.Markdown("""Upload 2 images (four fiber maps and four ring maps) corresponding to the board faces. The model will return four generated images (one per face), produced in a single coherent pass. | |
| Set the number of inference steps. Higher values can improve quality but increase processing time.""") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp1 = gr.Image(label="Fiber", type="numpy") | |
| inp2 = gr.Image(label="Ring", type="numpy") | |
| inp3 = gr.Number(value=10, label="Number of inference steps") | |
| run_btn = gr.Button("Run inference") | |
| with gr.Column(scale=2): | |
| model3d_out = gr.Model3D(label="3D board") | |
| with gr.Row(): | |
| out1 = gr.Image(label="Front") | |
| out2 = gr.Image(label="Right") | |
| out3 = gr.Image(label="Back") | |
| out4 = gr.Image(label="Left") | |
| run_btn.click(fn=run, inputs=[inp1, inp2, inp3], outputs=[out1, out2, out3, out4, model3d_out]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |