# -*- coding: utf-8 -*- """ Created on Tue Jun 10 11:16:28 2025 @author: camaac """ import gradio as gr from PIL import Image import io, base64, json, traceback import torch from inference import inference from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor import torch.nn as nn import numpy as np import tempfile import os import shutil import uuid class UNetNoCondWrapper(nn.Module): def __init__(self, base_unet: UNet2DModel): super().__init__() self.unet = base_unet def forward( self, sample, timestep, encoder_hidden_states=None, added_cond_kwargs=None, cross_attention_kwargs=None, return_dict=False, **kwargs ): return self.unet(sample, timestep, return_dict=return_dict, **kwargs) def __getattr__(self, name): if name in ("unet", "forward", "__getstate__", "__setstate__"): return super().__getattr__(name) return getattr(self.unet, name) def save_pretrained(self, save_directory, **kwargs): return self.unet.save_pretrained(save_directory, **kwargs) # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces" vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device) scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device) feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor") # 2) Chargez votre UNet non‑conditionné et wrappez‑le base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device) wrapped_unet = UNetNoCondWrapper(base_unet).to(device) # 3) Construisez la pipeline manuellement pipe = StableDiffusionInstructPix2PixPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=wrapped_unet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor, ) pipe = pipe.to(torch.float32).to(device) # @spaces.GPU def build_textured_cube(pil_imgs, face_rotations=None): """ Creates a textured parallelepiped (OBJ + MTL + textures). """ import os import tempfile from PIL import Image if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4): raise ValueError("build_textured_cube expects a list/tuple of 4 PIL images (front, right, back, left).") default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0} if face_rotations is None: face_rotations = default_rots else: for k, v in default_rots.items(): face_rotations.setdefault(k, v) base_dir = "/tmp/gradio" if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK): tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir) else: tmpdir = tempfile.mkdtemp(prefix="parallelep_") # relative names for textures (mtl will use these names) tex_names = { "front": "tex_front.png", "right": "tex_right.png", "back": "tex_back.png", "left": "tex_left.png", "top": "tex_top.png", "bottom": "tex_bottom.png", } front_w, front_h = pil_imgs[0].size right_w, right_h = pil_imgs[1].size ratio = 45/145 right_w = int(front_w * ratio) # define the physical dimensions of the parallelepiped (in “px”, then normalize) width_px = float(front_w) height_px = float(right_w) depth_px = float(front_h) # normalization to keep coordinates within ±0.5 max_dim = max(width_px, depth_px, height_px, 1.0) scale = 1.0 / max_dim half_x = (width_px * 0.5) * scale half_y = (depth_px * 0.5) * scale half_z = (height_px * 0.5) * scale mapping_order = ["front", "right", "back", "left"] # save textures in tmpdir for img, face_name in zip(pil_imgs[:4], mapping_order): im = img.convert("RGB") angle = face_rotations.get(face_name, 0) if angle % 360 != 0: # PIL rotate: angle in degrees, positive = CCW im = im.rotate(angle, resample=Image.BICUBIC, expand=False) path = os.path.join(tmpdir, tex_names[face_name]) im.save(path, format="PNG") try: os.chmod(path, 0o644) except Exception: pass # black top/bottom black = Image.new("RGB", (front_w, front_h), (0, 0, 0)) for face_name in ("top", "bottom"): im = black angle = face_rotations.get(face_name, 0) if angle % 360 != 0: im = im.rotate(angle, resample=Image.BICUBIC, expand=False) p = os.path.join(tmpdir, tex_names[face_name]) im.save(p, format="PNG") try: os.chmod(p, 0o644) except Exception: pass # --- write .mtl --- mtl_path = os.path.join(tmpdir, "parallelep.mtl") with open(mtl_path, "w", encoding="utf-8") as f: f.write("# Material file for parallelepiped\n") for mat_name, tex_file in tex_names.items(): f.write(f"newmtl m_{mat_name}\n") f.write("Ka 1.000 1.000 1.000\n") f.write("Kd 1.000 1.000 1.000\n") f.write("Ks 0.000 0.000 0.000\n") f.write("Ns 10.000\n") f.write("illum 2\n") f.write(f"map_Kd {tex_file}\n\n") try: os.chmod(mtl_path, 0o644) except Exception: pass # --- geometry: define quads per face (CCW when looking at the face from the outside) # Convention: +X = right, +Y = front, +Z = up # 8 corners: # (-x,-y,-z), ( x,-y,-z), ( x, y,-z), (-x, y,-z), # (-x,-y, z), ( x,-y, z), ( x, y, z), (-x, y, z) quads = { # top (+Z) : look from +Z "front": [ (-half_x, -half_y, half_z), ( half_x, -half_y, half_z), ( half_x, half_y, half_z), (-half_x, half_y, half_z), ], # right (+X) : look from +X "right": [ ( half_x, -half_y, -half_z), ( half_x, half_y, -half_z), ( half_x, half_y, half_z), ( half_x, -half_y, half_z), ], # bottom (-Z) : look from -Z "back": [ (-half_x, half_y, -half_z), ( half_x, half_y, -half_z), ( half_x, -half_y, -half_z), (-half_x, -half_y, -half_z), ], # left (-X) : look from -X "left": [ (-half_x, -half_y, half_z), (-half_x, half_y, half_z), (-half_x, half_y, -half_z), (-half_x, -half_y, -half_z), ], # front (+Y) : look from +Y "top": [ (-half_x, half_y, -half_z), (-half_x, half_y, half_z), ( half_x, half_y, half_z), ( half_x, half_y, -half_z), ], # back (-Y) : look from -Y "bottom": [ ( half_x, -half_y, -half_z), ( half_x, -half_y, half_z), (-half_x, -half_y, half_z), (-half_x, -half_y, -half_z), ], } face_order = ["top", "right", "bottom", "left", "front", "back"] obj_path = os.path.join(tmpdir, "parallelep.obj") with open(obj_path, "w", encoding="utf-8") as f: f.write("# Parallelepiped OBJ generated by build_textured_cube\n") f.write("mtllib parallelep.mtl\n\n") for face_name in face_order: for v in quads[face_name]: f.write("v {:.6f} {:.6f} {:.6f}\n".format(*v)) f.write("\n") uvs = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] for _ in range(6): for (u, v) in uvs: f.write("vt {:.6f} {:.6f}\n".format(u, v)) f.write("\n") for i, face_name in enumerate(face_order): f.write(f"usemtl m_{face_name}\n") v_base = i * 4 + 1 t_base = i * 4 + 1 v1, v2, v3, v4 = v_base, v_base + 1, v_base + 2, v_base + 3 t1, t2, t3, t4 = t_base, t_base + 1, t_base + 2, t_base + 3 # deux triangles (v/vt) f.write(f"f {v1}/{t1} {v2}/{t2} {v3}/{t3}\n") f.write(f"f {v1}/{t1} {v3}/{t3} {v4}/{t4}\n\n") try: os.chmod(obj_path, 0o644) except Exception: pass for fname in ["parallelep.obj", "parallelep.mtl"] + list(tex_names.values()): p = os.path.join(tmpdir, fname) if not os.path.exists(p): raise FileNotFoundError(f"Expected file not found : {p}") return (os.path.abspath(obj_path), tmpdir) # ------------------------- # return : 4 img (PIL) + path to .obj (str) # ------------------------- def run(fibers: Image.Image, rings: Image.Image, num_steps: int): try: outputs = inference(pipe, fibers, rings, num_steps) if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4): raise ValueError("user_inference must return a list/tuple of 4 images.") pil_imgs = [] for im in outputs[:4]: if isinstance(im, np.ndarray): im = Image.fromarray(im) if im.mode != "RGB": im = im.convert("RGB") print(im.size) pil_imgs.append(im) thumbs = [im.copy() for im in pil_imgs] obj_path, tmpdir = build_textured_cube(pil_imgs) return (*thumbs, obj_path) except Exception as e: traceback.print_exc() blank = Image.new("RGB", (256,256), (220,220,220)) return (blank, blank, blank, blank, None) # ------------------------- # Interface Gradio # ------------------------- with gr.Blocks(title="Photorealistic wood generator (4 faces)") as demo: gr.HTML("

Photorealistic wood generator (4 faces)

") gr.Markdown("""Upload 2 images (four fiber maps and four ring maps) corresponding to the board faces. The model will return four generated images (one per face), produced in a single coherent pass. Set the number of inference steps. Higher values can improve quality but increase processing time.""") with gr.Row(): with gr.Column(scale=1): inp1 = gr.Image(label="Fiber", type="numpy") inp2 = gr.Image(label="Ring", type="numpy") inp3 = gr.Number(value=10, label="Number of inference steps") run_btn = gr.Button("Run inference") with gr.Column(scale=2): model3d_out = gr.Model3D(label="3D board") with gr.Row(): out1 = gr.Image(label="Front") out2 = gr.Image(label="Right") out3 = gr.Image(label="Back") out4 = gr.Image(label="Left") run_btn.click(fn=run, inputs=[inp1, inp2, inp3], outputs=[out1, out2, out3, out4, model3d_out]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)