File size: 11,605 Bytes
3c8903a
 
 
 
 
 
 
 
 
8c9164c
3c8903a
 
 
 
 
 
 
98f01d7
 
 
 
 
 
3c8903a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0600e9e
3c8903a
 
 
 
0600e9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74eb73a
9299372
973f3eb
 
da6818c
98f01d7
0600e9e
98f01d7
da6818c
 
 
 
0600e9e
da6818c
0600e9e
da6818c
a28b5ab
da6818c
 
 
 
 
 
0600e9e
da6818c
 
 
 
 
 
0600e9e
da6818c
 
 
 
 
 
 
377bc40
 
da6818c
 
2ec02af
 
 
da6818c
0600e9e
 
 
a307280
da6818c
0600e9e
da6818c
 
0600e9e
 
 
da6818c
0600e9e
da6818c
0600e9e
da6818c
 
 
 
0600e9e
da6818c
 
 
 
 
 
 
 
0600e9e
da6818c
 
 
 
 
 
 
 
 
 
 
 
 
0600e9e
da6818c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a28b5ab
da6818c
04fbcf3
da6818c
 
007755c
04fbcf3
4348f60
da6818c
 
 
 
 
04fbcf3
da6818c
 
 
 
 
 
04fbcf3
4348f60
da6818c
 
 
 
 
04fbcf3
da6818c
 
 
 
 
 
04fbcf3
4348f60
da6818c
 
 
 
 
04fbcf3
4348f60
da6818c
 
 
 
 
007755c
 
04fbcf3
4348f60
da6818c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a28b5ab
da6818c
 
007755c
6324410
98f01d7
 
04fbcf3
98f01d7
c9c18ab
8c9164c
0600e9e
98f01d7
a28b5ab
04fbcf3
8c9164c
98f01d7
8c9164c
 
 
 
8ab5e1a
8c9164c
 
04fbcf3
ce50616
98f01d7
04fbcf3
98f01d7
 
 
8c9164c
 
 
98f01d7
8c9164c
98f01d7
8c9164c
98f01d7
6324410
04fbcf3
6324410
 
8c9164c
 
6324410
 
93cd9b7
6324410
8c9164c
6324410
8c9164c
04fbcf3
 
 
 
8c9164c
93cd9b7
3c8903a
 
98f01d7
 
6324410
3c8903a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# -*- coding: utf-8 -*-
"""

Created on Tue Jun 10 11:16:28 2025



@author: camaac

"""

import gradio as gr
from PIL import Image
import io, base64, json, traceback
import torch
from inference import inference
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler 
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
import torch.nn as nn
import numpy as np

import tempfile
import os
import shutil
import uuid


class UNetNoCondWrapper(nn.Module):
    def __init__(self, base_unet: UNet2DModel):
        super().__init__()
        self.unet = base_unet

    def forward(

        self,

        sample,

        timestep,

        encoder_hidden_states=None,    

        added_cond_kwargs=None,        

        cross_attention_kwargs=None,   

        return_dict=False,

        **kwargs                       

    ):
        
        return self.unet(sample, timestep, return_dict=return_dict, **kwargs)

    def __getattr__(self, name):
        if name in ("unet", "forward", "__getstate__", "__setstate__"):
            return super().__getattr__(name)
        return getattr(self.unet, name)
    
    def save_pretrained(self, save_directory, **kwargs):

        return self.unet.save_pretrained(save_directory, **kwargs)
    
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")

# 2) Chargez votre UNet non‑conditionné et wrappez‑le
base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)

# 3) Construisez la pipeline manuellement
pipe = StableDiffusionInstructPix2PixPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=wrapped_unet,
    scheduler=scheduler,
    safety_checker=None,
    feature_extractor=feature_extractor,
)

pipe = pipe.to(torch.float32).to(device)
# @spaces.GPU



def build_textured_cube(pil_imgs, face_rotations=None):
    """

    Creates a textured parallelepiped (OBJ + MTL + textures).

    """
    import os
    import tempfile
    from PIL import Image


    if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4):
        raise ValueError("build_textured_cube expects a list/tuple of 4 PIL images (front, right, back, left).")

    default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0}
    if face_rotations is None:
        face_rotations = default_rots
    else:
        for k, v in default_rots.items():
            face_rotations.setdefault(k, v)


    base_dir = "/tmp/gradio"
    if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK):
        tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir)
    else:
        tmpdir = tempfile.mkdtemp(prefix="parallelep_")

    # relative names for textures (mtl will use these names)
    tex_names = {
        "front": "tex_front.png",
        "right": "tex_right.png",
        "back": "tex_back.png",
        "left": "tex_left.png",
        "top": "tex_top.png",
        "bottom": "tex_bottom.png",
    }

    front_w, front_h = pil_imgs[0].size
    right_w, right_h = pil_imgs[1].size
    
    ratio = 45/145
    right_w = int(front_w * ratio)

    # define the physical dimensions of the parallelepiped (in “px”, then normalize)
    width_px = float(front_w)   
    height_px = float(right_w)  
    depth_px = float(front_h)

    # normalization to keep coordinates within ±0.5
    max_dim = max(width_px, depth_px, height_px, 1.0)
    scale = 1.0 / max_dim
    half_x = (width_px * 0.5) * scale  
    half_y = (depth_px * 0.5) * scale   
    half_z = (height_px * 0.5) * scale 

    
    mapping_order = ["front", "right", "back", "left"]
    # save textures in tmpdir
    for img, face_name in zip(pil_imgs[:4], mapping_order):
        im = img.convert("RGB")
        angle = face_rotations.get(face_name, 0)
        if angle % 360 != 0:
            # PIL rotate: angle in degrees, positive = CCW
            im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
        path = os.path.join(tmpdir, tex_names[face_name])
        im.save(path, format="PNG")
        try:
            os.chmod(path, 0o644)
        except Exception:
            pass

    # black top/bottom 
    black = Image.new("RGB", (front_w, front_h), (0, 0, 0))
    for face_name in ("top", "bottom"):
        im = black
        angle = face_rotations.get(face_name, 0)
        if angle % 360 != 0:
            im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
        p = os.path.join(tmpdir, tex_names[face_name])
        im.save(p, format="PNG")
        try:
            os.chmod(p, 0o644)
        except Exception:
            pass

    # --- write .mtl ---
    mtl_path = os.path.join(tmpdir, "parallelep.mtl")
    with open(mtl_path, "w", encoding="utf-8") as f:
        f.write("# Material file for parallelepiped\n")
        for mat_name, tex_file in tex_names.items():
            f.write(f"newmtl m_{mat_name}\n")
            f.write("Ka 1.000 1.000 1.000\n")
            f.write("Kd 1.000 1.000 1.000\n")
            f.write("Ks 0.000 0.000 0.000\n")
            f.write("Ns 10.000\n")
            f.write("illum 2\n")
            f.write(f"map_Kd {tex_file}\n\n")
    try:
        os.chmod(mtl_path, 0o644)
    except Exception:
        pass

    # --- geometry: define quads per face (CCW when looking at the face from the outside)
    # Convention: +X = right, +Y = front, +Z = up
    # 8 corners:
    #   (-x,-y,-z), ( x,-y,-z), ( x, y,-z), (-x, y,-z),
    #   (-x,-y, z), ( x,-y, z), ( x, y, z), (-x, y, z)
    quads = {
        # top (+Z)  : look from +Z
        "front": [
            (-half_x, -half_y,  half_z),
            ( half_x, -half_y,  half_z),
            ( half_x,  half_y,  half_z),
            (-half_x,  half_y,  half_z),
        ],
        # right (+X) : look from +X
        "right": [
            ( half_x, -half_y, -half_z),
            ( half_x,  half_y, -half_z),
            ( half_x,  half_y,  half_z),
            ( half_x, -half_y,  half_z),
        ],
        # bottom (-Z) : look from -Z
        "back": [
            (-half_x,  half_y, -half_z),
            ( half_x,  half_y, -half_z),
            ( half_x, -half_y, -half_z),
            (-half_x, -half_y, -half_z),
        ],
        # left (-X) : look from -X
        "left": [
            (-half_x, -half_y,  half_z),
            (-half_x,  half_y,  half_z),
            (-half_x,  half_y, -half_z),
            (-half_x, -half_y, -half_z),
        ],
        # front (+Y) : look from +Y
        "top": [
            (-half_x,  half_y, -half_z),
            (-half_x,  half_y,  half_z),
            ( half_x,  half_y,  half_z),
            ( half_x,  half_y, -half_z),
        ],
        # back (-Y) : look from -Y 
        "bottom": [
            ( half_x, -half_y, -half_z),
            ( half_x, -half_y,  half_z),
            (-half_x, -half_y,  half_z),
            (-half_x, -half_y, -half_z),
        ],
    }


    face_order = ["top", "right", "bottom", "left", "front", "back"]
    obj_path = os.path.join(tmpdir, "parallelep.obj")
    with open(obj_path, "w", encoding="utf-8") as f:
        f.write("# Parallelepiped OBJ generated by build_textured_cube\n")
        f.write("mtllib parallelep.mtl\n\n")

        for face_name in face_order:
            for v in quads[face_name]:
                f.write("v {:.6f} {:.6f} {:.6f}\n".format(*v))
        f.write("\n")

        uvs = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]
        for _ in range(6):
            for (u, v) in uvs:
                f.write("vt {:.6f} {:.6f}\n".format(u, v))
        f.write("\n")

        for i, face_name in enumerate(face_order):
            f.write(f"usemtl m_{face_name}\n")
            v_base = i * 4 + 1
            t_base = i * 4 + 1
            v1, v2, v3, v4 = v_base, v_base + 1, v_base + 2, v_base + 3
            t1, t2, t3, t4 = t_base, t_base + 1, t_base + 2, t_base + 3
            # deux triangles (v/vt)
            f.write(f"f {v1}/{t1} {v2}/{t2} {v3}/{t3}\n")
            f.write(f"f {v1}/{t1} {v3}/{t3} {v4}/{t4}\n\n")
    try:
        os.chmod(obj_path, 0o644)
    except Exception:
        pass

    for fname in ["parallelep.obj", "parallelep.mtl"] + list(tex_names.values()):
        p = os.path.join(tmpdir, fname)
        if not os.path.exists(p):
            raise FileNotFoundError(f"Expected file not found : {p}")

    return (os.path.abspath(obj_path), tmpdir)
    


# -------------------------
# return : 4 img (PIL) + path to .obj (str)
# -------------------------
def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
    try:
        outputs = inference(pipe, fibers, rings, num_steps)
        if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4):
            raise ValueError("user_inference must return a list/tuple of 4 images.")

        pil_imgs = []
        for im in outputs[:4]:
            if isinstance(im, np.ndarray):
                im = Image.fromarray(im)
            if im.mode != "RGB":
                im = im.convert("RGB")
            print(im.size)
            pil_imgs.append(im)


        thumbs = [im.copy() for im in pil_imgs]


        obj_path, tmpdir = build_textured_cube(pil_imgs)

        return (*thumbs, obj_path)
    except Exception as e:
        traceback.print_exc()
        blank = Image.new("RGB", (256,256), (220,220,220))
        return (blank, blank, blank, blank, None)

# -------------------------
# Interface Gradio
# -------------------------
with gr.Blocks(title="Photorealistic wood generator (4 faces)") as demo:
    gr.HTML("<h1 style='text-align:center; margin-bottom:8px;'>Photorealistic wood generator (4 faces)</h1>")
    gr.Markdown("""Upload 2 images (four fiber maps and four ring maps) corresponding to the board faces. The model will return four generated images (one per face), produced in a single coherent pass.

    Set the number of inference steps. Higher values can improve quality but increase processing time.""")
    with gr.Row():
        with gr.Column(scale=1):
            inp1 = gr.Image(label="Fiber", type="numpy")
            inp2 = gr.Image(label="Ring", type="numpy")
            inp3 = gr.Number(value=10, label="Number of inference steps")
            run_btn = gr.Button("Run inference")
        with gr.Column(scale=2):
            model3d_out = gr.Model3D(label="3D board")
    with gr.Row():
        out1 = gr.Image(label="Front")
        out2 = gr.Image(label="Right")
        out3 = gr.Image(label="Back")
        out4 = gr.Image(label="Left")

    run_btn.click(fn=run, inputs=[inp1, inp2, inp3], outputs=[out1, out2, out3, out4, model3d_out])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)