farid678 commited on
Commit
f2b428b
·
verified ·
1 Parent(s): 3daffc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import gradio as gr
4
  import numpy as np
 
5
 
6
  # --------------------------
7
  # تنظیمات
@@ -37,25 +38,33 @@ G.load_state_dict(torch.load(model_path, map_location=device))
37
  G.eval()
38
 
39
  # --------------------------
40
- # تابع تولید تصویر
41
  # --------------------------
42
- def generate_image(seed=42):
43
  torch.manual_seed(seed)
44
- z = torch.randn(1, latent_dim).to(device)
45
- img = G(z).detach().cpu().squeeze().numpy()
46
- img = (img + 1) / 2 # تبدیل [-1,1] به [0,1]
47
- img = (img * 255).astype(np.uint8) # تبدیل به uint8
48
- return img
 
 
 
 
 
49
 
50
  # --------------------------
51
- # ایجاد رابط Gradio
52
  # --------------------------
53
  iface = gr.Interface(
54
- fn=generate_image,
55
- inputs=gr.Slider(0, 10000, value=42, label="Seed"),
56
- outputs=gr.Image(type="numpy"), # مهم: از type="numpy" استفاده کنید
 
 
 
57
  title="MNIST GAN Generator",
58
- description="یک مدل GAN ساده برای تولید تصاویر اعداد دست‌نویس MNIST"
59
  )
60
 
61
  iface.launch()
 
2
  import torch.nn as nn
3
  import gradio as gr
4
  import numpy as np
5
+ from PIL import Image
6
 
7
  # --------------------------
8
  # تنظیمات
 
38
  G.eval()
39
 
40
  # --------------------------
41
+ # تابع تولید چند تصویر
42
  # --------------------------
43
+ def generate_images(seed=42, num_images=4):
44
  torch.manual_seed(seed)
45
+ z = torch.randn(num_images, latent_dim).to(device)
46
+ imgs = G(z).detach().cpu().numpy() # خروجی: (num_images,1,28,28)
47
+
48
+ pil_images = []
49
+ for i in range(num_images):
50
+ img = (imgs[i].squeeze() + 1) / 2 # [-1,1] -> [0,1]
51
+ img = (img * 255).astype(np.uint8)
52
+ pil_images.append(Image.fromarray(img))
53
+
54
+ return pil_images
55
 
56
  # --------------------------
57
+ # رابط Gradio
58
  # --------------------------
59
  iface = gr.Interface(
60
+ fn=generate_images,
61
+ inputs=[
62
+ gr.Slider(0, 10000, value=42, label="Seed"),
63
+ gr.Slider(1, 16, value=4, label="Number of Images")
64
+ ],
65
+ outputs=gr.Gallery(label="Generated MNIST Images").style(grid=[4,4]),
66
  title="MNIST GAN Generator",
67
+ description="یک مدل GAN برای تولید چند تصویر اعداد دست‌نویس MNIST"
68
  )
69
 
70
  iface.launch()