ma4389 commited on
Commit
edca514
·
verified ·
1 Parent(s): 6d2a8d0

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +68 -0
  2. generator_best.pth +3 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ from torchvision.utils import make_grid
5
+ import torchvision.transforms as transforms
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ # Set device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # Fixed Generator definition (with bias=True)
14
+ class Generator(nn.Module):
15
+ def __init__(self, z_dim=100, channels_img=3, features_g=64):
16
+ super(Generator, self).__init__()
17
+ self.net = nn.Sequential(
18
+ nn.ConvTranspose2d(z_dim, features_g * 8, 4, 1, 0, bias=True),
19
+ nn.BatchNorm2d(features_g * 8),
20
+ nn.ReLU(True),
21
+
22
+ nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=True),
23
+ nn.BatchNorm2d(features_g * 4),
24
+ nn.ReLU(True),
25
+
26
+ nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=True),
27
+ nn.BatchNorm2d(features_g * 2),
28
+ nn.ReLU(True),
29
+
30
+ nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=True),
31
+ nn.BatchNorm2d(features_g),
32
+ nn.ReLU(True),
33
+
34
+ nn.ConvTranspose2d(features_g, channels_img, 4, 2, 1, bias=True),
35
+ nn.Tanh()
36
+ )
37
+
38
+ def forward(self, x):
39
+ return self.net(x)
40
+
41
+ # Load generator
42
+ z_dim = 100
43
+ generator = Generator(z_dim=z_dim).to(device)
44
+ generator.load_state_dict(torch.load("/content/generator_best.pth", map_location=device))
45
+ generator.eval()
46
+
47
+ # Image generation function
48
+ def generate_image(seed: int = 42):
49
+ torch.manual_seed(seed)
50
+ noise = torch.randn(1, z_dim, 1, 1, device=device)
51
+ with torch.no_grad():
52
+ fake_image = generator(noise).cpu()
53
+
54
+ # Convert to PIL image
55
+ img_tensor = (fake_image + 1) / 2 # Denormalize from [-1, 1] to [0, 1]
56
+ img_tensor = img_tensor.squeeze(0)
57
+ to_pil = transforms.ToPILImage()
58
+ img_pil = to_pil(img_tensor)
59
+ return img_pil
60
+
61
+ # Gradio UI
62
+ gr.Interface(
63
+ fn=generate_image,
64
+ inputs=gr.Number(value=42, label="Random Seed"),
65
+ outputs=gr.Image(type="pil"),
66
+ title="DCGAN Image Generator",
67
+ description="Generate fake images using your trained DCGAN Generator"
68
+ ).launch()
generator_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e97248165ba1e7954f4a6e43c60c41cb23df14473160429b36527e446b67886
3
+ size 14328298
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.10
2
+ torchvision
3
+ numpy
4
+ matplotlib
5
+ Pillow
6
+ gradio>=4.0