Update README.md
Browse files
README.md
CHANGED
|
@@ -62,7 +62,6 @@ from torch import nn
|
|
| 62 |
from huggingface_hub import hf_hub_download
|
| 63 |
import matplotlib.pyplot as plt
|
| 64 |
|
| 65 |
-
# 1. Define the Generator class
|
| 66 |
class CGAN_Generator(nn.Module):
|
| 67 |
def __init__(self, noise_dim=100, num_classes=10, img_size=28):
|
| 68 |
super().__init__()
|
|
@@ -91,22 +90,18 @@ class CGAN_Generator(nn.Module):
|
|
| 91 |
img = self.conv_blocks(out)
|
| 92 |
return img
|
| 93 |
|
| 94 |
-
|
| 95 |
-
noise_dim = 100
|
| 96 |
-
model = CGAN_Generator(noise_dim=noise_dim)
|
| 97 |
|
| 98 |
weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
|
| 99 |
model.load_state_dict(torch.load(weights_path))
|
| 100 |
model.eval()
|
| 101 |
|
| 102 |
-
# 3. Generate a digit (e.g., the number '7')
|
| 103 |
target_label = torch.tensor([7])
|
| 104 |
eval_noise = torch.randn(1, noise_dim)
|
| 105 |
|
| 106 |
with torch.no_grad():
|
| 107 |
generated_img = model(eval_noise, target_label)
|
| 108 |
|
| 109 |
-
# 4. Denormalize and plot
|
| 110 |
generated_img = (generated_img + 1) / 2.0
|
| 111 |
img_numpy = generated_img.squeeze().cpu().numpy()
|
| 112 |
|
|
|
|
| 62 |
from huggingface_hub import hf_hub_download
|
| 63 |
import matplotlib.pyplot as plt
|
| 64 |
|
|
|
|
| 65 |
class CGAN_Generator(nn.Module):
|
| 66 |
def __init__(self, noise_dim=100, num_classes=10, img_size=28):
|
| 67 |
super().__init__()
|
|
|
|
| 90 |
img = self.conv_blocks(out)
|
| 91 |
return img
|
| 92 |
|
| 93 |
+
model = CGAN_Generator(noise_dim=100)
|
|
|
|
|
|
|
| 94 |
|
| 95 |
weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
|
| 96 |
model.load_state_dict(torch.load(weights_path))
|
| 97 |
model.eval()
|
| 98 |
|
|
|
|
| 99 |
target_label = torch.tensor([7])
|
| 100 |
eval_noise = torch.randn(1, noise_dim)
|
| 101 |
|
| 102 |
with torch.no_grad():
|
| 103 |
generated_img = model(eval_noise, target_label)
|
| 104 |
|
|
|
|
| 105 |
generated_img = (generated_img + 1) / 2.0
|
| 106 |
img_numpy = generated_img.squeeze().cpu().numpy()
|
| 107 |
|