violetar commited on
Commit
96672ae
·
verified ·
1 Parent(s): a7902cb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -6
README.md CHANGED
@@ -62,7 +62,6 @@ from torch import nn
62
  from huggingface_hub import hf_hub_download
63
  import matplotlib.pyplot as plt
64
 
65
- # 1. Define the Generator class
66
  class CGAN_Generator(nn.Module):
67
  def __init__(self, noise_dim=100, num_classes=10, img_size=28):
68
  super().__init__()
@@ -91,22 +90,18 @@ class CGAN_Generator(nn.Module):
91
  img = self.conv_blocks(out)
92
  return img
93
 
94
- # 2. Load model weights
95
- noise_dim = 100
96
- model = CGAN_Generator(noise_dim=noise_dim)
97
 
98
  weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
99
  model.load_state_dict(torch.load(weights_path))
100
  model.eval()
101
 
102
- # 3. Generate a digit (e.g., the number '7')
103
  target_label = torch.tensor([7])
104
  eval_noise = torch.randn(1, noise_dim)
105
 
106
  with torch.no_grad():
107
  generated_img = model(eval_noise, target_label)
108
 
109
- # 4. Denormalize and plot
110
  generated_img = (generated_img + 1) / 2.0
111
  img_numpy = generated_img.squeeze().cpu().numpy()
112
 
 
62
  from huggingface_hub import hf_hub_download
63
  import matplotlib.pyplot as plt
64
 
 
65
  class CGAN_Generator(nn.Module):
66
  def __init__(self, noise_dim=100, num_classes=10, img_size=28):
67
  super().__init__()
 
90
  img = self.conv_blocks(out)
91
  return img
92
 
93
+ model = CGAN_Generator(noise_dim=100)
 
 
94
 
95
  weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
96
  model.load_state_dict(torch.load(weights_path))
97
  model.eval()
98
 
 
99
  target_label = torch.tensor([7])
100
  eval_noise = torch.randn(1, noise_dim)
101
 
102
  with torch.no_grad():
103
  generated_img = model(eval_noise, target_label)
104
 
 
105
  generated_img = (generated_img + 1) / 2.0
106
  img_numpy = generated_img.squeeze().cpu().numpy()
107