Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,13 +5,21 @@ import math
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Otimizações para CPU no servidor do Hugging Face
|
| 10 |
torch.set_num_threads(4)
|
| 11 |
torch.backends.mkldnn.enabled = True
|
| 12 |
|
| 13 |
# ───────────────────────────────────────────────
|
| 14 |
-
# 1. Configuração e Arquitetura da U-Net
|
| 15 |
# ───────────────────────────────────────────────
|
| 16 |
@dataclass
|
| 17 |
class Config:
|
|
@@ -153,7 +161,7 @@ class DDPMScheduler:
|
|
| 153 |
# 2. Carregando o Modelo do seu Repositório
|
| 154 |
# ───────────────────────────────────────────────
|
| 155 |
REPO_ID = "AxionLab-Co/PokePixels1-9M"
|
| 156 |
-
FILENAME = "model.pt"
|
| 157 |
|
| 158 |
print("Baixando e carregando o modelo...")
|
| 159 |
model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
|
|
@@ -162,7 +170,6 @@ cfg = Config()
|
|
| 162 |
model = StudentUNet(cfg)
|
| 163 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 164 |
|
| 165 |
-
# Trata se você salvou apenas o state_dict ou o dicionário inteiro de treino
|
| 166 |
if "model_state" in ckpt:
|
| 167 |
model.load_state_dict(ckpt["model_state"])
|
| 168 |
else:
|
|
@@ -180,7 +187,6 @@ def generate_fakemons(num_images, progress=gr.Progress()):
|
|
| 180 |
device = "cpu"
|
| 181 |
x = torch.randn(num_images, 3, cfg.image_size, cfg.image_size, device=device)
|
| 182 |
|
| 183 |
-
# Passa pelos 1000 passos e atualiza a barra na tela do usuário
|
| 184 |
for t_val in progress.tqdm(reversed(range(scheduler.T)), total=scheduler.T, desc="Removendo ruído (DDPM)"):
|
| 185 |
t = torch.full((num_images,), t_val, device=device, dtype=torch.long)
|
| 186 |
noise_pred = model(x, t)
|
|
@@ -197,7 +203,6 @@ def generate_fakemons(num_images, progress=gr.Progress()):
|
|
| 197 |
else:
|
| 198 |
x = scheduler.predict_x0(x, noise_pred, t)
|
| 199 |
|
| 200 |
-
# Converte o Tensor para uma lista de imagens PIL para o Gradio
|
| 201 |
x = x.clamp(-1, 1)
|
| 202 |
x = (x + 1) / 2
|
| 203 |
x = (x * 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
|
@@ -205,6 +210,7 @@ def generate_fakemons(num_images, progress=gr.Progress()):
|
|
| 205 |
images = [Image.fromarray(img) for img in x]
|
| 206 |
return images
|
| 207 |
|
|
|
|
| 208 |
# ───────────────────────────────────────────────
|
| 209 |
# 4. Interface Web (Gradio)
|
| 210 |
# ───────────────────────────────────────────────
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
from PIL import Image
|
| 8 |
+
import pathlib
|
| 9 |
+
import platform
|
| 10 |
+
|
| 11 |
+
# ───────────────────────────────────────────────
|
| 12 |
+
# Correção para carregar no Linux modelos salvos no Windows
|
| 13 |
+
# ───────────────────────────────────────────────
|
| 14 |
+
if platform.system() == 'Linux':
|
| 15 |
+
pathlib.WindowsPath = pathlib.PosixPath
|
| 16 |
|
| 17 |
# Otimizações para CPU no servidor do Hugging Face
|
| 18 |
torch.set_num_threads(4)
|
| 19 |
torch.backends.mkldnn.enabled = True
|
| 20 |
|
| 21 |
# ───────────────────────────────────────────────
|
| 22 |
+
# 1. Configuração e Arquitetura da U-Net
|
| 23 |
# ───────────────────────────────────────────────
|
| 24 |
@dataclass
|
| 25 |
class Config:
|
|
|
|
| 161 |
# 2. Carregando o Modelo do seu Repositório
|
| 162 |
# ───────────────────────────────────────────────
|
| 163 |
REPO_ID = "AxionLab-Co/PokePixels1-9M"
|
| 164 |
+
FILENAME = "model.pt"
|
| 165 |
|
| 166 |
print("Baixando e carregando o modelo...")
|
| 167 |
model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
|
|
|
|
| 170 |
model = StudentUNet(cfg)
|
| 171 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 172 |
|
|
|
|
| 173 |
if "model_state" in ckpt:
|
| 174 |
model.load_state_dict(ckpt["model_state"])
|
| 175 |
else:
|
|
|
|
| 187 |
device = "cpu"
|
| 188 |
x = torch.randn(num_images, 3, cfg.image_size, cfg.image_size, device=device)
|
| 189 |
|
|
|
|
| 190 |
for t_val in progress.tqdm(reversed(range(scheduler.T)), total=scheduler.T, desc="Removendo ruído (DDPM)"):
|
| 191 |
t = torch.full((num_images,), t_val, device=device, dtype=torch.long)
|
| 192 |
noise_pred = model(x, t)
|
|
|
|
| 203 |
else:
|
| 204 |
x = scheduler.predict_x0(x, noise_pred, t)
|
| 205 |
|
|
|
|
| 206 |
x = x.clamp(-1, 1)
|
| 207 |
x = (x + 1) / 2
|
| 208 |
x = (x * 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
|
|
|
| 210 |
images = [Image.fromarray(img) for img in x]
|
| 211 |
return images
|
| 212 |
|
| 213 |
+
|
| 214 |
# ───────────────────────────────────────────────
|
| 215 |
# 4. Interface Web (Gradio)
|
| 216 |
# ───────────────────────────────────────────────
|