uday9k commited on
Commit
135c1d8
·
verified ·
1 Parent(s): 3ac0a3f

Upload 4 files

Browse files
Files changed (4) hide show
  1. __init__.py +2 -0
  2. config.json +6 -1
  3. configuration_vae.py +21 -0
  4. modeling_vae.py +143 -0
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_vae import VAEConfig
2
+ from .modeling_vae import VAEModel
config.json CHANGED
@@ -6,9 +6,13 @@
6
  "data_type": "auto",
7
  "model_type": "vae",
8
  "architectures": [
9
- "VAE"
10
  ],
11
  "_name_or_path": "uday9k/Binarized_MNIST_VAE",
 
 
 
 
12
  "dataset": "mnist",
13
  "image_size": 28,
14
  "channels": 1,
@@ -21,6 +25,7 @@
21
  },
22
  "torch_dtype": "float32",
23
  "framework": "pytorch",
 
24
  "license": "mit",
25
  "description": "Variational Autoencoder for MNIST digit generation"
26
  }
 
6
  "data_type": "auto",
7
  "model_type": "vae",
8
  "architectures": [
9
+ "VAEModel"
10
  ],
11
  "_name_or_path": "uday9k/Binarized_MNIST_VAE",
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_vae.VAEConfig",
14
+ "AutoModel": "modeling_vae.VAEModel"
15
+ },
16
  "dataset": "mnist",
17
  "image_size": 28,
18
  "channels": 1,
 
25
  },
26
  "torch_dtype": "float32",
27
  "framework": "pytorch",
28
+ "transformers_version": "4.36.0",
29
  "license": "mit",
30
  "description": "Variational Autoencoder for MNIST digit generation"
31
  }
configuration_vae.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Optional, Literal
3
+
4
+ class VAEConfig(PretrainedConfig):
5
+ model_type = "vae"
6
+
7
+ def __init__(
8
+ self,
9
+ data_dim=784,
10
+ latent_dim=20,
11
+ hidden_dim=1024,
12
+ encoder_layers=2,
13
+ data_type: Optional[Literal['binary', 'continuous', 'auto']] = 'auto',
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.data_dim = data_dim
18
+ self.latent_dim = latent_dim
19
+ self.hidden_dim = hidden_dim
20
+ self.encoder_layers = encoder_layers
21
+ self.data_type = data_type
modeling_vae.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_vae.py
2
+ from transformers import PreTrainedModel
3
+ import torch.nn as nn
4
+ import torch
5
+ import json
6
+
7
+ class VAEModel(PreTrainedModel):
8
+ config_class = VAEConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+
13
+ self.latent_dim = config.latent_dim
14
+ self.encoder_layers = config.encoder_layers
15
+ self.data_type = config.data_type
16
+ self.data_dim = config.data_dim
17
+ self.hidden_dim=config.hidden_dim
18
+
19
+ # Encoder
20
+ currentDim = self.data_dim
21
+ layers = []
22
+ for i in range(self.encoder_layers):
23
+ nextDim = self.hidden_dim if i ==0 else self.hidden_dim//2
24
+ layers.append(nn.Linear(currentDim, nextDim))
25
+ layers.append(nn.Tanh())
26
+ currentDim = nextDim
27
+ self.encodeLayers=nn.Sequential(*layers)
28
+ self.fc_mu = nn.Linear(currentDim, self.latent_dim)
29
+ self.fc_logvar = nn.Linear(currentDim, self.latent_dim)
30
+
31
+ # Decoder for binary data
32
+ currentDim = self.latent_dim
33
+ layers = []
34
+ for i in range(self.encoder_layers-1):
35
+ nextDim = self.hidden_dim
36
+ layers.append(nn.Linear(currentDim, nextDim))
37
+ layers.append(nn.Tanh())
38
+ currentDim = nextDim
39
+ layers.append(nn.Linear(self.hidden_dim, self.data_dim))
40
+ layers.append(nn.Sigmoid())
41
+ self.decoder_bernoulli = nn.Sequential(*layers)
42
+
43
+ # Decoder for continuous data
44
+ currentDim = self.latent_dim
45
+ layers = []
46
+ for i in range(self.encoder_layers):
47
+ nextDim = self.hidden_dim
48
+ layers.append(nn.Linear(currentDim, nextDim))
49
+ layers.append(nn.Tanh())
50
+ currentDim = nextDim
51
+ self.decoder_gaussian_layers = nn.Sequential(*layers)
52
+ self.decoder_gaussian_mean = nn.Linear(self.hidden_dim, self.data_dim)
53
+ self.decoder_gaussian_logvar = nn.Linear(self.hidden_dim, self.data_dim)
54
+
55
+ self.prior_mean = torch.zeros(self.latent_dim)
56
+ self.prior_std = torch.ones(self.latent_dim)
57
+
58
+ def detect_data_type(self, x: torch.Tensor) -> str:
59
+ unique_vals =torch.unique(x[0:2].flatten())
60
+ if len(unique_vals) <= 2:
61
+ print(f"Auto-detected: Binary data (unique values: {unique_vals.tolist()})")
62
+ return 'binary'
63
+ else:
64
+ print(f"Auto-detected: Continuous data ({len(unique_vals)} unique values)")
65
+ return 'continuous'
66
+
67
+ def encode(self, x: torch.Tensor) -> tuple:
68
+ h = self.encodeLayers(x)
69
+ mu = self.fc_mu(h)
70
+ logvar = self.fc_logvar(h)
71
+ return mu, logvar
72
+
73
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
74
+ std = torch.exp(0.5 * logvar)
75
+ eps = torch.randn_like(std)
76
+ return mu + eps * std
77
+
78
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
79
+ if (self.data_type is None)or(self.data_type=='auto') :
80
+ self.data_type = self.detect_data_type(z)
81
+ if self.data_type == 'binary':
82
+ return self.decoder_bernoulli(z), None
83
+ else:
84
+ h = self.decoder_gaussian_layers(z)
85
+ return self.decoder_gaussian_mean(h), self.decoder_gaussian_logvar(h)
86
+
87
+ def sample_prior(self, num_samples: int) -> torch.Tensor:
88
+ return torch.randn(num_samples, self.latent_dim)
89
+
90
+ def forward(self,x: torch.Tensor,data_type: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
91
+ if (self.data_type is None)or(self.data_type=='auto') :
92
+ self.data_type = self.detect_data_type(x)
93
+ mu, logvar = self.encode(x)
94
+ z = self.reparameterize(mu, logvar)
95
+ recon_x = self.decode(z)
96
+ return recon_x, mu, logvar
97
+
98
+ def reconstruction_loss(self, x: torch.Tensor, recon_output, mu: torch.Tensor,
99
+ logvar: torch.Tensor, data_type: Optional[str] = None) -> torch.Tensor:
100
+ if data_type is None:
101
+ data_type = self.data_type
102
+ kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
103
+ if data_type == 'binary':
104
+ if isinstance(recon_output, tuple):
105
+ recon_output=recon_output[0]
106
+ recon_loss = nn.functional.binary_cross_entropy(recon_output, x, reduction='sum')
107
+ else: # 'continuous'
108
+ mean, logvar_x = recon_output
109
+ var_x = torch.exp(logvar_x)
110
+ recon_loss = 0.5 * torch.sum(torch.log(2 * torch.pi * var_x) + (x - mean).pow(2) / var_x)
111
+ return recon_loss + kl_loss,recon_loss,kl_loss
112
+
113
+ def generate(self, num_samples: int = 1, z: Optional[torch.Tensor] = None):
114
+ if z is None:
115
+ z = self.sample_prior(num_samples)
116
+ recon_x = self.decode(z)
117
+ if isinstance(recon_x, tuple):
118
+ return recon_x[0]
119
+ return recon_x
120
+
121
+ @classmethod
122
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
123
+ # Custom loading to handle your model format
124
+ config_path = f"{pretrained_model_name_or_path}/config.json"
125
+ model_path = f"{pretrained_model_name_or_path}/pytorch_model.bin"
126
+
127
+ # Load config
128
+ with open(config_path, 'r') as f:
129
+ config_dict = json.load(f)
130
+
131
+ # Create config
132
+ config = VAEConfig(**config_dict)
133
+
134
+ # Create model
135
+ model = cls(config)
136
+
137
+ # Load weights
138
+ state_dict = torch.load(model_path, map_location='cpu')
139
+ if 'model_state_dict' in state_dict:
140
+ state_dict = state_dict['model_state_dict']
141
+
142
+ model.load_state_dict(state_dict)
143
+ return model