De4u's picture
Fix: CPU torch, lazy load, preload models at build
a05d721 verified
Raw
History Blame Contribute Delete
6.46 kB
import io
import os
import numpy as np
import streamlit as st
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from PIL import Image
MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "De4u/cyclegan-models")
DEVICE = torch.device("cpu")
PRETTY = {
"cyclegan_export.pt": "apple2orange (яблоки ↔ апельсины)",
"cyclegan_export_monet.pt": "monet2photo (Моне ↔ фото)",
}
MODEL_FILES = list(PRETTY.keys())
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_layer=nn.InstanceNorm2d, use_bias=True):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias),
norm_layer(dim),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias),
norm_layer(dim),
)
def forward(self, x):
return x + self.block(x)
class ResnetGenerator(nn.Module):
def __init__(self, in_channels=3, out_channels=3, ngf=64, n_res_blocks=6,
norm_layer=nn.InstanceNorm2d):
super().__init__()
use_bias = True
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, ngf, kernel_size=7, bias=use_bias),
norm_layer(ngf),
nn.ReLU(inplace=True),
]
n_down = 2
for i in range(n_down):
mult = 2 ** i
model += [
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(inplace=True),
]
mult = 2 ** n_down
for _ in range(n_res_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_bias=use_bias)]
for i in range(n_down):
mult = 2 ** (n_down - i)
model += [
nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=3, stride=2,
padding=1, output_padding=1, bias=use_bias),
norm_layer(ngf * mult // 2),
nn.ReLU(inplace=True),
]
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(ngf, out_channels, kernel_size=7),
nn.Tanh(),
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
@st.cache_resource(show_spinner="Загружаю модель...")
def load_models(filename: str):
ckpt_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=filename)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
mp = ckpt["model_params"]
gen_ab = ResnetGenerator(mp["in_channels"], mp["in_channels"], mp["ngf"], mp["n_res_blocks"])
gen_ba = ResnetGenerator(mp["in_channels"], mp["in_channels"], mp["ngf"], mp["n_res_blocks"])
gen_ab.load_state_dict(ckpt["gen_a_to_b"])
gen_ba.load_state_dict(ckpt["gen_b_to_a"])
gen_ab.to(DEVICE).eval()
gen_ba.to(DEVICE).eval()
meta = {
"mean_a": np.asarray(ckpt["mean_a"], dtype=np.float32),
"std_a": np.asarray(ckpt["std_a"], dtype=np.float32),
"mean_b": np.asarray(ckpt["mean_b"], dtype=np.float32),
"std_b": np.asarray(ckpt["std_b"], dtype=np.float32),
"image_size": int(ckpt.get("image_size", 128)),
"dataset_name": ckpt.get("dataset_name", "unknown"),
}
return gen_ab, gen_ba, meta
def preprocess(pil_img, mean, std, image_size):
img = pil_img.convert("RGB").resize((image_size, image_size), Image.BICUBIC)
arr = np.asarray(img, dtype=np.float32) / 255.0
arr = (arr - mean) / std
t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()
return t.to(DEVICE)
def postprocess(tensor, mean, std):
arr = tensor.squeeze(0).detach().cpu().permute(1, 2, 0).numpy()
arr = arr * std + mean
arr = np.clip(arr, 0, 1)
return (arr * 255).astype(np.uint8)
st.set_page_config(page_title="CycleGAN Img2Img", layout="wide")
st.title("CycleGAN: перевод изображений между доменами A и B")
st.caption(f"Модели: `{MODEL_REPO_ID}` · CPU inference")
choice = st.sidebar.selectbox(
"Модель",
MODEL_FILES,
format_func=lambda f: PRETTY.get(f, f),
)
st.sidebar.caption("Модель загрузится при первом переводе изображения.")
direction = st.radio("Направление перевода", ["A → B", "B → A"], horizontal=True)
uploaded = st.file_uploader("Загрузите изображение", type=["jpg", "jpeg", "png", "webp", "bmp"])
if uploaded is not None:
pil_img = Image.open(io.BytesIO(uploaded.read()))
try:
gen_ab, gen_ba, meta = load_models(choice)
except Exception as e:
st.error(f"Не удалось загрузить модель {choice}: {e}")
st.stop()
if direction == "A → B":
inp = preprocess(pil_img, meta["mean_a"], meta["std_a"], meta["image_size"])
with torch.no_grad():
out = gen_ab(inp)
out_img = postprocess(out, meta["mean_b"], meta["std_b"])
in_caption, out_caption = "Вход (домен A)", "Результат (домен B)"
else:
inp = preprocess(pil_img, meta["mean_b"], meta["std_b"], meta["image_size"])
with torch.no_grad():
out = gen_ba(inp)
out_img = postprocess(out, meta["mean_a"], meta["std_a"])
in_caption, out_caption = "Вход (домен B)", "Результат (домен A)"
st.caption(
f"Датасет: {meta['dataset_name']} · размер: {meta['image_size']}px · {DEVICE.type}"
)
col1, col2 = st.columns(2)
with col1:
st.image(pil_img, caption=in_caption, use_container_width=True)
with col2:
st.image(out_img, caption=out_caption, use_container_width=True)
buf = io.BytesIO()
Image.fromarray(out_img).save(buf, format="PNG")
st.download_button("Скачать результат", buf.getvalue(), file_name="translated.png", mime="image/png")
else:
st.info("Загрузите изображение, чтобы увидеть результат перевода.")