bmarci's picture
correct sizing, and limit
3669017
raw
history blame
6.7 kB
import gradio as gr
import numpy as np
import spaces
from PIL import Image
import torch
from torch.amp import autocast
from transformers import AutoTokenizer, AutoModel
from models.gen_pipeline import NextStepPipeline
HF_HUB = "stepfun-ai/NextStep-1-Large"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
model = AutoModel.from_pretrained(
HF_HUB,
local_files_only=False,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(device)
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
MAX_SEED = np.iinfo(np.int16).max
DEFAULT_POSITIVE_PROMPT = None
DEFAULT_NEGATIVE_PROMPT = None
def _ensure_pil(x):
"""Ensure returned image is a PIL.Image.Image."""
if isinstance(x, Image.Image):
return x
import numpy as np
if hasattr(x, "detach"):
x = x.detach().float().clamp(0, 1).cpu().numpy()
if isinstance(x, np.ndarray):
if x.dtype != np.uint8:
x = (x * 255.0).clip(0, 255).astype(np.uint8)
if x.ndim == 3 and x.shape[0] in (1,3,4): # CHW -> HWC
x = np.moveaxis(x, 0, -1)
return Image.fromarray(x)
raise TypeError("Unsupported image type returned by pipeline.")
@spaces.GPU(duration=300)
def infer(
prompt=None,
seed=0,
width=512,
height=512,
num_inference_steps=28,
positive_prompt=DEFAULT_POSITIVE_PROMPT,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
progress=gr.Progress(track_tqdm=True),
):
"""Run inference at exactly (width, height)."""
if prompt in [None, ""]:
gr.Warning("⚠️ Please enter a prompt!")
return None
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
imgs = pipeline.generate_image(
prompt,
hw=(int(height), int(width)),
num_images_per_caption=1,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
cfg=7.5,
cfg_img=1.0,
cfg_schedule="constant",
use_norm=False,
num_sampling_steps=int(num_inference_steps),
timesteps_shift=1.0,
seed=int(seed),
progress=True,
)
return _ensure_pil(imgs[0]) # Return raw output exactly as generated
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# NextStep-1-Large — Exact Output Size")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
cancel_button = gr.Button("Cancel", scale=0, variant="secondary")
with gr.Row():
with gr.Accordion("Advanced Settings", open=True):
positive_prompt = gr.Text(
label="Positive Prompt",
show_label=True,
max_lines=1,
placeholder="Optional: add positives",
container=True,
)
negative_prompt = gr.Text(
label="Negative Prompt",
show_label=True,
max_lines=2,
placeholder="Optional: add negatives",
container=True,
)
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=3407,
)
num_inference_steps = gr.Slider(
label="Sampling steps",
minimum=10,
maximum=50,
step=1,
value=28,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=512,
step=64,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=512,
step=64,
value=512,
)
with gr.Row():
result_1 = gr.Image(
label="Result",
show_label=True,
container=True,
interactive=False,
format="png",
)
# Click & Fill Examples (all <=512px)
examples = [
[
"A cozy wooden cabin by a frozen lake, northern lights in the sky",
123, 512, 512, 28,
"photorealistic, cinematic lighting, starry night, glowing reflections",
"low-res, distorted, extra objects"
],
[
"Futuristic city skyline at sunset, flying cars, neon reflections",
456, 512, 384, 30,
"detailed, vibrant, cinematic, sharp edges",
"washed out, cartoon, blurry"
],
[
"Close-up of a rare orchid in a greenhouse with soft morning light",
789, 384, 512, 32,
"macro lens effect, ultra-detailed petals, dew drops",
"grainy, noisy, oversaturated"
],
]
gr.Examples(
examples=examples,
inputs=[
prompt,
seed,
width,
height,
num_inference_steps,
positive_prompt,
negative_prompt,
],
label="Click & Fill Examples (Exact Size)",
)
def show_result():
return gr.update(visible=True)
generation_event = gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
seed,
width,
height,
num_inference_steps,
positive_prompt,
negative_prompt,
],
outputs=[result_1],
)
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
if __name__ == "__main__":
demo.launch()