| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| def Generator(): | |
| up_conv_block = lambda c_in, c_out: [ | |
| nn.Upsample(None, 2, 'bilinear'), | |
| nn.LeakyReLU(0.1, True), | |
| nn.Conv2d(c_in, c_in, 3, 1, 1), | |
| nn.LeakyReLU(0.1, True), | |
| nn.Conv2d(c_in, c_in, 3, 1, 1), | |
| nn.LeakyReLU(0.1, True), | |
| nn.Conv2d(c_in, c_out, 3, 1, 1), | |
| ] | |
| return nn.Sequential( | |
| nn.Linear(256, 1024), | |
| nn.LeakyReLU(0.1, True), | |
| nn.Linear(1024, 9216), | |
| nn.LayerNorm(9216, 1e-6, False), | |
| nn.Unflatten(1, (1024, 3, 3)), | |
| *up_conv_block(1024, 512), | |
| *up_conv_block(512, 256), | |
| *up_conv_block(256, 128), | |
| *up_conv_block(128, 64), | |
| *up_conv_block(64, 3), | |
| nn.Sigmoid(), | |
| ) | |
| model = Generator().requires_grad_(False).eval() | |
| model.load_state_dict(torch.load('weights.pt')) | |
| p = 2147483647 | |
| def gen(state): | |
| state = max(round(state), 1) | |
| x = torch.empty(1, 256, dtype=torch.float64) | |
| for i in range(256): | |
| state = state * 48271 % p | |
| x[0, i] = float(state) / p | |
| x = torch.special.ndtri(x).float() | |
| y = model(x).mul(255).round().byte() | |
| img = y[0].permute(1, 2, 0).numpy() | |
| return state, img | |
| with gr.Blocks() as demo: | |
| state_slider = gr.Slider(1, p - 1, 4, step=1, label='PRNG State') | |
| img_output = gr.Image(label="Generated Image", format='png') | |
| click_btn = gr.Button('Generate') | |
| click_btn.click(fn=gen, inputs=state_slider, outputs=[state_slider, img_output]) | |
| demo.launch() |