MiniT2I / app.py
multimodalart's picture
multimodalart HF Staff
Move Citrus theme to launch() (Gradio 6)
c930678 verified
Raw
History Blame Contribute Delete
4.1 kB
import random
from pathlib import Path
import spaces
import torch
import gradio as gr
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, T5EncoderModel
from minit2i_pipeline import (
MiniT2IMMJiTModel,
MiniT2IFlowMatchScheduler,
MiniT2ITextToImagePipeline,
)
REPO_ID = "MiniT2I/MiniT2I"
TEXT_ENCODER = "google/flan-t5-large"
DTYPE = torch.bfloat16
MAX_SEED = 2**31 - 1
MODELS = {
"MiniT2I-B/16 (base)": "minit2i-b-16",
"MiniT2I-L/16 (large)": "minit2i-l-16",
}
root = Path(snapshot_download(REPO_ID))
# Shared T5 text encoder — both variants use google/flan-t5-large
tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER)
text_encoder = T5EncoderModel.from_pretrained(
TEXT_ENCODER, torch_dtype=torch.float32
).to("cuda")
def _load(model_dir):
transformer = MiniT2IMMJiTModel.from_pretrained(
root / model_dir / "transformer", torch_dtype=DTYPE
)
scheduler = MiniT2IFlowMatchScheduler.from_pretrained(root / model_dir / "scheduler")
pipe = MiniT2ITextToImagePipeline(
transformer=transformer,
scheduler=scheduler,
tokenizer=tokenizer,
text_encoder=text_encoder,
text_encoder_name=TEXT_ENCODER,
)
pipe.to("cuda")
return pipe
# Preload BOTH models in the global context
PIPES = {label: _load(model_dir) for label, model_dir in MODELS.items()}
@spaces.GPU(duration=60)
def generate(prompt, model_label, guidance_scale, num_inference_steps, seed, randomize_seed,
progress=gr.Progress(track_tqdm=True)):
if not prompt or not prompt.strip():
raise gr.Error("Please enter a prompt.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(int(seed))
pipe = PIPES[model_label]
image = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=int(num_inference_steps),
generator=generator,
progress=True,
).images[0]
return image, seed
EXAMPLES = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A lonely astronaut standing on a quiet beach under two moons.",
"A cozy cabin in a snowy forest at dusk, warm light in the windows.",
"A bowl of ramen with steam rising, photorealistic, top-down view.",
]
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🍊 MiniT2I
A minimalist text-to-image model — pick the **B/16** or **L/16** variant below.
Both are preloaded. Images are generated at 512×512.
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate…",
lines=3,
)
model_label = gr.Radio(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
)
run_btn = gr.Button("Generate", variant="primary")
with gr.Accordion("Advanced settings", open=False):
guidance_scale = gr.Slider(
label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5
)
num_inference_steps = gr.Slider(
label="Inference steps", minimum=10, maximum=150, step=1, value=100
)
with gr.Row():
seed = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Column(scale=1):
output = gr.Image(label="Result", height=512)
gr.Examples(examples=EXAMPLES, inputs=prompt)
inputs = [prompt, model_label, guidance_scale, num_inference_steps, seed, randomize_seed]
run_btn.click(generate, inputs=inputs, outputs=[output, seed])
prompt.submit(generate, inputs=inputs, outputs=[output, seed])
demo.launch(theme=gr.themes.Citrus())