klasser commited on
Commit
2f8fc6b
·
verified ·
1 Parent(s): 10842a1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from model import CycleGAN # Подтягиваем архитектуру из вашего model.py
6
+
7
+ st.set_page_config(page_title="CycleGAN Demo", layout="wide")
8
+ st.title("🔄 CycleGAN: Трансфер стилей")
9
+
10
+ # Кешируем загрузку, чтобы не грузить 200 МБ при каждом нажатии
11
+ @st.cache_resource
12
+ def load_model():
13
+ # На бесплатном сервере HF нет видеокарт, заставляем работать на CPU
14
+ device = torch.device("cpu")
15
+ model = CycleGAN()
16
+ # Загружаем ваш файл (впишите точное имя файла, который вы загрузили!)
17
+ checkpoint = torch.load("cycle_gan_fixed.pt", map_location=device, weights_only=False)
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
+ model.eval()
20
+ return model, device
21
+
22
+ model, device = load_model()
23
+
24
+ # Трансформации (размер должен быть таким же, как при обучении, например 128)
25
+ IMG_SIZE = 128
26
+ transform = transforms.Compose([
27
+ transforms.Resize((IMG_SIZE, IMG_SIZE), Image.BICUBIC),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
30
+ ])
31
+
32
+ def de_normalize(tensor):
33
+ tensor = tensor.cpu().squeeze(0)
34
+ tensor = tensor * 0.5 + 0.5
35
+ tensor = torch.clamp(tensor, 0, 1)
36
+ return tensor.permute(1, 2, 0).numpy()
37
+
38
+ # Интерфейс
39
+ col1, col2 = st.columns(2)
40
+
41
+ with col1:
42
+ st.header("Домен A ➡️ Домен B")
43
+ file_a = st.file_uploader("Загрузить фото A", type=["jpg", "png", "jpeg"], key="a")
44
+ if file_a:
45
+ img_a = Image.open(file_a).convert("RGB")
46
+ st.image(img_a, caption="Оригинал")
47
+ if st.button("Преобразовать", key="btn_a"):
48
+ with st.spinner("Генерация..."):
49
+ tensor = transform(img_a).unsqueeze(0)
50
+ with torch.no_grad():
51
+ res = model.G_A2B(tensor) # Перевод из A в B
52
+ st.image(de_normalize(res), caption="Результат")
53
+
54
+ with col2:
55
+ st.header("Домен B ➡️ Домен A")
56
+ file_b = st.file_uploader("Загрузить фото B", type=["jpg", "png", "jpeg"], key="b")
57
+ if file_b:
58
+ img_b = Image.open(file_b).convert("RGB")
59
+ st.image(img_b, caption="Оригинал")
60
+ if st.button("Преобразовать", key="btn_b"):
61
+ with st.spinner("Генерация..."):
62
+ tensor = transform(img_b).unsqueeze(0)
63
+ with torch.no_grad():
64
+ res = model.G_B2A(tensor) # Перевод из B в A
65
+ st.image(de_normalize(res), caption="Результат")