MNIST conditional GAN (cGAN)

Class-conditional synthesis of 28×28 grayscale MNIST-style digits. The generator maps noise z and digit label y to an image; the discriminator uses a projection discriminator (Miyato & Koyama, ICLR 2018) with spectral normalization.

Files in this model repo

File Description
mnist_cgan_generator.pth Generator state_dict for inference (matches submission_digit_cgan.py).
training_<source>.pt Original checkpoint file (full training state when applicable).
cgan_architecture.py Copy of digit_cgan/model.py (Generator + Discriminator definitions).
generator_config.json Inferred constructor kwargs and metadata.

Weights

Checkpoint: generator-only export (epoch not in file).

Inferred architecture (from tensors):

  • latent_dim=100, embed_dim=100, base_channels=384, num_classes=10
  • Output shape: (B, 1, 28, 28), values in [-1, 1] (tanh).

Source file: mnist_cgan_generator.pth.

Load the generator (example)

import torch
from huggingface_hub import hf_hub_download

import sys
sys.path.insert(0, "/path/to/week-06")
from digit_cgan.model import Generator

repo_id = "<YOUR_REPO_ID>"
weights = hf_hub_download(repo_id, "mnist_cgan_generator.pth")

G = Generator(
    latent_dim=100,
    embed_dim=100,
    base_channels=384,
    num_classes=10,
)
G.load_state_dict(torch.load(weights, map_location="cpu", weights_only=True))
G.eval()

with torch.no_grad():
    z = torch.randn(4, 100)
    y = torch.tensor([0, 1, 2, 3])
    fake = G(z, y)

Architecture (cgan_architecture.py)

  • Generator: class embedding concatenated with z, linear reshape to 7×7 features, two ConvTranspose2d stages to 28×28, conv to 1 channel + tanh.
  • Discriminator: convolutional backbone with spectral norm, global pool, linear map to a feature vector; score is unconditional linear term plus inner product between features and a class embedding (projection term).

See T. Miyato & M. Koyama, cGANs with Projection Discriminator, ICLR 2018.

Training (typical)

python -m digit_cgan.train — hinge loss, Adam, optional EMA on the generator for sampling; best FID checkpoints use the EMA weights in best_generator.pth.

CLI defaults in train.py include latent_dim=100, embed_dim=100; base_channels_g / base_channels_d / feature_dim may differ per run — always use generator_config.json or infer from weights as above.

Limitations

MNIST is a simple benchmark; generalization to out-of-distribution digit styles is not guaranteed.

References

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for RGarrido03/mnist-conditional-gan