Spaces:
Running
on
Zero
Running
on
Zero
correct sizing, and limit
Browse files
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
-
import random
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
| 6 |
import torch
|
|
@@ -24,63 +23,23 @@ model = AutoModel.from_pretrained(
|
|
| 24 |
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
|
| 25 |
|
| 26 |
MAX_SEED = np.iinfo(np.int16).max
|
| 27 |
-
|
| 28 |
DEFAULT_POSITIVE_PROMPT = None
|
| 29 |
DEFAULT_NEGATIVE_PROMPT = None
|
| 30 |
|
| 31 |
def _ensure_pil(x):
|
|
|
|
| 32 |
if isinstance(x, Image.Image):
|
| 33 |
return x
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
return Image.fromarray(x)
|
| 45 |
-
except Exception:
|
| 46 |
-
pass
|
| 47 |
-
raise TypeError("Unsupported image type returned by pipeline; expected PIL or array/torch image.")
|
| 48 |
-
|
| 49 |
-
def resize_to_target(img: Image.Image, tw: int, th: int, mode: str = "fit"):
|
| 50 |
-
"""Return a PIL image of exactly (tw, th) using the selected mode."""
|
| 51 |
-
mode = (mode or "fit").lower()
|
| 52 |
-
# safety
|
| 53 |
-
tw = int(max(1, tw))
|
| 54 |
-
th = int(max(1, th))
|
| 55 |
-
|
| 56 |
-
if mode == "stretch":
|
| 57 |
-
return img.resize((tw, th), resample=Image.Resampling.LANCZOS)
|
| 58 |
-
|
| 59 |
-
iw, ih = img.size
|
| 60 |
-
if iw == 0 or ih == 0:
|
| 61 |
-
return img
|
| 62 |
-
|
| 63 |
-
src_ratio = iw / ih
|
| 64 |
-
tgt_ratio = tw / th
|
| 65 |
-
|
| 66 |
-
if mode == "fill":
|
| 67 |
-
# scale so that image fully covers target, then center-crop
|
| 68 |
-
scale = max(tw / iw, th / ih)
|
| 69 |
-
nw, nh = int(round(iw * scale)), int(round(ih * scale))
|
| 70 |
-
resized = img.resize((nw, nh), resample=Image.Resampling.LANCZOS)
|
| 71 |
-
left = (nw - tw) // 2
|
| 72 |
-
top = (nh - th) // 2
|
| 73 |
-
return resized.crop((left, top, left + tw, top + th))
|
| 74 |
-
else:
|
| 75 |
-
# "fit": letterbox to target
|
| 76 |
-
scale = min(tw / iw, th / ih)
|
| 77 |
-
nw, nh = int(round(iw * scale)), int(round(ih * scale))
|
| 78 |
-
resized = img.resize((nw, nh), resample=Image.Resampling.LANCZOS)
|
| 79 |
-
canvas = Image.new("RGB", (tw, th), (0, 0, 0))
|
| 80 |
-
left = (tw - nw) // 2
|
| 81 |
-
top = (th - nh) // 2
|
| 82 |
-
canvas.paste(resized, (left, top))
|
| 83 |
-
return canvas
|
| 84 |
|
| 85 |
@spaces.GPU(duration=300)
|
| 86 |
def infer(
|
|
@@ -91,14 +50,13 @@ def infer(
|
|
| 91 |
num_inference_steps=28,
|
| 92 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
| 93 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
| 94 |
-
resize_mode="fit (letterbox)",
|
| 95 |
progress=gr.Progress(track_tqdm=True),
|
| 96 |
):
|
|
|
|
| 97 |
if prompt in [None, ""]:
|
| 98 |
gr.Warning("⚠️ Please enter a prompt!")
|
| 99 |
return None
|
| 100 |
|
| 101 |
-
# Generate at (height, width). Some models may return bucketed sizes.
|
| 102 |
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
|
| 103 |
imgs = pipeline.generate_image(
|
| 104 |
prompt,
|
|
@@ -116,23 +74,18 @@ def infer(
|
|
| 116 |
progress=True,
|
| 117 |
)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
# Force output to exactly Width x Height based on user preference
|
| 122 |
-
mode_key = "fit" if "fit" in resize_mode else ("fill" if "fill" in resize_mode else "stretch")
|
| 123 |
-
out = resize_to_target(img, int(width), int(height), mode=mode_key)
|
| 124 |
-
return out
|
| 125 |
|
| 126 |
css = """
|
| 127 |
#col-container {
|
| 128 |
margin: 0 auto;
|
| 129 |
-
max-width:
|
| 130 |
}
|
| 131 |
"""
|
| 132 |
|
| 133 |
with gr.Blocks(css=css) as demo:
|
| 134 |
with gr.Column(elem_id="col-container"):
|
| 135 |
-
gr.Markdown("# NextStep-1-Large —
|
| 136 |
|
| 137 |
with gr.Row():
|
| 138 |
prompt = gr.Text(
|
|
@@ -180,26 +133,19 @@ with gr.Blocks(css=css) as demo:
|
|
| 180 |
width = gr.Slider(
|
| 181 |
label="Width",
|
| 182 |
minimum=256,
|
| 183 |
-
maximum=
|
| 184 |
step=64,
|
| 185 |
-
value=
|
| 186 |
)
|
| 187 |
height = gr.Slider(
|
| 188 |
label="Height",
|
| 189 |
minimum=256,
|
| 190 |
-
maximum=
|
| 191 |
step=64,
|
| 192 |
-
value=
|
| 193 |
)
|
| 194 |
-
resize_mode = gr.Radio(
|
| 195 |
-
label="Resize mode (final output)",
|
| 196 |
-
choices=["fit (letterbox)", "fill (center-crop)", "stretch"],
|
| 197 |
-
value="fit (letterbox)",
|
| 198 |
-
)
|
| 199 |
|
| 200 |
with gr.Row():
|
| 201 |
-
# Remove fixed height so the component can display any size; it will scale in the UI,
|
| 202 |
-
# but the returned image file is exactly width x height.
|
| 203 |
result_1 = gr.Image(
|
| 204 |
label="Result",
|
| 205 |
show_label=True,
|
|
@@ -208,29 +154,25 @@ with gr.Blocks(css=css) as demo:
|
|
| 208 |
format="png",
|
| 209 |
)
|
| 210 |
|
| 211 |
-
#
|
| 212 |
examples = [
|
| 213 |
-
# [prompt, seed, width, height, steps, positive, negative, resize_mode]
|
| 214 |
[
|
| 215 |
-
"
|
| 216 |
-
|
| 217 |
-
"
|
| 218 |
-
"
|
| 219 |
-
"fit (letterbox)",
|
| 220 |
],
|
| 221 |
[
|
| 222 |
-
"
|
| 223 |
-
|
| 224 |
-
"
|
| 225 |
-
"
|
| 226 |
-
"fill (center-crop)",
|
| 227 |
],
|
| 228 |
[
|
| 229 |
-
"
|
| 230 |
-
|
| 231 |
-
"
|
| 232 |
-
"
|
| 233 |
-
"stretch",
|
| 234 |
],
|
| 235 |
]
|
| 236 |
|
|
@@ -244,9 +186,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 244 |
num_inference_steps,
|
| 245 |
positive_prompt,
|
| 246 |
negative_prompt,
|
| 247 |
-
resize_mode,
|
| 248 |
],
|
| 249 |
-
label="Click & Fill Examples",
|
| 250 |
)
|
| 251 |
|
| 252 |
def show_result():
|
|
@@ -263,7 +204,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 263 |
num_inference_steps,
|
| 264 |
positive_prompt,
|
| 265 |
negative_prompt,
|
| 266 |
-
resize_mode,
|
| 267 |
],
|
| 268 |
outputs=[result_1],
|
| 269 |
)
|
|
@@ -271,5 +211,4 @@ with gr.Blocks(css=css) as demo:
|
|
| 271 |
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
|
| 272 |
|
| 273 |
if __name__ == "__main__":
|
| 274 |
-
# Set share=True if you want a public link
|
| 275 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
import spaces
|
| 4 |
from PIL import Image
|
| 5 |
import torch
|
|
|
|
| 23 |
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
|
| 24 |
|
| 25 |
MAX_SEED = np.iinfo(np.int16).max
|
|
|
|
| 26 |
DEFAULT_POSITIVE_PROMPT = None
|
| 27 |
DEFAULT_NEGATIVE_PROMPT = None
|
| 28 |
|
| 29 |
def _ensure_pil(x):
|
| 30 |
+
"""Ensure returned image is a PIL.Image.Image."""
|
| 31 |
if isinstance(x, Image.Image):
|
| 32 |
return x
|
| 33 |
+
import numpy as np
|
| 34 |
+
if hasattr(x, "detach"):
|
| 35 |
+
x = x.detach().float().clamp(0, 1).cpu().numpy()
|
| 36 |
+
if isinstance(x, np.ndarray):
|
| 37 |
+
if x.dtype != np.uint8:
|
| 38 |
+
x = (x * 255.0).clip(0, 255).astype(np.uint8)
|
| 39 |
+
if x.ndim == 3 and x.shape[0] in (1,3,4): # CHW -> HWC
|
| 40 |
+
x = np.moveaxis(x, 0, -1)
|
| 41 |
+
return Image.fromarray(x)
|
| 42 |
+
raise TypeError("Unsupported image type returned by pipeline.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
@spaces.GPU(duration=300)
|
| 45 |
def infer(
|
|
|
|
| 50 |
num_inference_steps=28,
|
| 51 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
| 52 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
|
|
|
| 53 |
progress=gr.Progress(track_tqdm=True),
|
| 54 |
):
|
| 55 |
+
"""Run inference at exactly (width, height)."""
|
| 56 |
if prompt in [None, ""]:
|
| 57 |
gr.Warning("⚠️ Please enter a prompt!")
|
| 58 |
return None
|
| 59 |
|
|
|
|
| 60 |
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
|
| 61 |
imgs = pipeline.generate_image(
|
| 62 |
prompt,
|
|
|
|
| 74 |
progress=True,
|
| 75 |
)
|
| 76 |
|
| 77 |
+
return _ensure_pil(imgs[0]) # Return raw output exactly as generated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
css = """
|
| 80 |
#col-container {
|
| 81 |
margin: 0 auto;
|
| 82 |
+
max-width: 800px;
|
| 83 |
}
|
| 84 |
"""
|
| 85 |
|
| 86 |
with gr.Blocks(css=css) as demo:
|
| 87 |
with gr.Column(elem_id="col-container"):
|
| 88 |
+
gr.Markdown("# NextStep-1-Large — Exact Output Size")
|
| 89 |
|
| 90 |
with gr.Row():
|
| 91 |
prompt = gr.Text(
|
|
|
|
| 133 |
width = gr.Slider(
|
| 134 |
label="Width",
|
| 135 |
minimum=256,
|
| 136 |
+
maximum=512,
|
| 137 |
step=64,
|
| 138 |
+
value=512,
|
| 139 |
)
|
| 140 |
height = gr.Slider(
|
| 141 |
label="Height",
|
| 142 |
minimum=256,
|
| 143 |
+
maximum=512,
|
| 144 |
step=64,
|
| 145 |
+
value=512,
|
| 146 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
with gr.Row():
|
|
|
|
|
|
|
| 149 |
result_1 = gr.Image(
|
| 150 |
label="Result",
|
| 151 |
show_label=True,
|
|
|
|
| 154 |
format="png",
|
| 155 |
)
|
| 156 |
|
| 157 |
+
# Click & Fill Examples (all <=512px)
|
| 158 |
examples = [
|
|
|
|
| 159 |
[
|
| 160 |
+
"A cozy wooden cabin by a frozen lake, northern lights in the sky",
|
| 161 |
+
123, 512, 512, 28,
|
| 162 |
+
"photorealistic, cinematic lighting, starry night, glowing reflections",
|
| 163 |
+
"low-res, distorted, extra objects"
|
|
|
|
| 164 |
],
|
| 165 |
[
|
| 166 |
+
"Futuristic city skyline at sunset, flying cars, neon reflections",
|
| 167 |
+
456, 512, 384, 30,
|
| 168 |
+
"detailed, vibrant, cinematic, sharp edges",
|
| 169 |
+
"washed out, cartoon, blurry"
|
|
|
|
| 170 |
],
|
| 171 |
[
|
| 172 |
+
"Close-up of a rare orchid in a greenhouse with soft morning light",
|
| 173 |
+
789, 384, 512, 32,
|
| 174 |
+
"macro lens effect, ultra-detailed petals, dew drops",
|
| 175 |
+
"grainy, noisy, oversaturated"
|
|
|
|
| 176 |
],
|
| 177 |
]
|
| 178 |
|
|
|
|
| 186 |
num_inference_steps,
|
| 187 |
positive_prompt,
|
| 188 |
negative_prompt,
|
|
|
|
| 189 |
],
|
| 190 |
+
label="Click & Fill Examples (Exact Size)",
|
| 191 |
)
|
| 192 |
|
| 193 |
def show_result():
|
|
|
|
| 204 |
num_inference_steps,
|
| 205 |
positive_prompt,
|
| 206 |
negative_prompt,
|
|
|
|
| 207 |
],
|
| 208 |
outputs=[result_1],
|
| 209 |
)
|
|
|
|
| 211 |
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
|
| 212 |
|
| 213 |
if __name__ == "__main__":
|
|
|
|
| 214 |
demo.launch()
|