Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ Created on Tue Jun 10 11:16:28 2025
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from PIL import Image
|
| 10 |
-
import io, base64, json
|
| 11 |
import torch
|
| 12 |
from inference import inference
|
| 13 |
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
|
|
@@ -71,135 +71,192 @@ pipe = StableDiffusionInstructPix2PixPipeline(
|
|
| 71 |
pipe = pipe.to(torch.float32).to(device)
|
| 72 |
# @spaces.GPU
|
| 73 |
|
| 74 |
-
def
|
| 75 |
buf = io.BytesIO()
|
| 76 |
-
img.save(buf, format=
|
| 77 |
-
|
| 78 |
-
return f"data:image/
|
| 79 |
-
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
<script>
|
| 92 |
-
(function(){{
|
| 93 |
-
const
|
| 94 |
-
const container = document.getElementById('three-root');
|
| 95 |
|
| 96 |
-
//
|
| 97 |
-
const
|
| 98 |
-
|
| 99 |
-
container.appendChild(renderer.domElement);
|
| 100 |
|
| 101 |
const scene = new THREE.Scene();
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
camera.position.set(1.8, 1.1, 2.4);
|
| 105 |
-
camera.lookAt(0,0,0);
|
| 106 |
|
| 107 |
-
const
|
| 108 |
-
|
| 109 |
-
|
|
|
|
| 110 |
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
//
|
| 114 |
-
|
| 115 |
-
const
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
new THREE.
|
| 135 |
-
new THREE.
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
}})();
|
| 156 |
</script>
|
| 157 |
-
</body>
|
| 158 |
-
</html>
|
| 159 |
-
"""
|
| 160 |
-
|
| 161 |
-
# --- new gradio_generate that returns both gallery and 3D html ---
|
| 162 |
-
def gradio_generate(fibers: Image.Image, rings: Image.Image, num_steps: int):
|
| 163 |
-
"""
|
| 164 |
-
fibers: PIL image containing 4 tiles in a 2x2 square (order 1,4,2,3)
|
| 165 |
-
rings: same
|
| 166 |
-
returns: (list_of_4_PIL_images, html_string_for_3d)
|
| 167 |
-
"""
|
| 168 |
-
# call your inference function (unchanged)
|
| 169 |
-
preds = inference(pipe, fibers, rings, int(num_steps)) # preds order: [1,4,2,3] -> TL,TR,BL,BR
|
| 170 |
-
|
| 171 |
-
# convert preds to RGB for gallery
|
| 172 |
-
preds_rgb = [p.convert("RGB") if p.mode != "RGB" else p for p in preds]
|
| 173 |
-
|
| 174 |
-
# build data-urls in same order
|
| 175 |
-
data_urls = [pil_to_dataurl(im, fmt="PNG") for im in preds_rgb] # [TL,TR,BL,BR]
|
| 176 |
-
|
| 177 |
-
# safe JSON array literal for injection into HTML template
|
| 178 |
-
js_img_list = json.dumps(data_urls)
|
| 179 |
-
|
| 180 |
-
html = HTML_TEMPLATE.replace("___IMGS___", js_img_list)
|
| 181 |
-
|
| 182 |
-
# gr.Interface expects outputs matching signature; we return (gallery_images, html_string)
|
| 183 |
-
return preds_rgb, html
|
| 184 |
-
|
| 185 |
-
# replace the Interface outputs: first a Gallery, second a HTML component
|
| 186 |
-
iface = gr.Interface(
|
| 187 |
-
fn=gradio_generate,
|
| 188 |
-
inputs=[
|
| 189 |
-
gr.Image(type="pil", label="Fiber"),
|
| 190 |
-
gr.Image(type="pil", label="Ring"),
|
| 191 |
-
gr.Number(value=10, label="Number of inference steps"),
|
| 192 |
-
],
|
| 193 |
-
outputs=[
|
| 194 |
-
gr.HTML(label="3D preview")
|
| 195 |
-
],
|
| 196 |
-
title="Photorealistic wood generator (4 faces)",
|
| 197 |
-
description="""
|
| 198 |
-
Upload 2 images (one with four fiber maps and one with four ring maps) arranged as a 2x2 square.
|
| 199 |
-
The model returns four generated faces (one per tile) and a 3D preview with those faces textured on a box.
|
| 200 |
"""
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
if __name__ == "__main__":
|
| 204 |
-
|
| 205 |
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from PIL import Image
|
| 10 |
+
import io, base64, json, traceback
|
| 11 |
import torch
|
| 12 |
from inference import inference
|
| 13 |
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
|
|
|
|
| 71 |
pipe = pipe.to(torch.float32).to(device)
|
| 72 |
# @spaces.GPU
|
| 73 |
|
| 74 |
+
def pil_to_data_uri(img: Image.Image) -> str:
|
| 75 |
buf = io.BytesIO()
|
| 76 |
+
img.save(buf, format="PNG")
|
| 77 |
+
b = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 78 |
+
return f"data:image/png;base64,{b}"
|
| 79 |
+
|
| 80 |
+
# # --- new gradio_generate that returns both gallery and 3D html ---
|
| 81 |
+
# def gradio_generate(fibers: Image.Image, rings: Image.Image, num_steps: int):
|
| 82 |
+
# """
|
| 83 |
+
# fibers: PIL image containing 4 tiles in a 2x2 square (order 1,4,2,3)
|
| 84 |
+
# rings: same
|
| 85 |
+
# returns: (list_of_4_PIL_images, html_string_for_3d)
|
| 86 |
+
# """
|
| 87 |
+
# # call your inference function (unchanged)
|
| 88 |
+
# preds = inference(pipe, fibers, rings, int(num_steps)) # preds order: [1,4,2,3] -> TL,TR,BL,BR
|
| 89 |
+
|
| 90 |
+
# # convert preds to RGB for gallery
|
| 91 |
+
# preds_rgb = [p.convert("RGB") if p.mode != "RGB" else p for p in preds]
|
| 92 |
+
|
| 93 |
+
# # build data-urls in same order
|
| 94 |
+
# data_urls = [pil_to_dataurl(im, fmt="PNG") for im in preds_rgb] # [TL,TR,BL,BR]
|
| 95 |
+
|
| 96 |
+
# # safe JSON array literal for injection into HTML template
|
| 97 |
+
# js_img_list = json.dumps(data_urls)
|
| 98 |
+
|
| 99 |
+
# html = HTML_TEMPLATE.replace("___IMGS___", js_img_list)
|
| 100 |
+
|
| 101 |
+
# # gr.Interface expects outputs matching signature; we return (gallery_images, html_string)
|
| 102 |
+
# return preds_rgb, html
|
| 103 |
+
|
| 104 |
+
# # replace the Interface outputs: first a Gallery, second a HTML component
|
| 105 |
+
# iface = gr.Interface(
|
| 106 |
+
# fn=gradio_generate,
|
| 107 |
+
# inputs=[
|
| 108 |
+
# gr.Image(type="pil", label="Fiber"),
|
| 109 |
+
# gr.Image(type="pil", label="Ring"),
|
| 110 |
+
# gr.Number(value=10, label="Number of inference steps"),
|
| 111 |
+
# ],
|
| 112 |
+
# outputs=[
|
| 113 |
+
# gr.HTML(label="3D preview")
|
| 114 |
+
# ],
|
| 115 |
+
# title="Photorealistic wood generator (4 faces)",
|
| 116 |
+
# description="""
|
| 117 |
+
# Upload 2 images (one with four fiber maps and one with four ring maps) arranged as a 2x2 square.
|
| 118 |
+
# The model returns four generated faces (one per tile) and a 3D preview with those faces textured on a box.
|
| 119 |
+
# """
|
| 120 |
+
# )
|
| 121 |
+
|
| 122 |
+
def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
|
| 123 |
+
try:
|
| 124 |
+
# appelle la fonction d'inference de l'utilisateur
|
| 125 |
+
outputs = inference(pipe, fibers, rings, int(num_steps))
|
| 126 |
+
if not (isinstance(outputs, (list,tuple)) and len(outputs) >= 4):
|
| 127 |
+
raise ValueError("La fonction d'inference doit renvoyer une liste/tuple de 4 images.")
|
| 128 |
+
# Prendre les 4 premières images
|
| 129 |
+
imgs = outputs[:4]
|
| 130 |
+
# Convertir en PIL si nécessaire et assurer taille raisonnable
|
| 131 |
+
pil_imgs = []
|
| 132 |
+
for im in imgs:
|
| 133 |
+
if isinstance(im, np.ndarray):
|
| 134 |
+
im = Image.fromarray(im)
|
| 135 |
+
# for safety, ensure RGB
|
| 136 |
+
if im.mode != "RGB":
|
| 137 |
+
im = im.convert("RGB")
|
| 138 |
+
pil_imgs.append(im)
|
| 139 |
+
|
| 140 |
+
# Préparer miniatures (pour les 4 gr.Image)
|
| 141 |
+
thumb_imgs = [im.copy().resize((256,256)) for im in pil_imgs]
|
| 142 |
+
|
| 143 |
+
# Convertir en data-URIs pour Three.js
|
| 144 |
+
data_uris = [pil_to_data_uri(im) for im in pil_imgs]
|
| 145 |
+
|
| 146 |
+
# Construire le HTML/JS pour le visualiseur Three.js
|
| 147 |
+
html = f"""
|
| 148 |
+
<div id="viewer" style="width:100%;height:480px; border:1px solid #ddd;"></div>
|
| 149 |
+
<p style="font-size:0.9em;color:#444;margin-top:6px;">Manipulez la souris pour faire tourner, molette pour zoomer.</p>
|
| 150 |
+
<script src="https://unpkg.com/three@0.152.2/build/three.min.js"></script>
|
| 151 |
+
<script src="https://unpkg.com/three@0.152.2/examples/js/controls/OrbitControls.js"></script>
|
| 152 |
<script>
|
| 153 |
+
(function() {{
|
| 154 |
+
const images = {data_uris!r}; // array of 4 data URIs
|
|
|
|
| 155 |
|
| 156 |
+
// cleanup previous canvas if any
|
| 157 |
+
const container = document.getElementById('viewer');
|
| 158 |
+
container.innerHTML = "";
|
|
|
|
| 159 |
|
| 160 |
const scene = new THREE.Scene();
|
| 161 |
+
const camera = new THREE.PerspectiveCamera(45, container.clientWidth / container.clientHeight, 0.1, 1000);
|
| 162 |
+
camera.position.set(2.5, 2.0, 3.5);
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
const renderer = new THREE.WebGLRenderer({{antialias:true}});
|
| 165 |
+
renderer.setSize(container.clientWidth, container.clientHeight);
|
| 166 |
+
renderer.setPixelRatio(window.devicePixelRatio ? window.devicePixelRatio : 1);
|
| 167 |
+
container.appendChild(renderer.domElement);
|
| 168 |
|
| 169 |
+
// Lights (soft)
|
| 170 |
+
const hemi = new THREE.HemisphereLight(0xffffff, 0x444444, 1.0);
|
| 171 |
+
scene.add(hemi);
|
| 172 |
|
| 173 |
+
// Load textures from data URIs
|
| 174 |
+
const loader = new THREE.TextureLoader();
|
| 175 |
+
const texPromises = images.map((uri) => new Promise((res, rej) => {{
|
| 176 |
+
loader.load(uri, (tex) => {{ tex.flipY = false; res(tex); }}, undefined, rej);
|
| 177 |
+
}}));
|
| 178 |
+
|
| 179 |
+
Promise.all(texPromises).then((textures) => {{
|
| 180 |
+
// materials for box: order is [right, left, top, bottom, front, back]
|
| 181 |
+
const neutral = new THREE.MeshBasicMaterial({{ color:0xcccccc }});
|
| 182 |
+
const mats = [
|
| 183 |
+
new THREE.MeshBasicMaterial({{ map: textures[0] }}), // right
|
| 184 |
+
new THREE.MeshBasicMaterial({{ map: textures[1] }}), // left
|
| 185 |
+
neutral, // top
|
| 186 |
+
neutral, // bottom
|
| 187 |
+
new THREE.MeshBasicMaterial({{ map: textures[2] }}), // front
|
| 188 |
+
new THREE.MeshBasicMaterial({{ map: textures[3] }}) // back
|
| 189 |
+
];
|
| 190 |
+
|
| 191 |
+
// ensure correct filtering / orientation
|
| 192 |
+
mats.forEach(m => {{ if (m.map) {{ m.map.minFilter = THREE.LinearFilter; m.map.wrapS = THREE.ClampToEdgeWrapping; m.map.wrapT = THREE.ClampToEdgeWrapping; }} }});
|
| 193 |
+
|
| 194 |
+
const geometry = new THREE.BoxGeometry(1.6,1.6,1.6);
|
| 195 |
+
const cube = new THREE.Mesh(geometry, mats);
|
| 196 |
+
scene.add(cube);
|
| 197 |
+
|
| 198 |
+
// grid & axes for context
|
| 199 |
+
const grid = new THREE.GridHelper(6, 12, 0x888888, 0x444444);
|
| 200 |
+
grid.position.y = -1.2;
|
| 201 |
+
scene.add(grid);
|
| 202 |
+
|
| 203 |
+
// controls
|
| 204 |
+
const controls = new THREE.OrbitControls(camera, renderer.domElement);
|
| 205 |
+
controls.enableDamping = true;
|
| 206 |
+
controls.dampingFactor = 0.07;
|
| 207 |
+
controls.target.set(0,0,0);
|
| 208 |
+
|
| 209 |
+
// responsive
|
| 210 |
+
function onWindowResize() {{
|
| 211 |
+
renderer.setSize(container.clientWidth, container.clientHeight);
|
| 212 |
+
camera.aspect = container.clientWidth / container.clientHeight;
|
| 213 |
+
camera.updateProjectionMatrix();
|
| 214 |
+
}}
|
| 215 |
+
window.addEventListener('resize', onWindowResize);
|
| 216 |
+
|
| 217 |
+
// animation loop
|
| 218 |
+
(function animate() {{
|
| 219 |
+
requestAnimationFrame(animate);
|
| 220 |
+
controls.update();
|
| 221 |
+
renderer.render(scene, camera);
|
| 222 |
+
}})();
|
| 223 |
+
}}).catch((e) => {{
|
| 224 |
+
container.innerHTML = "<div style='padding:16px;color:#900;'>Erreur chargement textures : " + (e && e.message ? e.message : e) + "</div>";
|
| 225 |
+
console.error(e);
|
| 226 |
+
}});
|
| 227 |
}})();
|
| 228 |
</script>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
"""
|
| 230 |
+
# Retourner : 4 miniatures pour affichage + HTML viewer
|
| 231 |
+
return (*thumb_imgs, html)
|
| 232 |
+
except Exception as e:
|
| 233 |
+
traceback.print_exc()
|
| 234 |
+
# en cas d'erreur, renvoyer 4 images vides et message html d'erreur
|
| 235 |
+
blank = Image.new("RGB", (256,256), (220,220,220))
|
| 236 |
+
err_html = f"<div style='color:#900;padding:12px;'>Erreur serveur: {str(e)}</div>"
|
| 237 |
+
return (blank, blank, blank, blank, err_html)
|
| 238 |
+
|
| 239 |
+
# Interface Gradio
|
| 240 |
+
with gr.Blocks(title="3D Cube Viewer - 4 outputs on cube") as demo:
|
| 241 |
+
gr.Markdown("### App 3D : assemble 4 images produites par votre code d'inference sur les 4 faces d'un cube.")
|
| 242 |
+
with gr.Row():
|
| 243 |
+
with gr.Column(scale=1):
|
| 244 |
+
inp1 = gr.Image(label="Image 1 (entrée)", type="numpy")
|
| 245 |
+
inp2 = gr.Image(label="Image 2 (entrée)", type="numpy")
|
| 246 |
+
inp3 = gr.Number(value=10, label="Number of inference steps")
|
| 247 |
+
run_btn = gr.Button("Lancer l'inference")
|
| 248 |
+
gr.Markdown("Remplace `user_inference` dans `app.py` par ton code d'inférence existant.")
|
| 249 |
+
with gr.Column(scale=2):
|
| 250 |
+
viewer = gr.HTML("<div style='padding:12px;color:#666;'>Le rendu 3D s'affichera ici après inference.</div>", label="Visu 3D")
|
| 251 |
+
gr.Row().style(mobile_collapse=False)
|
| 252 |
+
with gr.Row():
|
| 253 |
+
out1 = gr.Image(label="Output 1")
|
| 254 |
+
out2 = gr.Image(label="Output 2")
|
| 255 |
+
out3 = gr.Image(label="Output 3")
|
| 256 |
+
out4 = gr.Image(label="Output 4")
|
| 257 |
+
|
| 258 |
+
run_btn.click(fn=run, inputs=[inp1, inp2, inp3], outputs=[out1, out2, out3, out4, viewer])
|
| 259 |
|
| 260 |
if __name__ == "__main__":
|
| 261 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
| 262 |
|