import torch import gradio as gr from huggingface_hub import hf_hub_download from safetensors.torch import load_file from models.unet import PixelArtUNet from sampling.conditional_probability_path import GaussianConditionalProbabilityPath from sampling.noise_scheduling import LinearAlpha, LinearBeta from diff_eq.ode_sde import UnguidedVectorFieldODE from diff_eq.simulator import EulerSimulator from utils import tensor_to_rgba_image, normalize_to_unit, make_large device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup model model = PixelArtUNet( channels = [128, 256, 512, 1024], num_residual_layers = 2, t_embed_dim = 128, midcoder_dropout_p=0.2 ).to(device) repo_id = "mradovic38/sprite-flow" filename = "model.safetensors" file_path = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = load_file(file_path) model.load_state_dict(checkpoint) model.to(device) model.eval() def generate_image_stream(num_timesteps: int = 200): # Setup path path = GaussianConditionalProbabilityPath( p_data=None, p_simple_shape=[4, 128, 128], alpha=LinearAlpha(), beta=LinearBeta() ).to(device) path.eval() ts = torch.linspace(0, 1, num_timesteps).view(1, -1, 1, 1, 1).expand(1, -1, 1, 1, 1).to(device) x0 = path.p_simple.sample(1).to(device) # (1, 4, 128, 128) ode = UnguidedVectorFieldODE(model) simulator = EulerSimulator(ode) # Yield images at each step for x in simulator.simulate(x0, ts): img = normalize_to_unit(x) img = tensor_to_rgba_image(img)[0] yield make_large(img) # --- Create Gradio interface --- css = """ .gradio-container { max-width: 700px !important; margin: 0 auto !important; /* centers the container */ text-align: center; /* centers text and components inside */ } #component-0 { display: inline-block; /* make the image inline-block */ width: 100% !important; } """ iface = gr.Interface( fn=generate_image_stream, inputs=[gr.Slider(50, 500, step=10, value=200, label="Number of Steps")], outputs=gr.Image(type="pil", streaming=True), title="Flow-Matching Pixel Art Sprite Generation", description="Generate pixel art sprites with sprite-flow.", css=css ) if __name__ == "__main__": iface.launch()