ZYI-1.0 / app.py
caikybaldo999's picture
Update app.py
2740e28 verified
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
z_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "/kaggle/working/generator64.pth" # <-- CHANGE ACCORDING TO THE LOCATION OF generator64.pth
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
self._block(z_dim, 512, 4, 1, 0), # 1x1 β†’ 4x4
self._block(512, 256, 4, 2, 1), # 4x4 β†’ 8x8
self._block(256, 128, 4, 2, 1), # 8x8 β†’ 16x16
self._block(128, 64, 4, 2, 1), # 16x16 β†’ 32x32
nn.ConvTranspose2d(64, 3, 4, 2, 1), # 32x32 β†’ 64x64
nn.Tanh()
)
def _block(self, in_c, out_c, k, s, p):
return nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, k, s, p, bias=False),
nn.BatchNorm2d(out_c),
nn.ReLU(True)
)
def forward(self, x):
return self.model(x)
G = Generator().to(device)
if os.path.exists(model_path):
G.load_state_dict(torch.load(model_path, map_location=device))
G.eval()
print("βœ… Model loaded successfully!")
else:
raise FileNotFoundError(f"❌ File not found: {model_path}")
# =======================
# GERA UMA IMAGEM
# =======================
with torch.no_grad():
noise = torch.randn(1, z_dim, 1, 1).to(device)
fake_img = G(noise).detach().cpu().squeeze(0)
img = (fake_img.permute(1, 2, 0) + 1) / 2 # normaliza para [0,1]
# =======================
# EXIBE A IMAGEM
# =======================
plt.figure(figsize=(3, 3))
plt.axis("off")
plt.title("🧠 Generated Image")
plt.imshow(img)
plt.show()