import torch import matplotlib.pyplot as plt from src.generator import Gerador import uuid from PIL import Image import numpy as np import os device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') G = Gerador().to(device) G.load_state_dict(torch.load('config/gerador_mnist.pth', map_location=device)) G.eval() def gerar_imagem(numero, qtd=6, escala=10): G.eval() ruido = torch.rand(qtd, 100).to(device) * 2 - 1 # Criando vetor one-hot do número desejado rotulo = torch.zeros(qtd, 10).to(device) rotulo[:, numero] = 1 # Ativando a posição correspondente ao número desejado previsao = G(ruido, rotulo).cpu().detach().numpy() # Criar diretório "data" se não existir os.makedirs("data", exist_ok=True) nomes_arquivos = [] for i in range(qtd): nome_arquivo = f"data/{uuid.uuid4()}.png" imagem = (previsao[i] * 255).astype(np.uint8) # Normalizando para 0-255 img = Image.fromarray(imagem) # Redimensionar para uma escala maior (exemplo: 10x maior -> 280x280) img = img.resize((28 * escala, 28 * escala), Image.NEAREST) img.save(nome_arquivo, quality=95) nomes_arquivos.append(nome_arquivo) return nomes_arquivos