Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms as tr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from cycle_gan import CycleGAN, create_model_and_optimizer, val_transform_a, val_transform_b, de_normalize_a, de_normalize_b | |
| # кэширование | |
| def load_model(): | |
| model, _, _ = create_model_and_optimizer( | |
| model_class = CycleGAN, | |
| model_params = {'c_in':3, 'c_out':3}, | |
| lr = 0.0002, | |
| device = 'cpu') | |
| model_path = hf_hub_download(repo_id="igor-saprygin/ysda-anime-gan", filename="anime_gan#3.pt") | |
| model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=False)['model_state_dict']) | |
| return model | |
| model = load_model() | |
| st.title("Anime sketch <--> Colored Art Converter") | |
| # выбор опции выпадающим меню | |
| style = st.selectbox( | |
| "Select style", | |
| ['SketchToColored', 'ColoredToSketch'], | |
| ) | |
| query = 'sketch' if style == 'SketchToColored' else 'colored art' | |
| uploaded_file = st.file_uploader(f"Upload your {query}", type=["png", "jpg", "jpeg"]) | |
| if uploaded_file is not None: | |
| st.session_state['uploaded_file'] = uploaded_file | |
| uploaded_file = st.session_state.get('uploaded_file', None) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| image_size = image.size | |
| # image = image.resize((256,256)) | |
| image = tr.CenterCrop(min(image.size))(image) | |
| image = np.array(image) | |
| model.eval() | |
| with torch.no_grad(): | |
| if style == 'ColoredToSketch': | |
| image = val_transform_a(image) | |
| image = image.to(dtype = torch.float, device='cpu') | |
| generation = model.GA(image.view(1, *image.shape)).detach().squeeze() | |
| generation = de_normalize_b(generation) | |
| else: | |
| image = val_transform_b(image) | |
| image = image.to(dtype = torch.float, device='cpu') | |
| generation = model.GB(image.unsqueeze(0)).squeeze().detach() | |
| generation = de_normalize_a(generation) | |
| resized = generation.resize(image_size, Image.Resampling.LANCZOS) | |
| st.image(resized, caption="Your result!", use_container_width=True) | |