Spaces:
Paused
Paused
| import io | |
| import os | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "De4u/cyclegan-models") | |
| DEVICE = torch.device("cpu") | |
| PRETTY = { | |
| "cyclegan_export.pt": "apple2orange (яблоки ↔ апельсины)", | |
| "cyclegan_export_monet.pt": "monet2photo (Моне ↔ фото)", | |
| } | |
| MODEL_FILES = list(PRETTY.keys()) | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, dim, norm_layer=nn.InstanceNorm2d, use_bias=True): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias), | |
| norm_layer(dim), | |
| nn.ReLU(inplace=True), | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias), | |
| norm_layer(dim), | |
| ) | |
| def forward(self, x): | |
| return x + self.block(x) | |
| class ResnetGenerator(nn.Module): | |
| def __init__(self, in_channels=3, out_channels=3, ngf=64, n_res_blocks=6, | |
| norm_layer=nn.InstanceNorm2d): | |
| super().__init__() | |
| use_bias = True | |
| model = [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(in_channels, ngf, kernel_size=7, bias=use_bias), | |
| norm_layer(ngf), | |
| nn.ReLU(inplace=True), | |
| ] | |
| n_down = 2 | |
| for i in range(n_down): | |
| mult = 2 ** i | |
| model += [ | |
| nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), | |
| norm_layer(ngf * mult * 2), | |
| nn.ReLU(inplace=True), | |
| ] | |
| mult = 2 ** n_down | |
| for _ in range(n_res_blocks): | |
| model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_bias=use_bias)] | |
| for i in range(n_down): | |
| mult = 2 ** (n_down - i) | |
| model += [ | |
| nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=3, stride=2, | |
| padding=1, output_padding=1, bias=use_bias), | |
| norm_layer(ngf * mult // 2), | |
| nn.ReLU(inplace=True), | |
| ] | |
| model += [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(ngf, out_channels, kernel_size=7), | |
| nn.Tanh(), | |
| ] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| return self.model(x) | |
| def load_models(filename: str): | |
| ckpt_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=filename) | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| mp = ckpt["model_params"] | |
| gen_ab = ResnetGenerator(mp["in_channels"], mp["in_channels"], mp["ngf"], mp["n_res_blocks"]) | |
| gen_ba = ResnetGenerator(mp["in_channels"], mp["in_channels"], mp["ngf"], mp["n_res_blocks"]) | |
| gen_ab.load_state_dict(ckpt["gen_a_to_b"]) | |
| gen_ba.load_state_dict(ckpt["gen_b_to_a"]) | |
| gen_ab.to(DEVICE).eval() | |
| gen_ba.to(DEVICE).eval() | |
| meta = { | |
| "mean_a": np.asarray(ckpt["mean_a"], dtype=np.float32), | |
| "std_a": np.asarray(ckpt["std_a"], dtype=np.float32), | |
| "mean_b": np.asarray(ckpt["mean_b"], dtype=np.float32), | |
| "std_b": np.asarray(ckpt["std_b"], dtype=np.float32), | |
| "image_size": int(ckpt.get("image_size", 128)), | |
| "dataset_name": ckpt.get("dataset_name", "unknown"), | |
| } | |
| return gen_ab, gen_ba, meta | |
| def preprocess(pil_img, mean, std, image_size): | |
| img = pil_img.convert("RGB").resize((image_size, image_size), Image.BICUBIC) | |
| arr = np.asarray(img, dtype=np.float32) / 255.0 | |
| arr = (arr - mean) / std | |
| t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float() | |
| return t.to(DEVICE) | |
| def postprocess(tensor, mean, std): | |
| arr = tensor.squeeze(0).detach().cpu().permute(1, 2, 0).numpy() | |
| arr = arr * std + mean | |
| arr = np.clip(arr, 0, 1) | |
| return (arr * 255).astype(np.uint8) | |
| st.set_page_config(page_title="CycleGAN Img2Img", layout="wide") | |
| st.title("CycleGAN: перевод изображений между доменами A и B") | |
| st.caption(f"Модели: `{MODEL_REPO_ID}` · CPU inference") | |
| choice = st.sidebar.selectbox( | |
| "Модель", | |
| MODEL_FILES, | |
| format_func=lambda f: PRETTY.get(f, f), | |
| ) | |
| st.sidebar.caption("Модель загрузится при первом переводе изображения.") | |
| direction = st.radio("Направление перевода", ["A → B", "B → A"], horizontal=True) | |
| uploaded = st.file_uploader("Загрузите изображение", type=["jpg", "jpeg", "png", "webp", "bmp"]) | |
| if uploaded is not None: | |
| pil_img = Image.open(io.BytesIO(uploaded.read())) | |
| try: | |
| gen_ab, gen_ba, meta = load_models(choice) | |
| except Exception as e: | |
| st.error(f"Не удалось загрузить модель {choice}: {e}") | |
| st.stop() | |
| if direction == "A → B": | |
| inp = preprocess(pil_img, meta["mean_a"], meta["std_a"], meta["image_size"]) | |
| with torch.no_grad(): | |
| out = gen_ab(inp) | |
| out_img = postprocess(out, meta["mean_b"], meta["std_b"]) | |
| in_caption, out_caption = "Вход (домен A)", "Результат (домен B)" | |
| else: | |
| inp = preprocess(pil_img, meta["mean_b"], meta["std_b"], meta["image_size"]) | |
| with torch.no_grad(): | |
| out = gen_ba(inp) | |
| out_img = postprocess(out, meta["mean_a"], meta["std_a"]) | |
| in_caption, out_caption = "Вход (домен B)", "Результат (домен A)" | |
| st.caption( | |
| f"Датасет: {meta['dataset_name']} · размер: {meta['image_size']}px · {DEVICE.type}" | |
| ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(pil_img, caption=in_caption, use_container_width=True) | |
| with col2: | |
| st.image(out_img, caption=out_caption, use_container_width=True) | |
| buf = io.BytesIO() | |
| Image.fromarray(out_img).save(buf, format="PNG") | |
| st.download_button("Скачать результат", buf.getvalue(), file_name="translated.png", mime="image/png") | |
| else: | |
| st.info("Загрузите изображение, чтобы увидеть результат перевода.") | |