Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from .model_loader import Model | |
| from PIL import Image | |
| import cv2 | |
| import io | |
| # models fron pretrained/latent_transformer folder | |
| models_files = { | |
| "anime": "pretrained_models/latent_transformer/anime.pt", | |
| "car": "pretrained_models/latent_transformer/car.pt", | |
| "cat": "pretrained_models/latent_transformer/cat.pt", | |
| "church": "pretrained_models/latent_transformer/church.pt", | |
| "ffhq": "pretrained_models/latent_transformer/ffhq.pt", | |
| } | |
| models = {name: Model(path) for name, path in models_files.items()} | |
| def cv_to_pil(img): | |
| return Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB)) | |
| def random_sample(model_name: str): | |
| model = models[model_name] | |
| img, latents = model.random_sample() | |
| pil_img = cv_to_pil(img) | |
| return pil_img, model_name, latents | |
| def zoom(dx, dy, dz, model_state, latents_state): | |
| model = models[model_state] | |
| dx = dx | |
| dy = dy | |
| dz = dz | |
| sx = 100 | |
| sy = 100 | |
| stop_points = [] | |
| img, latents_state = model.zoom( | |
| latents_state, dz, sxsy=[sx, sy], stop_points=stop_points | |
| ) # dz, sxsy=[sx, sy], stop_points=stop_points) | |
| pil_img = cv_to_pil(img) | |
| return pil_img, latents_state | |
| def translate(dx, dy, dz, model_state, latents_state): | |
| model = models[model_state] | |
| dx = dx | |
| dy = dy | |
| dz = dz | |
| sx = 128 | |
| sy = 128 | |
| stop_points = [] | |
| zi = False | |
| zo = False | |
| img, latents_state = model.translate( | |
| latents_state, | |
| [dx, dy], | |
| sxsy=[sx, sy], | |
| stop_points=stop_points, | |
| zoom_in=zi, | |
| zoom_out=zo, | |
| ) | |
| pil_img = cv_to_pil(img) | |
| return pil_img, latents_state | |
| def change_style(image: Image.Image, model_state, latents_state): | |
| model = models[model_state] | |
| img, latents_state = model.change_style(latents_state) | |
| pil_img = cv_to_pil(img) | |
| return pil_img, latents_state | |
| def reset(model_state, latents_state): | |
| model = models[model_state] | |
| img, latents_state = model.reset(latents_state) | |
| pil_img = cv_to_pil(img) | |
| return pil_img, latents_state | |
| with gr.Blocks() as block: | |
| model_state = gr.State(value="cat") | |
| latents_state = gr.State({}) | |
| gr.Markdown("# UserControllableLT: User controllable latent transformer") | |
| gr.Markdown("## Select model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name = gr.Dropdown( | |
| choices=list(models_files.keys()), | |
| label="Select Pretrained Model", | |
| value="cat", | |
| ) | |
| with gr.Row(): | |
| button = gr.Button("Random sample") | |
| reset_btn = gr.Button("Reset") | |
| dx = gr.Slider( | |
| minimum=-128, maximum=128, step_size=0.1, label="dx", value=0.0 | |
| ) | |
| dy = gr.Slider( | |
| minimum=-128, maximum=128, step_size=0.1, label="dy", value=0.0 | |
| ) | |
| dz = gr.Slider( | |
| minimum=-128, maximum=128, step_size=0.1, label="dz", value=0.0 | |
| ) | |
| with gr.Row(): | |
| change_style_bt = gr.Button("Change style") | |
| with gr.Column(): | |
| image = gr.Image(type="pil", label="") | |
| button.click( | |
| random_sample, inputs=[model_name], outputs=[image, model_state, latents_state] | |
| ) | |
| reset_btn.click( | |
| reset, | |
| inputs=[model_state, latents_state], | |
| outputs=[image, latents_state], | |
| ) | |
| change_style_bt.click( | |
| change_style, | |
| inputs=[image, model_state, latents_state], | |
| outputs=[image, latents_state], | |
| ) | |
| dx.change( | |
| translate, | |
| inputs=[dx, dy, dz, model_state, latents_state], | |
| outputs=[image, latents_state], | |
| show_progress=False, | |
| ) | |
| dy.change( | |
| translate, | |
| inputs=[dx, dy, dz, model_state, latents_state], | |
| outputs=[image, latents_state], | |
| show_progress=False, | |
| ) | |
| dz.change( | |
| zoom, | |
| inputs=[dx, dy, dz, model_state, latents_state], | |
| outputs=[image, latents_state], | |
| show_progress=False, | |
| ) | |
| block.launch() | |