ju4nppp commited on
Commit
49c9fa6
·
verified ·
1 Parent(s): 5b23f61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -85
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import torch
2
  import torch.nn as nn
3
- import torchvision.utils as vutils
4
- import gradio as gr
5
  import numpy as np
6
- import matplotlib.pyplot as plt
7
-
 
8
 
9
- # Define Generator architecture - must match what you used during training
10
  class Generator(nn.Module):
11
- def __init__(self, ngpu=1, nz=100, ngf=64, nc=3):
12
  super(Generator, self).__init__()
13
  self.ngpu = ngpu
14
  self.main = nn.Sequential(
@@ -37,90 +36,62 @@ class Generator(nn.Module):
37
  def forward(self, input):
38
  return self.main(input)
39
 
 
 
 
40
 
41
- # Load the generator
42
- def load_model(model_path="models/netG_best.pth"):
43
- # Create the generator and load the saved weights
44
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
- netG = Generator(ngpu=1, nz=100, ngf=64, nc=3).to(device)
46
-
47
- try:
48
- netG.load_state_dict(torch.load(model_path, map_location=device))
49
- netG.eval() # Set to evaluation mode
50
- print(f"Model loaded successfully from {model_path}")
51
- return netG, device
52
- except Exception as e:
53
- print(f"Error loading model: {e}")
54
- return None, device
55
-
56
-
57
- # Generate images using the model
58
- def generate_images(num_images=16, seed=None, randomize=True):
59
- # Load the model (do this once when needed)
60
- global model, device
61
- if 'model' not in globals():
62
- model, device = load_model()
63
- if model is None:
64
- return np.zeros((299, 299, 3))
65
-
66
- # Set random seed for reproducibility if provided
67
- if seed is not None and not randomize:
68
- torch.manual_seed(seed)
69
- np.random.seed(seed)
70
 
71
- # Generate latent vectors
72
- nz = 100 # Size of the latent vector
73
- noise = torch.randn(num_images, nz, 1, 1, device=device)
74
 
75
- # Generate fake images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  with torch.no_grad():
77
  fake_images = model(noise).detach().cpu()
78
-
79
- # Convert to grid for display
80
- grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
81
-
82
- # Convert from tensor to numpy array for Gradio
83
- grid_np = grid.numpy().transpose((1, 2, 0))
84
-
85
- # Make sure values are in 0-1 range
86
- grid_np = np.clip(grid_np, 0, 1)
87
-
88
- return grid_np
89
-
90
 
91
  # Create Gradio interface
92
- def create_gradio_app():
93
- with gr.Blocks(title="Computer Mouse Generator") as app:
94
- gr.Markdown("# Computer Mouse GAN Generator")
95
- gr.Markdown("Generate computer mice using a Deep Convolutional GAN trained on ~2,500 augmented images")
96
-
97
- with gr.Row():
98
- with gr.Column():
99
- num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
100
- seed = gr.Number(label="Random Seed", value=42, precision=0)
101
- randomize = gr.Checkbox(label="Use Random Seeds (ignore seed value)", value=True)
102
- generate_button = gr.Button("Generate Mice")
103
-
104
- with gr.Column():
105
- output_image = gr.Image(label="Generated Computer Mice")
106
-
107
- generate_button.click(fn=generate_images, inputs=[num_images, seed, randomize], outputs=output_image)
108
-
109
- gr.Markdown("## About")
110
- gr.Markdown("""This model was trained using a PyTorch DCGAN implementation on a dataset of computer mouse images.
111
-
112
- The training process used data augmentation to expand a small dataset of 300+ original images into 2,500+ training samples through techniques like flipping, rotation, and brightness/contrast adjustments.
113
-
114
- The generator creates brand new, never-before-seen computer mice from random noise!""")
115
-
116
- return app
117
-
118
-
119
- # Initialize global variables
120
- model = None
121
- device = None
122
-
123
- # Launch the app if the script is run directly
124
  if __name__ == "__main__":
125
- app = create_gradio_app()
126
- app.launch()
 
1
  import torch
2
  import torch.nn as nn
 
 
3
  import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import os
7
 
8
+ # Define your Generator architecture - with ngf=128 to match your training parameters
9
  class Generator(nn.Module):
10
+ def __init__(self, ngpu=1, nz=100, ngf=128, nc=3):
11
  super(Generator, self).__init__()
12
  self.ngpu = ngpu
13
  self.main = nn.Sequential(
 
36
  def forward(self, input):
37
  return self.main(input)
38
 
39
+ # Load the model - Update path to point to the models folder
40
+ device = torch.device("cpu")
41
+ model_path = "models/netG_epoch_246.pth"
42
 
43
+ # Print file existence for debugging
44
+ print(f"Checking if model file exists: {os.path.exists(model_path)}")
45
+ print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Initialize the model with ngf=128 to match your training parameters
48
+ model = Generator(ngf=128).to(device)
 
49
 
50
+ # Try loading with error handling
51
+ try:
52
+ model.load_state_dict(torch.load(model_path, map_location=device))
53
+ print("Model loaded successfully!")
54
+ except Exception as e:
55
+ print(f"Error loading model: {e}")
56
+ # Try alternative loading methods if the first fails
57
+ try:
58
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
59
+ print("Model loaded with strict=False")
60
+ except Exception as e2:
61
+ print(f"Error with alternative loading: {e2}")
62
+
63
+ # Set model to evaluation mode
64
+ model.eval()
65
+ print(f"Model initialized: {model is not None}")
66
+
67
+ def generate_images(random_seed=42):
68
+ """Generate images using the DCGAN model"""
69
+ # Set seed for reproducibility
70
+ torch.manual_seed(random_seed)
71
+
72
+ # Generate random noise
73
+ noise = torch.randn(1, 100, 1, 1, device=device)
74
+
75
+ # Generate fake image
76
  with torch.no_grad():
77
  fake_images = model(noise).detach().cpu()
78
+
79
+ # Convert tensor to image
80
+ fake_img = fake_images * 0.5 + 0.5 # unnormalize
81
+ fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy()
82
+ fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8)
83
+ return Image.fromarray(fake_img)
 
 
 
 
 
 
84
 
85
  # Create Gradio interface
86
+ demo = gr.Interface(
87
+ fn=generate_images,
88
+ inputs=gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"),
89
+ outputs=gr.Image(type="pil", label="Generated Computer Mouse"),
90
+ title="DCGAN Computer Mouse Generator",
91
+ description="Generate unique computer mouse designs using a DCGAN model trained on computer mice images using ngf=128 and ndf=128.",
92
+ examples=[[42], [23], [7], [99]]
93
+ )
94
+
95
+ # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  if __name__ == "__main__":
97
+ demo.launch()