Binarized_MNIST_VAE / configuration_vae.py
uday9k's picture
Upload 4 files
135c1d8 verified
from transformers import PretrainedConfig
from typing import Optional, Literal
class VAEConfig(PretrainedConfig):
model_type = "vae"
def __init__(
self,
data_dim=784,
latent_dim=20,
hidden_dim=1024,
encoder_layers=2,
data_type: Optional[Literal['binary', 'continuous', 'auto']] = 'auto',
**kwargs
):
super().__init__(**kwargs)
self.data_dim = data_dim
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.encoder_layers = encoder_layers
self.data_type = data_type