Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from models.unet import PixelArtUNet | |
| from sampling.conditional_probability_path import GaussianConditionalProbabilityPath | |
| from sampling.noise_scheduling import LinearAlpha, LinearBeta | |
| from diff_eq.ode_sde import UnguidedVectorFieldODE | |
| from diff_eq.simulator import EulerSimulator | |
| from utils import tensor_to_rgba_image, normalize_to_unit, make_large | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Setup model | |
| model = PixelArtUNet( | |
| channels = [128, 256, 512, 1024], | |
| num_residual_layers = 2, | |
| t_embed_dim = 128, | |
| midcoder_dropout_p=0.2 | |
| ).to(device) | |
| repo_id = "mradovic38/sprite-flow" | |
| filename = "model.safetensors" | |
| file_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| checkpoint = load_file(file_path) | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| model.eval() | |
| def generate_image_stream(num_timesteps: int = 200): | |
| # Setup path | |
| path = GaussianConditionalProbabilityPath( | |
| p_data=None, | |
| p_simple_shape=[4, 128, 128], | |
| alpha=LinearAlpha(), | |
| beta=LinearBeta() | |
| ).to(device) | |
| path.eval() | |
| ts = torch.linspace(0, 1, num_timesteps).view(1, -1, 1, 1, 1).expand(1, -1, 1, 1, 1).to(device) | |
| x0 = path.p_simple.sample(1).to(device) # (1, 4, 128, 128) | |
| ode = UnguidedVectorFieldODE(model) | |
| simulator = EulerSimulator(ode) | |
| # Yield images at each step | |
| for x in simulator.simulate(x0, ts): | |
| img = normalize_to_unit(x) | |
| img = tensor_to_rgba_image(img)[0] | |
| yield make_large(img) | |
| # --- Create Gradio interface --- | |
| css = """ | |
| .gradio-container { | |
| max-width: 700px !important; | |
| margin: 0 auto !important; /* centers the container */ | |
| text-align: center; /* centers text and components inside */ | |
| } | |
| #component-0 { | |
| display: inline-block; /* make the image inline-block */ | |
| width: 100% !important; | |
| } | |
| """ | |
| iface = gr.Interface( | |
| fn=generate_image_stream, | |
| inputs=[gr.Slider(50, 500, step=10, value=200, label="Number of Steps")], | |
| outputs=gr.Image(type="pil", streaming=True), | |
| title="Flow-Matching Pixel Art Sprite Generation", | |
| description="Generate pixel art sprites with sprite-flow.", | |
| css=css | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |