ysda-cycle-gan / app.py
igor-saprygin's picture
fix
35552ef
import streamlit as st
import torch
import numpy as np
from torchvision import transforms as tr
from PIL import Image
from huggingface_hub import hf_hub_download
from cycle_gan import CycleGAN, create_model_and_optimizer, val_transform_a, val_transform_b, de_normalize_a, de_normalize_b
@st.cache_resource # кэширование
def load_model():
model, _, _ = create_model_and_optimizer(
model_class = CycleGAN,
model_params = {'c_in':3, 'c_out':3},
lr = 0.0002,
device = 'cpu')
model_path = hf_hub_download(repo_id="igor-saprygin/ysda-anime-gan", filename="anime_gan#3.pt")
model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=False)['model_state_dict'])
return model
model = load_model()
st.title("Anime sketch <--> Colored Art Converter")
# выбор опции выпадающим меню
style = st.selectbox(
"Select style",
['SketchToColored', 'ColoredToSketch'],
)
query = 'sketch' if style == 'SketchToColored' else 'colored art'
uploaded_file = st.file_uploader(f"Upload your {query}", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
st.session_state['uploaded_file'] = uploaded_file
uploaded_file = st.session_state.get('uploaded_file', None)
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
image_size = image.size
# image = image.resize((256,256))
image = tr.CenterCrop(min(image.size))(image)
image = np.array(image)
model.eval()
with torch.no_grad():
if style == 'ColoredToSketch':
image = val_transform_a(image)
image = image.to(dtype = torch.float, device='cpu')
generation = model.GA(image.view(1, *image.shape)).detach().squeeze()
generation = de_normalize_b(generation)
else:
image = val_transform_b(image)
image = image.to(dtype = torch.float, device='cpu')
generation = model.GB(image.unsqueeze(0)).squeeze().detach()
generation = de_normalize_a(generation)
resized = generation.resize(image_size, Image.Resampling.LANCZOS)
st.image(resized, caption="Your result!", use_container_width=True)