import gradio as gr import torch import torch.nn as nn def Generator(): up_conv_block = lambda c_in, c_out: [ nn.Upsample(None, 2, 'bilinear'), nn.LeakyReLU(0.1, True), nn.Conv2d(c_in, c_in, 3, 1, 1), nn.LeakyReLU(0.1, True), nn.Conv2d(c_in, c_in, 3, 1, 1), nn.LeakyReLU(0.1, True), nn.Conv2d(c_in, c_out, 3, 1, 1), ] return nn.Sequential( nn.Linear(256, 1024), nn.LeakyReLU(0.1, True), nn.Linear(1024, 9216), nn.LayerNorm(9216, 1e-6, False), nn.Unflatten(1, (1024, 3, 3)), *up_conv_block(1024, 512), *up_conv_block(512, 256), *up_conv_block(256, 128), *up_conv_block(128, 64), *up_conv_block(64, 3), nn.Sigmoid(), ) model = Generator().requires_grad_(False).eval() model.load_state_dict(torch.load('weights.pt')) p = 2147483647 def gen(state): state = max(round(state), 1) x = torch.empty(1, 256, dtype=torch.float64) for i in range(256): state = state * 48271 % p x[0, i] = float(state) / p x = torch.special.ndtri(x).float() y = model(x).mul(255).round().byte() img = y[0].permute(1, 2, 0).numpy() return state, img with gr.Blocks() as demo: state_slider = gr.Slider(1, p - 1, 4, step=1, label='PRNG State') img_output = gr.Image(label="Generated Image", format='png') click_btn = gr.Button('Generate') click_btn.click(fn=gen, inputs=state_slider, outputs=[state_slider, img_output]) demo.launch()