AxionLab-official commited on
Commit
f486416
·
verified ·
1 Parent(s): 76df893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -5,13 +5,21 @@ import math
5
  from dataclasses import dataclass
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
 
 
 
 
 
 
 
 
8
 
9
  # Otimizações para CPU no servidor do Hugging Face
10
  torch.set_num_threads(4)
11
  torch.backends.mkldnn.enabled = True
12
 
13
  # ───────────────────────────────────────────────
14
- # 1. Configuração e Arquitetura da U-Net (Cópia do seu treino)
15
  # ───────────────────────────────────────────────
16
  @dataclass
17
  class Config:
@@ -153,7 +161,7 @@ class DDPMScheduler:
153
  # 2. Carregando o Modelo do seu Repositório
154
  # ───────────────────────────────────────────────
155
  REPO_ID = "AxionLab-Co/PokePixels1-9M"
156
- FILENAME = "model.pt" # ← Substitua se o nome exato do arquivo no seu repo for outro (ex: checkpoint_epoch100.pt)
157
 
158
  print("Baixando e carregando o modelo...")
159
  model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
@@ -162,7 +170,6 @@ cfg = Config()
162
  model = StudentUNet(cfg)
163
  ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
164
 
165
- # Trata se você salvou apenas o state_dict ou o dicionário inteiro de treino
166
  if "model_state" in ckpt:
167
  model.load_state_dict(ckpt["model_state"])
168
  else:
@@ -180,7 +187,6 @@ def generate_fakemons(num_images, progress=gr.Progress()):
180
  device = "cpu"
181
  x = torch.randn(num_images, 3, cfg.image_size, cfg.image_size, device=device)
182
 
183
- # Passa pelos 1000 passos e atualiza a barra na tela do usuário
184
  for t_val in progress.tqdm(reversed(range(scheduler.T)), total=scheduler.T, desc="Removendo ruído (DDPM)"):
185
  t = torch.full((num_images,), t_val, device=device, dtype=torch.long)
186
  noise_pred = model(x, t)
@@ -197,7 +203,6 @@ def generate_fakemons(num_images, progress=gr.Progress()):
197
  else:
198
  x = scheduler.predict_x0(x, noise_pred, t)
199
 
200
- # Converte o Tensor para uma lista de imagens PIL para o Gradio
201
  x = x.clamp(-1, 1)
202
  x = (x + 1) / 2
203
  x = (x * 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
@@ -205,6 +210,7 @@ def generate_fakemons(num_images, progress=gr.Progress()):
205
  images = [Image.fromarray(img) for img in x]
206
  return images
207
 
 
208
  # ───────────────────────────────────────────────
209
  # 4. Interface Web (Gradio)
210
  # ───────────────────────────────────────────────
 
5
  from dataclasses import dataclass
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
+ import pathlib
9
+ import platform
10
+
11
+ # ───────────────────────────────────────────────
12
+ # Correção para carregar no Linux modelos salvos no Windows
13
+ # ───────────────────────────────────────────────
14
+ if platform.system() == 'Linux':
15
+ pathlib.WindowsPath = pathlib.PosixPath
16
 
17
  # Otimizações para CPU no servidor do Hugging Face
18
  torch.set_num_threads(4)
19
  torch.backends.mkldnn.enabled = True
20
 
21
  # ───────────────────────────────────────────────
22
+ # 1. Configuração e Arquitetura da U-Net
23
  # ───────────────────────────────────────────────
24
  @dataclass
25
  class Config:
 
161
  # 2. Carregando o Modelo do seu Repositório
162
  # ───────────────────────────────────────────────
163
  REPO_ID = "AxionLab-Co/PokePixels1-9M"
164
+ FILENAME = "model.pt"
165
 
166
  print("Baixando e carregando o modelo...")
167
  model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
 
170
  model = StudentUNet(cfg)
171
  ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
172
 
 
173
  if "model_state" in ckpt:
174
  model.load_state_dict(ckpt["model_state"])
175
  else:
 
187
  device = "cpu"
188
  x = torch.randn(num_images, 3, cfg.image_size, cfg.image_size, device=device)
189
 
 
190
  for t_val in progress.tqdm(reversed(range(scheduler.T)), total=scheduler.T, desc="Removendo ruído (DDPM)"):
191
  t = torch.full((num_images,), t_val, device=device, dtype=torch.long)
192
  noise_pred = model(x, t)
 
203
  else:
204
  x = scheduler.predict_x0(x, noise_pred, t)
205
 
 
206
  x = x.clamp(-1, 1)
207
  x = (x + 1) / 2
208
  x = (x * 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
 
210
  images = [Image.fromarray(img) for img in x]
211
  return images
212
 
213
+
214
  # ───────────────────────────────────────────────
215
  # 4. Interface Web (Gradio)
216
  # ───────────────────────────────────────────────