Spaces:
Runtime error
Runtime error
File size: 2,388 Bytes
a0ae60c 8ccf632 a0ae60c 81b26b5 06f0278 8ccf632 4ea3b6f 8ccf632 06f0278 8ccf632 54192f0 8ccf632 a0ae60c 06f0278 a0ae60c e0a4dd7 a0ae60c 8ccf632 a32bb1a 8ccf632 a0ae60c 8ccf632 e2944a6 8ccf632 a0ae60c 6ebb7df 4ea3b6f a0ae60c 8ccf632 a0ae60c 8ccf632 a0ae60c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import random
import gradio as gr
import numpy as np
import spaces
import torch
from dataset_viber import CollectorInterface
from diffusers import DiffusionPipeline
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
@spaces.GPU()
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
return image
examples = [
["a tiny astronaut hatching from an egg on the moon", 0, True, 1024, 1024, 4],
["a cat holding a sign that says hello world", 0, True, 1024, 1024, 4],
["an anime illustration of a wiener schnitzel", 0, True, 1024, 1024, 4],
]
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
description = """# FLUX.1 [schnell]
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
"""
interface = CollectorInterface(
fn=infer,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt")
],
outputs=[
gr.Image(label="Result"),
],
additional_inputs=[
gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0),
gr.Checkbox(label="Randomize seed", value=True),
gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024),
gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024),
gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4),
],
title="FLUX.1 [schnell] - with Dataset Viber data collection",
description=description,
examples=examples,
css=css,
dataset_name="image-generation-flux1-schnell"
)
interface.launch() |