CarolineM5's picture
Upload 2 files
0600e9e verified
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 10 11:16:28 2025
@author: camaac
"""
import gradio as gr
from PIL import Image
import io, base64, json, traceback
import torch
from inference import inference
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
import torch.nn as nn
import numpy as np
import tempfile
import os
import shutil
import uuid
class UNetNoCondWrapper(nn.Module):
def __init__(self, base_unet: UNet2DModel):
super().__init__()
self.unet = base_unet
def forward(
self,
sample,
timestep,
encoder_hidden_states=None,
added_cond_kwargs=None,
cross_attention_kwargs=None,
return_dict=False,
**kwargs
):
return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
def __getattr__(self, name):
if name in ("unet", "forward", "__getstate__", "__setstate__"):
return super().__getattr__(name)
return getattr(self.unet, name)
def save_pretrained(self, save_directory, **kwargs):
return self.unet.save_pretrained(save_directory, **kwargs)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
# 2) Chargez votre UNet non‑conditionné et wrappez‑le
base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
# 3) Construisez la pipeline manuellement
pipe = StableDiffusionInstructPix2PixPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=wrapped_unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
pipe = pipe.to(torch.float32).to(device)
# @spaces.GPU
def build_textured_cube(pil_imgs, face_rotations=None):
"""
Creates a textured parallelepiped (OBJ + MTL + textures).
"""
import os
import tempfile
from PIL import Image
if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4):
raise ValueError("build_textured_cube expects a list/tuple of 4 PIL images (front, right, back, left).")
default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0}
if face_rotations is None:
face_rotations = default_rots
else:
for k, v in default_rots.items():
face_rotations.setdefault(k, v)
base_dir = "/tmp/gradio"
if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK):
tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir)
else:
tmpdir = tempfile.mkdtemp(prefix="parallelep_")
# relative names for textures (mtl will use these names)
tex_names = {
"front": "tex_front.png",
"right": "tex_right.png",
"back": "tex_back.png",
"left": "tex_left.png",
"top": "tex_top.png",
"bottom": "tex_bottom.png",
}
front_w, front_h = pil_imgs[0].size
right_w, right_h = pil_imgs[1].size
ratio = 45/145
right_w = int(front_w * ratio)
# define the physical dimensions of the parallelepiped (in “px”, then normalize)
width_px = float(front_w)
height_px = float(right_w)
depth_px = float(front_h)
# normalization to keep coordinates within ±0.5
max_dim = max(width_px, depth_px, height_px, 1.0)
scale = 1.0 / max_dim
half_x = (width_px * 0.5) * scale
half_y = (depth_px * 0.5) * scale
half_z = (height_px * 0.5) * scale
mapping_order = ["front", "right", "back", "left"]
# save textures in tmpdir
for img, face_name in zip(pil_imgs[:4], mapping_order):
im = img.convert("RGB")
angle = face_rotations.get(face_name, 0)
if angle % 360 != 0:
# PIL rotate: angle in degrees, positive = CCW
im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
path = os.path.join(tmpdir, tex_names[face_name])
im.save(path, format="PNG")
try:
os.chmod(path, 0o644)
except Exception:
pass
# black top/bottom
black = Image.new("RGB", (front_w, front_h), (0, 0, 0))
for face_name in ("top", "bottom"):
im = black
angle = face_rotations.get(face_name, 0)
if angle % 360 != 0:
im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
p = os.path.join(tmpdir, tex_names[face_name])
im.save(p, format="PNG")
try:
os.chmod(p, 0o644)
except Exception:
pass
# --- write .mtl ---
mtl_path = os.path.join(tmpdir, "parallelep.mtl")
with open(mtl_path, "w", encoding="utf-8") as f:
f.write("# Material file for parallelepiped\n")
for mat_name, tex_file in tex_names.items():
f.write(f"newmtl m_{mat_name}\n")
f.write("Ka 1.000 1.000 1.000\n")
f.write("Kd 1.000 1.000 1.000\n")
f.write("Ks 0.000 0.000 0.000\n")
f.write("Ns 10.000\n")
f.write("illum 2\n")
f.write(f"map_Kd {tex_file}\n\n")
try:
os.chmod(mtl_path, 0o644)
except Exception:
pass
# --- geometry: define quads per face (CCW when looking at the face from the outside)
# Convention: +X = right, +Y = front, +Z = up
# 8 corners:
# (-x,-y,-z), ( x,-y,-z), ( x, y,-z), (-x, y,-z),
# (-x,-y, z), ( x,-y, z), ( x, y, z), (-x, y, z)
quads = {
# top (+Z) : look from +Z
"front": [
(-half_x, -half_y, half_z),
( half_x, -half_y, half_z),
( half_x, half_y, half_z),
(-half_x, half_y, half_z),
],
# right (+X) : look from +X
"right": [
( half_x, -half_y, -half_z),
( half_x, half_y, -half_z),
( half_x, half_y, half_z),
( half_x, -half_y, half_z),
],
# bottom (-Z) : look from -Z
"back": [
(-half_x, half_y, -half_z),
( half_x, half_y, -half_z),
( half_x, -half_y, -half_z),
(-half_x, -half_y, -half_z),
],
# left (-X) : look from -X
"left": [
(-half_x, -half_y, half_z),
(-half_x, half_y, half_z),
(-half_x, half_y, -half_z),
(-half_x, -half_y, -half_z),
],
# front (+Y) : look from +Y
"top": [
(-half_x, half_y, -half_z),
(-half_x, half_y, half_z),
( half_x, half_y, half_z),
( half_x, half_y, -half_z),
],
# back (-Y) : look from -Y
"bottom": [
( half_x, -half_y, -half_z),
( half_x, -half_y, half_z),
(-half_x, -half_y, half_z),
(-half_x, -half_y, -half_z),
],
}
face_order = ["top", "right", "bottom", "left", "front", "back"]
obj_path = os.path.join(tmpdir, "parallelep.obj")
with open(obj_path, "w", encoding="utf-8") as f:
f.write("# Parallelepiped OBJ generated by build_textured_cube\n")
f.write("mtllib parallelep.mtl\n\n")
for face_name in face_order:
for v in quads[face_name]:
f.write("v {:.6f} {:.6f} {:.6f}\n".format(*v))
f.write("\n")
uvs = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]
for _ in range(6):
for (u, v) in uvs:
f.write("vt {:.6f} {:.6f}\n".format(u, v))
f.write("\n")
for i, face_name in enumerate(face_order):
f.write(f"usemtl m_{face_name}\n")
v_base = i * 4 + 1
t_base = i * 4 + 1
v1, v2, v3, v4 = v_base, v_base + 1, v_base + 2, v_base + 3
t1, t2, t3, t4 = t_base, t_base + 1, t_base + 2, t_base + 3
# deux triangles (v/vt)
f.write(f"f {v1}/{t1} {v2}/{t2} {v3}/{t3}\n")
f.write(f"f {v1}/{t1} {v3}/{t3} {v4}/{t4}\n\n")
try:
os.chmod(obj_path, 0o644)
except Exception:
pass
for fname in ["parallelep.obj", "parallelep.mtl"] + list(tex_names.values()):
p = os.path.join(tmpdir, fname)
if not os.path.exists(p):
raise FileNotFoundError(f"Expected file not found : {p}")
return (os.path.abspath(obj_path), tmpdir)
# -------------------------
# return : 4 img (PIL) + path to .obj (str)
# -------------------------
def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
try:
outputs = inference(pipe, fibers, rings, num_steps)
if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4):
raise ValueError("user_inference must return a list/tuple of 4 images.")
pil_imgs = []
for im in outputs[:4]:
if isinstance(im, np.ndarray):
im = Image.fromarray(im)
if im.mode != "RGB":
im = im.convert("RGB")
print(im.size)
pil_imgs.append(im)
thumbs = [im.copy() for im in pil_imgs]
obj_path, tmpdir = build_textured_cube(pil_imgs)
return (*thumbs, obj_path)
except Exception as e:
traceback.print_exc()
blank = Image.new("RGB", (256,256), (220,220,220))
return (blank, blank, blank, blank, None)
# -------------------------
# Interface Gradio
# -------------------------
with gr.Blocks(title="Photorealistic wood generator (4 faces)") as demo:
gr.HTML("<h1 style='text-align:center; margin-bottom:8px;'>Photorealistic wood generator (4 faces)</h1>")
gr.Markdown("""Upload 2 images (four fiber maps and four ring maps) corresponding to the board faces. The model will return four generated images (one per face), produced in a single coherent pass.
Set the number of inference steps. Higher values can improve quality but increase processing time.""")
with gr.Row():
with gr.Column(scale=1):
inp1 = gr.Image(label="Fiber", type="numpy")
inp2 = gr.Image(label="Ring", type="numpy")
inp3 = gr.Number(value=10, label="Number of inference steps")
run_btn = gr.Button("Run inference")
with gr.Column(scale=2):
model3d_out = gr.Model3D(label="3D board")
with gr.Row():
out1 = gr.Image(label="Front")
out2 = gr.Image(label="Right")
out3 = gr.Image(label="Back")
out4 = gr.Image(label="Left")
run_btn.click(fn=run, inputs=[inp1, inp2, inp3], outputs=[out1, out2, out3, out4, model3d_out])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)