CarolineM5 commited on
Commit
8c9164c
·
verified ·
1 Parent(s): bffac01

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -119
app.py CHANGED
@@ -7,7 +7,7 @@ Created on Tue Jun 10 11:16:28 2025
7
 
8
  import gradio as gr
9
  from PIL import Image
10
- import io, base64, json
11
  import torch
12
  from inference import inference
13
  from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
@@ -71,135 +71,192 @@ pipe = StableDiffusionInstructPix2PixPipeline(
71
  pipe = pipe.to(torch.float32).to(device)
72
  # @spaces.GPU
73
 
74
- def pil_to_dataurl(img: Image.Image, fmt="PNG"):
75
  buf = io.BytesIO()
76
- img.save(buf, format=fmt)
77
- b64 = base64.b64encode(buf.getvalue()).decode("ascii")
78
- return f"data:image/{fmt.lower()};base64,{b64}"
79
-
80
- # small three.js HTML template: imgs_list will be a JSON array of 4 data-urls in JS order [TL,TR,BL,BR]
81
- HTML_TEMPLATE = """
82
- <!doctype html>
83
- <html>
84
- <head>
85
- <meta charset="utf-8">
86
- <style>body{{ margin:0; }} #three-root{{ width:100%; height:420px; }}</style>
87
- </head>
88
- <body>
89
- <div id="three-root"></div>
90
- <script src="https://cdn.jsdelivr.net/npm/three@0.158.0/build/three.min.js"></script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  <script>
92
- (function(){{
93
- const imgs = {imgs_list}; // [TL,TR,BL,BR]
94
- const container = document.getElementById('three-root');
95
 
96
- // minimal renderer + camera
97
- const renderer = new THREE.WebGLRenderer({{antialias:true}});
98
- renderer.setSize(container.clientWidth, 420);
99
- container.appendChild(renderer.domElement);
100
 
101
  const scene = new THREE.Scene();
102
- scene.background = new THREE.Color(0xf7f7f7);
103
- const camera = new THREE.PerspectiveCamera(45, container.clientWidth / 420, 0.1, 1000);
104
- camera.position.set(1.8, 1.1, 2.4);
105
- camera.lookAt(0,0,0);
106
 
107
- const hemi = new THREE.HemisphereLight(0xffffff, 0x444444, 0.9);
108
- hemi.position.set(0, 20, 0); scene.add(hemi);
109
- const dir = new THREE.DirectionalLight(0xffffff, 0.7); dir.position.set(3,10,5); scene.add(dir);
 
110
 
111
- const loader = new THREE.TextureLoader();
 
 
112
 
113
- // mapping: imgs[0]=TL, imgs[1]=TR, imgs[2]=BL, imgs[3]=BR
114
- // We assign images to faces: left, right, front, back (top/bottom neutral)
115
- const neutral = 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAE0lEQVQImWNgYGBgYAAAAAQAAV6CkXkAAAAAASUVORK5CYII=';
116
-
117
- const tex_right = loader.load(imgs[1]);
118
- const tex_left = loader.load(imgs[0]);
119
- const tex_front = loader.load(imgs[2]);
120
- const tex_back = loader.load(imgs[3]);
121
- const tex_top = loader.load(neutral);
122
- const tex_bottom= loader.load(neutral);
123
-
124
- [tex_right,tex_left,tex_front,tex_back,tex_top,tex_bottom].forEach(t=>{
125
- t.wrapS = THREE.ClampToEdgeWrapping;
126
- t.wrapT = THREE.ClampToEdgeWrapping;
127
- t.encoding = THREE.sRGBEncoding;
128
- });
129
-
130
- const materials = [
131
- new THREE.MeshStandardMaterial({map: tex_right}), // px
132
- new THREE.MeshStandardMaterial({map: tex_left}), // nx
133
- new THREE.MeshStandardMaterial({map: tex_top}), // py
134
- new THREE.MeshStandardMaterial({map: tex_bottom}),// ny
135
- new THREE.MeshStandardMaterial({map: tex_front}), // pz
136
- new THREE.MeshStandardMaterial({map: tex_back}) // nz
137
- ];
138
-
139
- // Box geometry (width,height,depth) depth small to simulate board thickness
140
- const geometry = new THREE.BoxGeometry(1.0, 1.0, 0.26);
141
- const box = new THREE.Mesh(geometry, materials);
142
- scene.add(box);
143
-
144
- // simple drag-to-rotate
145
- let dragging=false, lastX=0, lastY=0;
146
- container.addEventListener('pointerdown', e=>{{ dragging=true; lastX=e.clientX; lastY=e.clientY; }});
147
- window.addEventListener('pointerup', ()=>dragging=false);
148
- window.addEventListener('pointermove', e=>{{ if(!dragging) return; const dx=(e.clientX-lastX)/200; const dy=(e.clientY-lastY)/200; box.rotation.y += dx; box.rotation.x += dy; lastX=e.clientX; lastY=e.clientY; }});
149
-
150
- function animate(){{ requestAnimationFrame(animate); renderer.render(scene, camera); }}
151
- animate();
152
-
153
- // responsive
154
- new ResizeObserver(()=>{{ const w = container.clientWidth; renderer.setSize(w,420); camera.aspect = w/420; camera.updateProjectionMatrix(); }}).observe(container);
 
 
 
 
 
 
 
 
 
 
 
 
155
  }})();
156
  </script>
157
- </body>
158
- </html>
159
- """
160
-
161
- # --- new gradio_generate that returns both gallery and 3D html ---
162
- def gradio_generate(fibers: Image.Image, rings: Image.Image, num_steps: int):
163
- """
164
- fibers: PIL image containing 4 tiles in a 2x2 square (order 1,4,2,3)
165
- rings: same
166
- returns: (list_of_4_PIL_images, html_string_for_3d)
167
- """
168
- # call your inference function (unchanged)
169
- preds = inference(pipe, fibers, rings, int(num_steps)) # preds order: [1,4,2,3] -> TL,TR,BL,BR
170
-
171
- # convert preds to RGB for gallery
172
- preds_rgb = [p.convert("RGB") if p.mode != "RGB" else p for p in preds]
173
-
174
- # build data-urls in same order
175
- data_urls = [pil_to_dataurl(im, fmt="PNG") for im in preds_rgb] # [TL,TR,BL,BR]
176
-
177
- # safe JSON array literal for injection into HTML template
178
- js_img_list = json.dumps(data_urls)
179
-
180
- html = HTML_TEMPLATE.replace("___IMGS___", js_img_list)
181
-
182
- # gr.Interface expects outputs matching signature; we return (gallery_images, html_string)
183
- return preds_rgb, html
184
-
185
- # replace the Interface outputs: first a Gallery, second a HTML component
186
- iface = gr.Interface(
187
- fn=gradio_generate,
188
- inputs=[
189
- gr.Image(type="pil", label="Fiber"),
190
- gr.Image(type="pil", label="Ring"),
191
- gr.Number(value=10, label="Number of inference steps"),
192
- ],
193
- outputs=[
194
- gr.HTML(label="3D preview")
195
- ],
196
- title="Photorealistic wood generator (4 faces)",
197
- description="""
198
- Upload 2 images (one with four fiber maps and one with four ring maps) arranged as a 2x2 square.
199
- The model returns four generated faces (one per tile) and a 3D preview with those faces textured on a box.
200
  """
201
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  if __name__ == "__main__":
204
- iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
205
 
 
7
 
8
  import gradio as gr
9
  from PIL import Image
10
+ import io, base64, json, traceback
11
  import torch
12
  from inference import inference
13
  from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
 
71
  pipe = pipe.to(torch.float32).to(device)
72
  # @spaces.GPU
73
 
74
+ def pil_to_data_uri(img: Image.Image) -> str:
75
  buf = io.BytesIO()
76
+ img.save(buf, format="PNG")
77
+ b = base64.b64encode(buf.getvalue()).decode("utf-8")
78
+ return f"data:image/png;base64,{b}"
79
+
80
+ # # --- new gradio_generate that returns both gallery and 3D html ---
81
+ # def gradio_generate(fibers: Image.Image, rings: Image.Image, num_steps: int):
82
+ # """
83
+ # fibers: PIL image containing 4 tiles in a 2x2 square (order 1,4,2,3)
84
+ # rings: same
85
+ # returns: (list_of_4_PIL_images, html_string_for_3d)
86
+ # """
87
+ # # call your inference function (unchanged)
88
+ # preds = inference(pipe, fibers, rings, int(num_steps)) # preds order: [1,4,2,3] -> TL,TR,BL,BR
89
+
90
+ # # convert preds to RGB for gallery
91
+ # preds_rgb = [p.convert("RGB") if p.mode != "RGB" else p for p in preds]
92
+
93
+ # # build data-urls in same order
94
+ # data_urls = [pil_to_dataurl(im, fmt="PNG") for im in preds_rgb] # [TL,TR,BL,BR]
95
+
96
+ # # safe JSON array literal for injection into HTML template
97
+ # js_img_list = json.dumps(data_urls)
98
+
99
+ # html = HTML_TEMPLATE.replace("___IMGS___", js_img_list)
100
+
101
+ # # gr.Interface expects outputs matching signature; we return (gallery_images, html_string)
102
+ # return preds_rgb, html
103
+
104
+ # # replace the Interface outputs: first a Gallery, second a HTML component
105
+ # iface = gr.Interface(
106
+ # fn=gradio_generate,
107
+ # inputs=[
108
+ # gr.Image(type="pil", label="Fiber"),
109
+ # gr.Image(type="pil", label="Ring"),
110
+ # gr.Number(value=10, label="Number of inference steps"),
111
+ # ],
112
+ # outputs=[
113
+ # gr.HTML(label="3D preview")
114
+ # ],
115
+ # title="Photorealistic wood generator (4 faces)",
116
+ # description="""
117
+ # Upload 2 images (one with four fiber maps and one with four ring maps) arranged as a 2x2 square.
118
+ # The model returns four generated faces (one per tile) and a 3D preview with those faces textured on a box.
119
+ # """
120
+ # )
121
+
122
+ def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
123
+ try:
124
+ # appelle la fonction d'inference de l'utilisateur
125
+ outputs = inference(pipe, fibers, rings, int(num_steps))
126
+ if not (isinstance(outputs, (list,tuple)) and len(outputs) >= 4):
127
+ raise ValueError("La fonction d'inference doit renvoyer une liste/tuple de 4 images.")
128
+ # Prendre les 4 premières images
129
+ imgs = outputs[:4]
130
+ # Convertir en PIL si nécessaire et assurer taille raisonnable
131
+ pil_imgs = []
132
+ for im in imgs:
133
+ if isinstance(im, np.ndarray):
134
+ im = Image.fromarray(im)
135
+ # for safety, ensure RGB
136
+ if im.mode != "RGB":
137
+ im = im.convert("RGB")
138
+ pil_imgs.append(im)
139
+
140
+ # Préparer miniatures (pour les 4 gr.Image)
141
+ thumb_imgs = [im.copy().resize((256,256)) for im in pil_imgs]
142
+
143
+ # Convertir en data-URIs pour Three.js
144
+ data_uris = [pil_to_data_uri(im) for im in pil_imgs]
145
+
146
+ # Construire le HTML/JS pour le visualiseur Three.js
147
+ html = f"""
148
+ <div id="viewer" style="width:100%;height:480px; border:1px solid #ddd;"></div>
149
+ <p style="font-size:0.9em;color:#444;margin-top:6px;">Manipulez la souris pour faire tourner, molette pour zoomer.</p>
150
+ <script src="https://unpkg.com/three@0.152.2/build/three.min.js"></script>
151
+ <script src="https://unpkg.com/three@0.152.2/examples/js/controls/OrbitControls.js"></script>
152
  <script>
153
+ (function() {{
154
+ const images = {data_uris!r}; // array of 4 data URIs
 
155
 
156
+ // cleanup previous canvas if any
157
+ const container = document.getElementById('viewer');
158
+ container.innerHTML = "";
 
159
 
160
  const scene = new THREE.Scene();
161
+ const camera = new THREE.PerspectiveCamera(45, container.clientWidth / container.clientHeight, 0.1, 1000);
162
+ camera.position.set(2.5, 2.0, 3.5);
 
 
163
 
164
+ const renderer = new THREE.WebGLRenderer({{antialias:true}});
165
+ renderer.setSize(container.clientWidth, container.clientHeight);
166
+ renderer.setPixelRatio(window.devicePixelRatio ? window.devicePixelRatio : 1);
167
+ container.appendChild(renderer.domElement);
168
 
169
+ // Lights (soft)
170
+ const hemi = new THREE.HemisphereLight(0xffffff, 0x444444, 1.0);
171
+ scene.add(hemi);
172
 
173
+ // Load textures from data URIs
174
+ const loader = new THREE.TextureLoader();
175
+ const texPromises = images.map((uri) => new Promise((res, rej) => {{
176
+ loader.load(uri, (tex) => {{ tex.flipY = false; res(tex); }}, undefined, rej);
177
+ }}));
178
+
179
+ Promise.all(texPromises).then((textures) => {{
180
+ // materials for box: order is [right, left, top, bottom, front, back]
181
+ const neutral = new THREE.MeshBasicMaterial({{ color:0xcccccc }});
182
+ const mats = [
183
+ new THREE.MeshBasicMaterial({{ map: textures[0] }}), // right
184
+ new THREE.MeshBasicMaterial({{ map: textures[1] }}), // left
185
+ neutral, // top
186
+ neutral, // bottom
187
+ new THREE.MeshBasicMaterial({{ map: textures[2] }}), // front
188
+ new THREE.MeshBasicMaterial({{ map: textures[3] }}) // back
189
+ ];
190
+
191
+ // ensure correct filtering / orientation
192
+ mats.forEach(m => {{ if (m.map) {{ m.map.minFilter = THREE.LinearFilter; m.map.wrapS = THREE.ClampToEdgeWrapping; m.map.wrapT = THREE.ClampToEdgeWrapping; }} }});
193
+
194
+ const geometry = new THREE.BoxGeometry(1.6,1.6,1.6);
195
+ const cube = new THREE.Mesh(geometry, mats);
196
+ scene.add(cube);
197
+
198
+ // grid & axes for context
199
+ const grid = new THREE.GridHelper(6, 12, 0x888888, 0x444444);
200
+ grid.position.y = -1.2;
201
+ scene.add(grid);
202
+
203
+ // controls
204
+ const controls = new THREE.OrbitControls(camera, renderer.domElement);
205
+ controls.enableDamping = true;
206
+ controls.dampingFactor = 0.07;
207
+ controls.target.set(0,0,0);
208
+
209
+ // responsive
210
+ function onWindowResize() {{
211
+ renderer.setSize(container.clientWidth, container.clientHeight);
212
+ camera.aspect = container.clientWidth / container.clientHeight;
213
+ camera.updateProjectionMatrix();
214
+ }}
215
+ window.addEventListener('resize', onWindowResize);
216
+
217
+ // animation loop
218
+ (function animate() {{
219
+ requestAnimationFrame(animate);
220
+ controls.update();
221
+ renderer.render(scene, camera);
222
+ }})();
223
+ }}).catch((e) => {{
224
+ container.innerHTML = "<div style='padding:16px;color:#900;'>Erreur chargement textures : " + (e && e.message ? e.message : e) + "</div>";
225
+ console.error(e);
226
+ }});
227
  }})();
228
  </script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  """
230
+ # Retourner : 4 miniatures pour affichage + HTML viewer
231
+ return (*thumb_imgs, html)
232
+ except Exception as e:
233
+ traceback.print_exc()
234
+ # en cas d'erreur, renvoyer 4 images vides et message html d'erreur
235
+ blank = Image.new("RGB", (256,256), (220,220,220))
236
+ err_html = f"<div style='color:#900;padding:12px;'>Erreur serveur: {str(e)}</div>"
237
+ return (blank, blank, blank, blank, err_html)
238
+
239
+ # Interface Gradio
240
+ with gr.Blocks(title="3D Cube Viewer - 4 outputs on cube") as demo:
241
+ gr.Markdown("### App 3D : assemble 4 images produites par votre code d'inference sur les 4 faces d'un cube.")
242
+ with gr.Row():
243
+ with gr.Column(scale=1):
244
+ inp1 = gr.Image(label="Image 1 (entrée)", type="numpy")
245
+ inp2 = gr.Image(label="Image 2 (entrée)", type="numpy")
246
+ inp3 = gr.Number(value=10, label="Number of inference steps")
247
+ run_btn = gr.Button("Lancer l'inference")
248
+ gr.Markdown("Remplace `user_inference` dans `app.py` par ton code d'inférence existant.")
249
+ with gr.Column(scale=2):
250
+ viewer = gr.HTML("<div style='padding:12px;color:#666;'>Le rendu 3D s'affichera ici après inference.</div>", label="Visu 3D")
251
+ gr.Row().style(mobile_collapse=False)
252
+ with gr.Row():
253
+ out1 = gr.Image(label="Output 1")
254
+ out2 = gr.Image(label="Output 2")
255
+ out3 = gr.Image(label="Output 3")
256
+ out4 = gr.Image(label="Output 4")
257
+
258
+ run_btn.click(fn=run, inputs=[inp1, inp2, inp3], outputs=[out1, out2, out3, out4, viewer])
259
 
260
  if __name__ == "__main__":
261
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
262