gendigit / src /generate.py
marlonsousa's picture
Upload 16 files
5be6b48 verified
import torch
import matplotlib.pyplot as plt
from src.generator import Gerador
import uuid
from PIL import Image
import numpy as np
import os
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
G = Gerador().to(device)
G.load_state_dict(torch.load('config/gerador_mnist.pth', map_location=device))
G.eval()
def gerar_imagem(numero, qtd=6, escala=10):
G.eval()
ruido = torch.rand(qtd, 100).to(device) * 2 - 1
# Criando vetor one-hot do número desejado
rotulo = torch.zeros(qtd, 10).to(device)
rotulo[:, numero] = 1 # Ativando a posição correspondente ao número desejado
previsao = G(ruido, rotulo).cpu().detach().numpy()
# Criar diretório "data" se não existir
os.makedirs("data", exist_ok=True)
nomes_arquivos = []
for i in range(qtd):
nome_arquivo = f"data/{uuid.uuid4()}.png"
imagem = (previsao[i] * 255).astype(np.uint8) # Normalizando para 0-255
img = Image.fromarray(imagem)
# Redimensionar para uma escala maior (exemplo: 10x maior -> 280x280)
img = img.resize((28 * escala, 28 * escala), Image.NEAREST)
img.save(nome_arquivo, quality=95)
nomes_arquivos.append(nome_arquivo)
return nomes_arquivos