zimhe commited on
Commit
5790b69
·
1 Parent(s): c52cd6c

Update space

Browse files
Files changed (2) hide show
  1. app.py +372 -46
  2. requirements.txt +10 -3
app.py CHANGED
@@ -3,27 +3,133 @@ import numpy as np
3
  import random
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
 
 
6
  from diffusers import DiffusionPipeline
 
 
 
 
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
11
 
12
  if torch.cuda.is_available():
 
13
  torch_dtype = torch.float16
14
  else:
15
  torch_dtype = torch.float32
 
 
 
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
 
18
  pipe = pipe.to(device)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def infer(
26
  prompt,
 
 
 
 
 
 
 
27
  negative_prompt,
28
  seed,
29
  randomize_seed,
@@ -31,64 +137,169 @@ def infer(
31
  height,
32
  guidance_scale,
33
  num_inference_steps,
 
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
39
  generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
 
 
 
 
 
 
 
 
 
50
 
51
- return image, seed
52
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
60
  css = """
61
  #col-container {
 
 
 
 
 
62
  margin: 0 auto;
63
  max-width: 640px;
64
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
 
 
71
  with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  with gr.Accordion("Advanced Settings", open=False):
85
  negative_prompt = gr.Text(
 
 
86
  label="Negative prompt",
87
  max_lines=1,
88
  placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
-
 
92
  seed = gr.Slider(
93
  label="Seed",
94
  minimum=0,
@@ -98,6 +309,8 @@ with gr.Blocks(css=css) as demo:
98
  )
99
 
100
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
101
 
102
  with gr.Row():
103
  width = gr.Slider(
@@ -105,7 +318,7 @@ with gr.Blocks(css=css) as demo:
105
  minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
 
111
  height = gr.Slider(
@@ -113,16 +326,16 @@ with gr.Blocks(css=css) as demo:
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
  with gr.Row():
120
  guidance_scale = gr.Slider(
121
  label="Guidance scale",
122
  minimum=0.0,
123
- maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
 
128
  num_inference_steps = gr.Slider(
@@ -130,15 +343,84 @@ with gr.Blocks(css=css) as demo:
130
  minimum=1,
131
  maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  fn=infer,
140
  inputs=[
141
- prompt,
 
 
 
 
 
 
 
142
  negative_prompt,
143
  seed,
144
  randomize_seed,
@@ -146,8 +428,52 @@ with gr.Blocks(css=css) as demo:
146
  height,
147
  guidance_scale,
148
  num_inference_steps,
 
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
  if __name__ == "__main__":
 
3
  import random
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
6
+ from scripts.cubemap_vae import CubemapVAE
7
+ from scripts.cubemap_unet import CubemapUNet
8
  from diffusers import DiffusionPipeline
9
+ from scripts.cubemap_diffusion_pipeline import CubemapDiffusionInpaintPipeline
10
+ from scripts.utils import resize_and_crop,convert_to_equirectangular,to_cubemap_dict,cubemap_unfold
11
+ from diffusers import AutoencoderKL,UNet2DConditionModel
12
+ from contextlib import nullcontext
13
  import torch
14
+ from PIL import Image
15
+ import base64
16
+ from io import BytesIO
17
+ import json
18
+
19
+ import os
20
+ from datetime import datetime
21
+ import time
22
+
23
+ from realesrgan import RealESRGANer
24
+ from basicsr.archs.rrdbnet_arch import RRDBNet
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ model_repo_id = "zimhe/SpatialDiffusion" # Replace to the model you would like to use
28
+ upscale_model_id = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
29
+
30
 
31
  if torch.cuda.is_available():
32
+ print("CUDA is available")
33
  torch_dtype = torch.float16
34
  else:
35
  torch_dtype = torch.float32
36
+
37
+ pretrained_vae = AutoencoderKL.from_pretrained(
38
+ model_repo_id, subfolder="vae",torch_dtype=torch_dtype
39
+ )
40
+ pretrained_unet=UNet2DConditionModel.from_pretrained(model_repo_id,subfolder="unet",torch_dtype=torch_dtype)
41
 
42
+ cubemap_unet=CubemapUNet(pretrained_unet=pretrained_unet)
43
+ cubemap_vae = CubemapVAE(num_views=6, pretrained_vae=pretrained_vae,in_channels=3) # 你的 VAE 结构
44
+
45
+ pipe = CubemapDiffusionInpaintPipeline.from_pretrained(model_repo_id,vae=cubemap_vae,unet=cubemap_unet,torch_dtype=torch_dtype,safety_checker=None)
46
  pipe = pipe.to(device)
47
 
48
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
49
+ upsampler = RealESRGANer(
50
+ scale=4,
51
+ model_path=upscale_model_id,
52
+ model=model,
53
+ tile=512,
54
+ tile_pad=32,
55
+ pre_pad=0,
56
+ device=device,
57
+ half=True
58
+ )
59
+
60
+
61
+
62
+
63
  MAX_SEED = np.iinfo(np.int32).max
64
+ MAX_IMAGE_SIZE = 512
65
+
66
+ # 获取当前脚本所在目录
67
+ current_dir = os.path.dirname(os.path.abspath(__file__))
68
+ viewer_html_path = os.path.join(current_dir, "viewer.html")
69
+ default_image_path = os.path.join(current_dir, "img", "004.png")
70
+
71
+ # 读取 viewer.html 内容
72
+ with open(viewer_html_path, 'r', encoding='utf-8') as f:
73
+ viewer_html_content = f.read()
74
 
75
+ # 读取默认图片并转换为 base64
76
+ with open(default_image_path, 'rb') as f:
77
+ default_image_data = f.read()
78
+ default_image_base64 = base64.b64encode(default_image_data).decode('utf-8')
79
+ default_image_url = f"data:image/png;base64,{default_image_base64}"
80
+
81
+ with open("examples/examples.json", "r") as f:
82
+ examples_data = json.load(f)
83
+ examples=[]
84
+ example_labels=[]
85
+ for key in examples_data:
86
+ example=examples_data[key]
87
+ example_list=[
88
+ example["img"],
89
+ example["global"],
90
+ example["front"],
91
+ example["back"],
92
+ example["left"],
93
+ example["right"],
94
+ example["top"],
95
+ example["bottom"]
96
+ ]
97
+ examples.append(example_list)
98
+ example_labels.append(key)
99
 
100
+
101
+ def process_panorama(image):
102
+ """处理上传的全景图片并创建查看器"""
103
+ if image is None:
104
+ return None
105
+
106
+ try:
107
+ # 将图片转换为 JPEG 格式的二进制数据
108
+ buffered = BytesIO()
109
+ if isinstance(image, Image.Image):
110
+ image.save(buffered, format="JPEG", quality=95, optimize=True)
111
+ else:
112
+ Image.fromarray(image).save(buffered, format="JPEG", quality=95, optimize=True)
113
+
114
+ # 将图片转换为 base64 字符串
115
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
116
+ return img_str
117
+ except Exception as e:
118
+ print(f"处理图片时出错: {str(e)}")
119
+ return None
120
+
121
+
122
+
123
+
124
  def infer(
125
  prompt,
126
+ front_prompt,
127
+ back_prompt,
128
+ left_prompt,
129
+ right_prompt,
130
+ top_prompt,
131
+ bottom_prompt,
132
+ cond_img: Image.Image, # Declare cond_img as a PIL Image
133
  negative_prompt,
134
  seed,
135
  randomize_seed,
 
137
  height,
138
  guidance_scale,
139
  num_inference_steps,
140
+ upscale=False,
141
  progress=gr.Progress(track_tqdm=True),
142
  ):
143
  if randomize_seed:
144
  seed = random.randint(0, MAX_SEED)
145
 
146
  generator = torch.Generator().manual_seed(seed)
147
+
148
+ # Preprocess the input image to make it square (1:1 aspect ratio)
149
+ # Ensure the image is square by cropping to the smallest dimension
150
+ W, H = cond_img.size
151
+ min_dim = min(W, H)
152
+ left = (W - min_dim) // 2
153
+ top = (H - min_dim) // 2
154
+ right = left + min_dim
155
+ bottom = top + min_dim
156
+ cond_img = cond_img.crop((left, top, right, bottom))
157
+
158
+ if torch.backends.mps.is_available():
159
+ autocast_ctx = nullcontext()
160
+ elif torch.cuda.is_available():
161
+ autocast_ctx = torch.amp.autocast(device_type="cuda")
162
+ else:
163
+ autocast_ctx = torch.cpu.amp.autocast()
164
+
165
+ face_prompt_dict = {
166
+ "front": front_prompt,
167
+ "back": back_prompt,
168
+ "left": left_prompt,
169
+ "right": right_prompt,
170
+ "top": top_prompt,
171
+ "bottom": bottom_prompt,
172
+ }
173
+
174
+ with autocast_ctx:
175
+ images = pipe(
176
+ global_prompt=prompt,
177
+ per_face_prompts=face_prompt_dict,
178
+ image=cond_img,
179
+ negative_prompt=negative_prompt,
180
+ guidance_scale=guidance_scale,
181
+ num_inference_steps=num_inference_steps,
182
+ width=width,
183
+ height=height,
184
+ output_type="np",
185
+ generator=generator,
186
+ ).images
187
+
188
+ cubemaps=[resize_and_crop(image=image,padding=16) for image in images]
189
 
190
+ cubemap_dict=to_cubemap_dict(cubemaps)
191
+ pano_img=convert_to_equirectangular(cubemap_dict,width=2048,height=1024)
192
+
193
+ if device == "cuda":
194
+ torch.cuda.empty_cache()
195
+
196
+ if upscale:
197
+ try:
198
+ # Use the existing autocast_ctx instead of creating a new one
199
+ img_np = np.array(pano_img).astype(np.uint8)
200
+ output, _ = upsampler.enhance(img=img_np, outscale=2)
201
+ pano_img = Image.fromarray(output)
202
+ except Exception as e:
203
+ print(f"Upscaling error: {str(e)}")
204
+
205
+ if device == "cuda":
206
+ torch.cuda.empty_cache()
207
+
208
 
209
+ return cubemap_dict["F"], cubemap_dict["B"], cubemap_dict["L"], cubemap_dict["R"], cubemap_dict["U"], cubemap_dict["D"], pano_img,seed,
210
 
211
 
 
 
 
 
 
212
 
213
  css = """
214
  #col-container {
215
+ margin: 0 auto;
216
+ max-width: 980px;
217
+ }
218
+
219
+ #input_container {
220
  margin: 0 auto;
221
  max-width: 640px;
222
  }
223
+
224
+ #squre_image {
225
+ width: 100%;
226
+ height: auto;
227
+ aspect-ratio: 1 / 1;
228
+ }
229
+
230
+ #pano_image {
231
+ width: 100%;
232
+ height: auto;
233
+ aspect-ratio: 2 / 1;
234
+ }
235
+
236
  """
237
 
238
  with gr.Blocks(css=css) as demo:
239
  with gr.Column(elem_id="col-container"):
240
+ gr.Markdown(" # Spatial Diffusion")
241
+ pano_html = gr.HTML(label="panorama viewer", elem_classes=["panorama-output"],container=True)
242
+ gr.Markdown("## Input Parameters")
243
+
244
  with gr.Row():
245
+ with gr.Column(scale=1):
246
+ # Image upload with 1:1 aspect ratio
247
+ cond_img = gr.Image(
248
+ label="Condition Image",
249
+ type="pil",
250
+ sources=["upload","webcam","clipboard"],
251
+ elem_id="squre_image",
252
+ container=True,
253
+ )
254
+
255
+
256
+ with gr.Column(scale=1):
257
+ global_prompt = gr.Text(
258
+ label="Global Prompt",
259
+ show_label=True,
260
+ max_lines=2,
261
+ placeholder="Enter global prompt",
262
+ container=True,
263
+ )
264
+
265
+
266
+ face_prompts = {}
267
+ for face in ["front", "back", "left", "right", "top", "bottom"]:
268
+ face_prompts[face] = gr.Text(
269
+ label=f"{face.capitalize()} Prompt",
270
+ show_label=True,
271
+ max_lines=1,
272
+ placeholder=f"Enter {face.lower()} prompt",
273
+ container=False,
274
+ )
275
+
276
+ run_button = gr.Button("Run", variant="primary")
277
+
278
+ gr.Examples(
279
+ examples=examples,
280
+ example_labels=example_labels,
281
+ inputs=[
282
+ cond_img,
283
+ global_prompt,
284
+ face_prompts["front"],
285
+ face_prompts["back"],
286
+ face_prompts["left"],
287
+ face_prompts["right"],
288
+ face_prompts["top"],
289
+ face_prompts["bottom"]
290
+ ],
291
+ )
292
+
293
  with gr.Accordion("Advanced Settings", open=False):
294
  negative_prompt = gr.Text(
295
+ value='''grids, lines, texts, labels, blury, bad quality, bad image, wrong scale, clear seams, distorted objects, disconnected edges, replicated items,
296
+ blurry, overexposed, chaotic, low resolution, 3D render, overly dramatic, unrealistic''',
297
  label="Negative prompt",
298
  max_lines=1,
299
  placeholder="Enter a negative prompt",
 
300
  )
301
+
302
+
303
  seed = gr.Slider(
304
  label="Seed",
305
  minimum=0,
 
309
  )
310
 
311
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
312
+
313
+ upscale=gr.Checkbox(label="Upscale", value=False)
314
 
315
  with gr.Row():
316
  width = gr.Slider(
 
318
  minimum=256,
319
  maximum=MAX_IMAGE_SIZE,
320
  step=32,
321
+ value=512, # Replace with defaults that work for your model
322
  )
323
 
324
  height = gr.Slider(
 
326
  minimum=256,
327
  maximum=MAX_IMAGE_SIZE,
328
  step=32,
329
+ value=512, # Replace with defaults that work for your model
330
  )
331
 
332
  with gr.Row():
333
  guidance_scale = gr.Slider(
334
  label="Guidance scale",
335
  minimum=0.0,
336
+ maximum=15.0,
337
  step=0.1,
338
+ value=9.0, # Replace with defaults that work for your model
339
  )
340
 
341
  num_inference_steps = gr.Slider(
 
343
  minimum=1,
344
  maximum=50,
345
  step=1,
346
+ value=30, # Replace with defaults that work for your model
347
  )
348
+
349
+
350
+ gr.Markdown("## Result")
351
+ with gr.Row():
352
+ left_face = gr.Image(label="Left", show_label=True,elem_id="squre_image",format="png")
353
+ front_face = gr.Image(label="Front", show_label=True,elem_id="squre_image",format="png")
354
+ right_face = gr.Image(label="Right", show_label=True,elem_id="squre_image",format="png")
355
+ with gr.Row():
356
+ back_face = gr.Image(label="Back", show_label=True,elem_id="squre_image",format="png")
357
+ top_face = gr.Image(label="Top", show_label=True,elem_id="squre_image",format="png")
358
+ bottom_face = gr.Image(label="Bottom", show_label=True,elem_id="squre_image",format="png")
359
+
360
+ pano = gr.Image(label="Equirectangular Image", show_label=True, interactive=False,type="pil",elem_id="pano_image",format="png")
361
+ save_button = gr.Button("Save All", variant="primary")
362
+
363
+
364
+
365
+ # 监听 result 图像的变化
366
+ pano.change(
367
+ fn=process_panorama, # 不需要 Python 函数
368
+ inputs=[pano], # 将图像转换为 base64 字符串
369
+ outputs=[pano_html],
370
+ js=f"""
371
+ async (img_obj) => {{
372
+ if (!img_obj || !img_obj.url) return;
373
+
374
+ // 创建 iframe 容器
375
+ const container = document.querySelector('.panorama-output');
376
+ if (container) {{
377
+ // 将 viewer.html 内容转换为 data URL
378
+ const viewerHtml = `{viewer_html_content}`;
379
+ const viewerBlob = new Blob([viewerHtml], {{ type: 'text/html' }});
380
+ const viewerUrl = URL.createObjectURL(viewerBlob);
381
+
382
+ container.innerHTML = `<iframe id="panorama-viewer" style="width: 100%; height: 480px; border: none;" src="${{viewerUrl}}"></iframe>`;
383
+
384
+ // 等待 iframe 加载完成
385
+ const iframe = document.getElementById('panorama-viewer');
386
+ iframe.onload = async () => {{
387
+ try {{
388
+ // 从 URL 获取图片数据
389
+ const response = await fetch(img_obj.url);
390
+ const blob = await response.blob();
391
+ const reader = new FileReader();
392
+
393
+ reader.onloadend = () => {{
394
+ // 向 iframe 发送图片数据
395
+ iframe.contentWindow.postMessage({{
396
+ type: 'loadPanorama',
397
+ image: reader.result
398
+ }}, '*');
399
+ }};
400
+
401
+ reader.readAsDataURL(blob);
402
+ }} catch (error) {{
403
+ console.error('Error processing image:', error);
404
+ console.log('Image object:', img_obj);
405
+ }}
406
+ }};
407
+ }}
408
+ }}
409
+ """
410
+ )
411
+
412
+
413
+ run_button.click(
414
  fn=infer,
415
  inputs=[
416
+ global_prompt,
417
+ face_prompts["front"], # 显式传递每个面对应的组件
418
+ face_prompts["back"],
419
+ face_prompts["left"],
420
+ face_prompts["right"],
421
+ face_prompts["top"],
422
+ face_prompts["bottom"],
423
+ cond_img,
424
  negative_prompt,
425
  seed,
426
  randomize_seed,
 
428
  height,
429
  guidance_scale,
430
  num_inference_steps,
431
+ upscale
432
  ],
433
+ outputs=[
434
+ front_face, # Update with "front"
435
+ back_face, # Update with "back"
436
+ left_face, # Update with "left"
437
+ right_face, # Update with "right"
438
+ top_face, # Update with "top"
439
+ bottom_face, # Update with "bottom"
440
+ pano, # Update with "pano"
441
+ seed, # Update with "seed"
442
+ ],
443
+ )
444
+
445
+ # 初始化时显示默认全景图
446
+ demo.load(
447
+ fn=None,
448
+ inputs=None,
449
+ outputs=None,
450
+ js=f"""
451
+ () => {{
452
+ // 创建 iframe 容器
453
+ const container = document.querySelector('.panorama-output');
454
+ if (container) {{
455
+ // 将 viewer.html 内容转换为 data URL
456
+ const viewerHtml = `{viewer_html_content}`;
457
+ const viewerBlob = new Blob([viewerHtml], {{ type: 'text/html' }});
458
+ const viewerUrl = URL.createObjectURL(viewerBlob);
459
+
460
+ container.innerHTML = `<iframe id="panorama-viewer" style="width: 100%; height: 480px; border: none;" src="${{viewerUrl}}"></iframe>`;
461
+
462
+ // 等待 iframe 加载完成
463
+ const iframe = document.getElementById('panorama-viewer');
464
+ iframe.onload = () => {{
465
+ // 使用本地默认全景图
466
+ const defaultImage = '{default_image_url}';
467
+
468
+ // 向 iframe 发送图片数据
469
+ iframe.contentWindow.postMessage({{
470
+ type: 'loadPanorama',
471
+ image: defaultImage
472
+ }}, '*');
473
+ }};
474
+ }}
475
+ }}
476
+ """
477
  )
478
 
479
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,6 +1,13 @@
1
  accelerate
 
 
2
  diffusers
3
- invisible_watermark
4
- torch
 
 
5
  transformers
6
- xformers
 
 
 
 
1
  accelerate
2
+ numpy
3
+ pillow
4
  diffusers
5
+ --index-url https://download.pytorch.org/whl/cu121
6
+ torch==2.4.0+cu121
7
+ torchvision
8
+ torchaudio
9
  transformers
10
+ xformers
11
+ realesrgan
12
+ py360convert
13
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp312-cp312-linux_x86_64.whl