GenerateDigit / app.py
farid678's picture
Update app.py
7ac6299 verified
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
from PIL import Image
# --------------------------
# تنظیمات
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
model_path = "generator.pth"
# --------------------------
# تعریف Generator
# --------------------------
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, z):
out = self.model(z)
return out.view(-1, 1, 28, 28)
# --------------------------
# بارگذاری مدل Generator
# --------------------------
G = Generator().to(device)
G.load_state_dict(torch.load(model_path, map_location=device))
G.eval()
# --------------------------
# تابع تولید چند تصویر
# --------------------------
def generate_images(seed=42, num_images=4):
torch.manual_seed(seed)
z = torch.randn(num_images, latent_dim).to(device)
imgs = G(z).detach().cpu().numpy()
pil_images = []
for i in range(num_images):
img = (imgs[i].squeeze() + 1) / 2 # [-1,1] -> [0,1]
img = (img * 255).astype(np.uint8)
pil_images.append(Image.fromarray(img))
return pil_images
# --------------------------
# رابط Gradio
# --------------------------
iface = gr.Interface(
fn=generate_images,
inputs=[
gr.Slider(0, 10000, value=42, label="Seed"),
gr.Slider(1, 16, value=4, label="Number of Images")
],
outputs=gr.Gallery(label="Generated MNIST Images", columns=4, type="pil"),
title="MNIST GAN Generator",
description="یک مدل GAN برای تولید چند تصویر اعداد دست‌نویس MNIST"
)
iface.launch()