uday9k commited on
Commit
3a71a31
·
verified ·
1 Parent(s): 91e4c8f

inital commit

Browse files
Files changed (2) hide show
  1. __init__.py +2 -0
  2. configuration_vae.py +21 -0
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_vae import VAEConfig
2
+ from .modeling_vae import VAEModel
configuration_vae.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Optional, Literal
3
+
4
+ class VAEConfig(PretrainedConfig):
5
+ model_type = "vae"
6
+
7
+ def __init__(
8
+ self,
9
+ data_dim=2,
10
+ latent_dim=2,
11
+ hidden_dim=96,
12
+ encoder_layers=2,
13
+ data_type: Optional[Literal['binary', 'continuous', 'auto']] = 'auto',
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.data_dim = data_dim
18
+ self.latent_dim = latent_dim
19
+ self.hidden_dim = hidden_dim
20
+ self.encoder_layers = encoder_layers
21
+ self.data_type = data_type