Insta360-Research commited on
Commit
b8a21b1
·
verified ·
1 Parent(s): e3b58f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -24
app.py CHANGED
@@ -8,10 +8,7 @@ import spaces
8
  from src.pipeline import DiT360Pipeline
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- if torch.cuda.is_available():
12
- torch_dtype = torch.float16
13
- else:
14
- torch_dtype = torch.float32
15
 
16
  model_repo = "black-forest-labs/FLUX.1-dev"
17
  lora_weights = "fenghora/DiT360-Panorama-Image-Generation"
@@ -26,17 +23,13 @@ MAX_IMAGE_SIZE = 2048
26
  def infer(
27
  prompt,
28
  seed,
29
- randomize_seed,
30
  width,
31
  guidance_scale,
32
  num_inference_steps,
33
  progress=gr.Progress(track_tqdm=True),
34
  ):
35
- if randomize_seed:
36
- seed = random.randint(0, MAX_SEED)
37
-
38
  height = width // 2
39
- generator = torch.Generator(device=device).manual_seed(seed)
40
 
41
  full_prompt = f"This is a panorama. The images shows {prompt.strip()}"
42
 
@@ -50,8 +43,10 @@ def infer(
50
  ).images[0]
51
  image.save("test.png")
52
 
53
- return image, seed
54
 
 
 
55
 
56
  examples = [
57
  "A medieval castle stands proudly on a hilltop surrounded by autumn forests, with golden light spilling across the landscape.",
@@ -60,7 +55,6 @@ examples = [
60
  "A snowy mountain village under northern lights, with cozy cabins and smoke rising from chimneys.",
61
  ]
62
 
63
-
64
  css = """
65
  #main-container {
66
  display: flex;
@@ -70,24 +64,20 @@ css = """
70
  gap: 2rem;
71
  margin-top: 1rem;
72
  }
73
-
74
  #image-panel {
75
- flex: 2; /* 占2/3 */
76
  max-width: 900px;
77
  margin: 0 auto;
78
  }
79
-
80
  #settings-panel {
81
- flex: 1; /* 占1/3 */
82
  max-width: 280px;
83
  }
84
-
85
  #prompt-box textarea {
86
- resize: none !important; /* 去掉上下箭头 */
87
  }
88
  """
89
 
90
-
91
  with gr.Blocks(css=css) as demo:
92
  gr.Markdown("# 🌀 DiT360: High-Fidelity Panoramic Image Generation")
93
  gr.Markdown("Official Gradio demo for **DiT360**, a panoramic image generation model based on hybrid training.")
@@ -115,10 +105,10 @@ with gr.Blocks(css=css) as demo:
115
  "The height is automatically set to half the width (2:1 aspect ratio)."
116
  )
117
 
118
- seed = gr.Slider(0, MAX_SEED, value=0, step=1, label="Seed")
119
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
120
 
121
- width = gr.Slider(1024, MAX_IMAGE_SIZE, value=2048, step=64, label="Width (fixed 2:1)")
122
  height_display = gr.Number(value=1024, label="Height", interactive=False)
123
 
124
  guidance_scale = gr.Slider(0.0, 10.0, value=2.8, step=0.1, label="Guidance Scale")
@@ -128,6 +118,7 @@ with gr.Blocks(css=css) as demo:
128
  return width // 2
129
 
130
  width.change(fn=update_height, inputs=width, outputs=height_display)
 
131
 
132
  gr.Markdown(
133
  "💡 *Tip: Try descriptive prompts like “A mountain village at sunrise with mist over the valley.” "
@@ -137,10 +128,9 @@ with gr.Blocks(css=css) as demo:
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
- inputs=[prompt, seed, randomize_seed, width, guidance_scale, num_inference_steps],
141
- outputs=[result, seed],
142
  )
143
 
144
-
145
  if __name__ == "__main__":
146
  demo.launch()
 
8
  from src.pipeline import DiT360Pipeline
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
12
 
13
  model_repo = "black-forest-labs/FLUX.1-dev"
14
  lora_weights = "fenghora/DiT360-Panorama-Image-Generation"
 
23
  def infer(
24
  prompt,
25
  seed,
 
26
  width,
27
  guidance_scale,
28
  num_inference_steps,
29
  progress=gr.Progress(track_tqdm=True),
30
  ):
 
 
 
31
  height = width // 2
32
+ generator = torch.Generator(device=device).manual_seed(int(seed))
33
 
34
  full_prompt = f"This is a panorama. The images shows {prompt.strip()}"
35
 
 
43
  ).images[0]
44
  image.save("test.png")
45
 
46
+ return image
47
 
48
+ def generate_seed():
49
+ return random.randint(0, MAX_SEED)
50
 
51
  examples = [
52
  "A medieval castle stands proudly on a hilltop surrounded by autumn forests, with golden light spilling across the landscape.",
 
55
  "A snowy mountain village under northern lights, with cozy cabins and smoke rising from chimneys.",
56
  ]
57
 
 
58
  css = """
59
  #main-container {
60
  display: flex;
 
64
  gap: 2rem;
65
  margin-top: 1rem;
66
  }
 
67
  #image-panel {
68
+ flex: 2;
69
  max-width: 900px;
70
  margin: 0 auto;
71
  }
 
72
  #settings-panel {
73
+ flex: 1;
74
  max-width: 280px;
75
  }
 
76
  #prompt-box textarea {
77
+ resize: none !important;
78
  }
79
  """
80
 
 
81
  with gr.Blocks(css=css) as demo:
82
  gr.Markdown("# 🌀 DiT360: High-Fidelity Panoramic Image Generation")
83
  gr.Markdown("Official Gradio demo for **DiT360**, a panoramic image generation model based on hybrid training.")
 
105
  "The height is automatically set to half the width (2:1 aspect ratio)."
106
  )
107
 
108
+ seed_display = gr.Number(value=0, label="Seed", interactive=True)
109
+ random_seed_button = gr.Button("🎲 Random Seed")
110
 
111
+ width = gr.Slider(1024, MAX_IMAGE_SIZE, value=2048, step=64, label="Width")
112
  height_display = gr.Number(value=1024, label="Height", interactive=False)
113
 
114
  guidance_scale = gr.Slider(0.0, 10.0, value=2.8, step=0.1, label="Guidance Scale")
 
118
  return width // 2
119
 
120
  width.change(fn=update_height, inputs=width, outputs=height_display)
121
+ random_seed_button.click(fn=generate_seed, inputs=[], outputs=seed_display)
122
 
123
  gr.Markdown(
124
  "💡 *Tip: Try descriptive prompts like “A mountain village at sunrise with mist over the valley.” "
 
128
  gr.on(
129
  triggers=[run_button.click, prompt.submit],
130
  fn=infer,
131
+ inputs=[prompt, seed_display, width, guidance_scale, num_inference_steps],
132
+ outputs=[result],
133
  )
134
 
 
135
  if __name__ == "__main__":
136
  demo.launch()