SreekarB commited on
Commit
e4a8a19
·
verified ·
1 Parent(s): e88139d

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +0 -0
  2. requirements.txt +1 -7
  3. test_huggingface.py +35 -0
  4. vae_model.py +312 -452
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,14 +1,8 @@
1
  torch>=1.9.0
2
  numpy>=1.19.2
3
  pandas>=1.2.4
4
- nilearn>=0.8.1
5
- nibabel>=3.2.1
6
  scikit-learn>=0.24.2
7
  matplotlib>=3.4.2
8
- gradio>=2.0.0
9
- datasets>=1.11.0
10
- huggingface_hub>=0.15.0
11
- transformers>=4.15.0
12
- seaborn>=0.11.2
13
  joblib>=1.0.1
14
 
 
1
  torch>=1.9.0
2
  numpy>=1.19.2
3
  pandas>=1.2.4
 
 
4
  scikit-learn>=0.24.2
5
  matplotlib>=3.4.2
6
+ gradio>=3.0.0
 
 
 
 
7
  joblib>=1.0.1
8
 
test_huggingface.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple test script to verify the Huggingface app works locally.
3
+ This will run the app with synthetic data.
4
+ """
5
+ import numpy as np
6
+ import pandas as pd
7
+ import os
8
+
9
+ # Ensure directories exist
10
+ os.makedirs('results', exist_ok=True)
11
+ os.makedirs('models', exist_ok=True)
12
+
13
+ # Create synthetic data
14
+ print("Creating synthetic test data...")
15
+ n_samples = 10
16
+ n_features = 100
17
+
18
+ # Create FC matrix data
19
+ fc_data = np.random.randn(n_samples, n_features)
20
+ np.save('results/test_fc.npy', fc_data)
21
+ print(f"Saved FC matrix data to results/test_fc.npy with shape {fc_data.shape}")
22
+
23
+ # Create demographics data
24
+ demo_df = pd.DataFrame({
25
+ 'age': np.random.normal(60, 10, n_samples),
26
+ 'sex': np.random.choice(['M', 'F'], n_samples),
27
+ 'months_post_stroke': np.random.normal(24, 12, n_samples),
28
+ 'wab_score': np.random.normal(65, 15, n_samples)
29
+ })
30
+ demo_df.to_csv('results/test_demographics.csv', index=False)
31
+ print(f"Saved demographics data to results/test_demographics.csv with shape {demo_df.shape}")
32
+
33
+ print("\nTest data created successfully!")
34
+ print("\nNow you can run: python app.py")
35
+ print("Then upload the test files to train a model.")
vae_model.py CHANGED
@@ -1,495 +1,355 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import numpy as np
5
- from utils import to_torch, to_cuda, to_numpy, demo_to_torch
 
6
  from sklearn.base import BaseEstimator
7
 
8
- class VAE(nn.Module):
9
- def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
10
- super(VAE, self).__init__()
 
11
  self.input_dim = input_dim
12
  self.latent_dim = latent_dim
13
  self.demo_dim = demo_dim
14
- self.use_cuda = use_cuda
15
 
16
- # Create layers with standard parameters (no .float() call)
17
- self.enc1 = nn.Linear(input_dim, 1000)
18
- self.enc2 = nn.Linear(1000, latent_dim)
19
 
20
- # Decoder
21
- self.dec1 = nn.Linear(latent_dim+demo_dim, 1000)
22
- self.dec2 = nn.Linear(1000, input_dim)
23
 
24
- # Batch normalization layers
25
- self.bn1 = nn.BatchNorm1d(1000)
26
- self.bn2 = nn.BatchNorm1d(1000)
 
 
 
 
 
 
 
 
27
 
28
- # Move to CUDA if requested and available
29
- if use_cuda and torch.cuda.is_available():
30
- self.cuda()
 
31
 
32
- def enc(self, x):
33
- # First layer with activation
34
- h = self.enc1(x)
35
- h = F.relu(h)
 
 
 
 
 
 
 
36
 
37
- # Apply batch norm - handle training vs eval mode automatically
38
- h = self.bn1(h)
 
 
 
 
 
39
 
40
- # Output layer
41
- z = self.enc2(h)
42
- return z
43
-
44
- def gen(self, n):
45
- return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
46
-
47
- def dec(self, z, demo):
48
- # Concatenate latent code with demographic data
49
- z_combined = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # First decoder layer with activation
52
- h = self.dec1(z_combined)
53
- h = F.relu(h)
 
 
 
 
 
54
 
55
- # Apply batch norm - handle training vs eval mode automatically
56
- h = self.bn2(h)
 
57
 
58
- # Output layer
59
- x = self.dec2(h)
60
- return x
61
-
62
- class DemoVAE(BaseEstimator):
63
- def __init__(self, **params):
64
- self.set_params(**params)
65
-
66
- @staticmethod
67
- def get_default_params():
68
- return dict(
69
- latent_dim=32,
70
- use_cuda=True,
71
- nepochs=100, # Changed from 1000 to 100 for faster testing
72
- pperiod=10, # Changed from 100 to 10 to see more progress updates
73
- bsize=5, # Changed from 16 to 5 for small sample sizes
74
- loss_C_mult=1,
75
- loss_mu_mult=1,
76
- loss_rec_mult=100,
77
- loss_decor_mult=10,
78
- loss_pred_mult=0.001,
79
- alpha=100,
80
- LR_C=100,
81
- lr=1e-4,
82
- weight_decay=0
83
- )
84
-
85
- def get_params(self, deep=True):
86
- return {k: getattr(self, k) for k in self.get_default_params().keys()}
87
-
88
- def set_params(self, **params):
89
- for k, v in self.get_default_params().items():
90
- setattr(self, k, params.get(k, v))
91
- return self
92
-
93
- def fit(self, x, demo, demo_types):
94
- from utils import train_vae
95
-
96
- # Calculate demo_dim
97
- demo_dim = 0
98
- for d, t in zip(demo, demo_types):
99
- if t == 'continuous':
100
- demo_dim += 1
101
- elif t == 'categorical':
102
- demo_dim += len(set(d))
103
- else:
104
- raise ValueError(f'Demographic type "{t}" not supported')
105
-
106
- # Initialize VAE
107
- self.input_dim = x.shape[1]
108
- self.demo_dim = demo_dim
109
- self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
110
-
111
- # Train VAE
112
- train_losses, val_losses = train_vae(
113
- self.vae, x, demo, demo_types,
114
- self.nepochs, self.pperiod, self.bsize,
115
- self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
116
- self.loss_decor_mult, self.loss_pred_mult,
117
- self.lr, self.weight_decay, self.alpha, self.LR_C,
118
- self
119
- )
120
 
121
- # Store the losses for later visualization
122
- self.train_losses = train_losses
123
- self.val_losses = val_losses
124
 
125
- # Return the losses for immediate use
126
- return train_losses, val_losses
127
-
128
- def transform(self, x, demo, demo_types):
129
- """
130
- Transform data through the VAE model.
131
-
132
- Args:
133
- x: Either an integer (to generate samples) or input data to encode/decode
134
- demo: Demographic data
135
- demo_types: Types of demographic variables
136
-
137
- Returns:
138
- Transformed data (reconstructions or generations)
139
- """
140
- print(f"VAE transform called - Input type: {type(x)}")
141
- if not isinstance(x, int):
142
- print(f"Input data shape: {np.array(x).shape}")
143
- print(f"Demo data: {[len(d) for d in demo]}, Types: {demo_types}")
144
-
145
- # Set model to evaluation mode to handle batch norm with batch size of 1
146
  self.vae.eval()
 
 
 
 
 
147
 
148
- try:
149
- # Use torch.no_grad to disable gradient calculation during inference
150
- with torch.no_grad():
151
- # Generate latent vectors or encode inputs
152
- if isinstance(x, int):
153
- print(f"Generating {x} random latent vectors...")
154
- z = self.vae.gen(x)
155
- print(f"Generated latent vectors shape: {z.shape}")
156
- else:
157
- print("Encoding input data to latent space...")
158
- x_tensor = to_cuda(to_torch(x), self.vae.use_cuda)
159
- print(f"Input tensor shape: {x_tensor.shape}")
160
- z = self.vae.enc(x_tensor)
161
- print(f"Encoded latent vectors shape: {z.shape}")
162
-
163
- # Convert demographics to tensors
164
- print("Converting demographics to tensors...")
165
- try:
166
- demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
167
- print(f"Demographic tensor shape: {demo_t.shape}")
168
- except Exception as demo_err:
169
- print(f"Error in demographic conversion: {demo_err}")
170
- raise
171
-
172
- # Handle batch size of 1 for batch normalization
173
- print(f"Decoding with batch size: {z.size(0)}")
174
- if z.size(0) == 1:
175
- print("Using special handling for batch size=1...")
176
- # If batch size is 1, we need to be careful with batch norm
177
- # Clone and repeat the input to create a fake batch if needed
178
- if hasattr(self.vae, 'bn1') or hasattr(self.vae, 'bn2'):
179
- print("Batch normalization layers detected")
180
- try:
181
- # Try normal decoding first
182
- print("Attempting normal decoding...")
183
- y = self.vae.dec(z, demo_t)
184
- print("Normal decoding succeeded")
185
- except Exception as e:
186
- # If it fails, use a workaround for batch norm
187
- print(f"Normal decoding failed: {e}")
188
- print("Using batch norm workaround (repeating batch)...")
189
- # Create a batch by repeating the input
190
- z_batch = z.repeat(2, 1)
191
- demo_t_batch = demo_t.repeat(2, 1)
192
- # Get the output and use only the first element
193
- print(f"Created batch with shapes - z: {z_batch.shape}, demo: {demo_t_batch.shape}")
194
- y_batch = self.vae.dec(z_batch, demo_t_batch)
195
- print(f"Batch decoding succeeded, extracting first item from {y_batch.shape}")
196
- y = y_batch[0:1]
197
- else:
198
- # No batch norm, proceed normally
199
- print("No batch norm, proceeding normally...")
200
- y = self.vae.dec(z, demo_t)
201
- else:
202
- # Normal batch size, proceed as usual
203
- print("Normal batch size, proceeding with standard decoding...")
204
- y = self.vae.dec(z, demo_t)
205
-
206
- print(f"Decoding complete, output tensor shape: {y.shape}")
207
 
208
- # Convert to numpy
209
- result = to_numpy(y)
210
- print(f"Final output shape: {result.shape}")
211
 
212
- # Check for NaN values in the result
213
- if np.any(np.isnan(result)):
214
- print("WARNING: Result contains NaN values")
215
- result = np.nan_to_num(result)
216
- print("NaN values replaced with zeros")
217
 
218
- return result
 
 
219
 
220
- except Exception as e:
221
- import traceback
222
- print(f"Error in VAE transform: {e}")
223
- print(f"Traceback: {traceback.format_exc()}")
224
 
225
- # Create a fallback output with appropriate shape
226
- if isinstance(x, int):
227
- # Generate empty latent vectors with the right shape
228
- n_features = self.input_dim
229
- fallback = np.zeros((x, n_features))
230
- else:
231
- # Return empty array with same shape as input
232
- fallback = np.zeros_like(np.array(x))
233
-
234
- print(f"Returning fallback output with shape: {fallback.shape}")
235
- return fallback
236
-
237
- def encode(self, x):
238
- """Alias for get_latents method - to provide compatibility with some interfaces"""
239
- return self.get_latents(x)
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- def get_latents(self, x):
242
  # Set model to evaluation mode
243
  self.vae.eval()
244
 
245
- # Use torch.no_grad for inference
246
- with torch.no_grad():
247
- try:
248
- # Convert to torch tensor and move to CUDA if needed
249
- x_tensor = to_cuda(to_torch(x), self.vae.use_cuda)
250
-
251
- # Get latent representation
252
- z = self.vae.enc(x_tensor)
253
- except Exception as e:
254
- print(f"Error in encoder: {e}")
255
- # Try workaround for batch norm if needed
256
- if x.shape[0] == 1 and (hasattr(self.vae, 'bn1') or hasattr(self.vae, 'bn2')):
257
- print("Using batch normalization workaround for single sample")
258
- # Repeat the input to create a batch of size 2
259
- if len(x.shape) == 2:
260
- x_batch = np.repeat(x, 2, axis=0)
261
- else:
262
- x_batch = np.array([x[0], x[0]])
263
-
264
- # Process the batch
265
- x_tensor = to_cuda(to_torch(x_batch), self.vae.use_cuda)
266
- z_batch = self.vae.enc(x_tensor)
267
-
268
- # Extract just the first sample's latent representation
269
- z = z_batch[0:1]
270
  else:
271
- # Re-raise if we can't handle it
272
- raise
273
-
274
- return to_numpy(z)
275
-
276
- def save(self, path):
277
- train_losses = getattr(self, 'train_losses', [])
278
- val_losses = getattr(self, 'val_losses', [])
279
-
280
- # Make sure train_losses and val_losses are regular Python lists of float
281
- if train_losses:
282
- train_losses = [float(x) for x in train_losses]
283
- else:
284
- train_losses = []
285
 
286
- if val_losses:
287
- val_losses = [float(x) for x in val_losses]
288
- else:
289
- val_losses = []
290
-
291
- # Save state dict separately (most compatible way)
292
- torch.save(self.vae.state_dict(), f"{path}_state_dict")
293
- print(f"Saved VAE model state to {path}_state_dict")
294
-
295
- # Save metadata as simple numpy arrays
296
- import numpy as np
297
- import json
298
- np.savez(
299
- f"{path}_metadata.npz",
300
- train_losses=np.array(train_losses, dtype=np.float32),
301
- val_losses=np.array(val_losses, dtype=np.float32),
302
- input_dim=np.array([self.input_dim], dtype=np.int32),
303
- demo_dim=np.array([self.demo_dim], dtype=np.int32)
304
- )
305
-
306
- # Save parameters and pred_stats to JSON
307
- params_json = {}
308
- for k, v in self.get_params().items():
309
- if isinstance(v, (int, float)):
310
- params_json[k] = float(v)
311
- elif isinstance(v, bool):
312
- params_json[k] = v
313
- else:
314
- params_json[k] = str(v)
315
-
316
- # Convert pred_stats to JSON-serializable format
317
- pred_stats_json = []
318
- for stat in self.pred_stats:
319
- if isinstance(stat, (list, tuple)):
320
- pred_stats_json.append([float(v) if isinstance(v, (int, float)) else str(v) for v in stat])
321
- else:
322
- pred_stats_json.append(stat)
323
-
324
- with open(f"{path}_params.json", 'w') as f:
325
- json.dump({
326
- 'params': params_json,
327
- 'pred_stats': pred_stats_json
328
- }, f)
329
-
330
- # Also save with original method as a backup
331
- try:
332
- model_dict = {
333
- 'model_state_dict': self.vae.state_dict(),
334
- 'params': params_json,
335
- 'pred_stats': pred_stats_json,
336
- 'input_dim': int(self.input_dim),
337
- 'demo_dim': int(self.demo_dim),
338
- 'train_losses': train_losses,
339
- 'val_losses': val_losses
340
- }
341
- torch.save(model_dict, path)
342
- print(f"Saved VAE model to {path}")
343
- except Exception as e:
344
- print(f"Error saving model with default settings: {e}")
345
- print(f"Falling back to component files {path}_*")
346
-
347
- def load(self, path):
348
- # Simplified load function focusing on component-based loading first
349
- try:
350
- print(f"Attempting to load model from component files {path}_*")
351
- import json
352
- import numpy as np
353
- import os
354
 
355
- # Check if component files exist
356
- state_dict_path = f"{path}_state_dict"
357
- metadata_path = f"{path}_metadata.npz"
358
- params_path = f"{path}_params.json"
359
 
360
- if os.path.exists(state_dict_path) and os.path.exists(metadata_path) and os.path.exists(params_path):
361
- # Load state dict from the most reliable source
362
- print(f"Loading state dict from {state_dict_path}")
363
- state_dict = torch.load(state_dict_path, map_location='cpu')
364
-
365
- # Load metadata
366
- print(f"Loading metadata from {metadata_path}")
367
- metadata = np.load(metadata_path, allow_pickle=True)
368
- self.input_dim = int(metadata['input_dim'][0])
369
- self.demo_dim = int(metadata['demo_dim'][0])
370
-
371
- # Load training histories if available
372
- if 'train_losses' in metadata:
373
- self.train_losses = metadata['train_losses'].tolist()
374
  else:
375
- self.train_losses = []
376
-
377
- if 'val_losses' in metadata:
378
- self.val_losses = metadata['val_losses'].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  else:
380
- self.val_losses = []
 
 
 
381
 
382
- # Load parameters and pred_stats
383
- print(f"Loading parameters from {params_path}")
384
- with open(params_path, 'r') as f:
385
- json_data = json.load(f)
386
- self.set_params(**json_data['params'])
387
- self.pred_stats = json_data['pred_stats']
388
-
389
- # Initialize model and load state dict
390
- print("Initializing VAE model with loaded parameters")
391
- try:
392
- # First create model with proper typing
393
- device = torch.device("cpu") # Always start with CPU
394
- self.vae = VAE(
395
- input_dim=int(self.input_dim),
396
- latent_dim=int(self.latent_dim),
397
- demo_dim=int(self.demo_dim),
398
- use_cuda=False # Initially False, move to CUDA later if needed
399
- )
400
-
401
- # Then load state dict
402
- self.vae.load_state_dict(state_dict)
403
- print(f"Successfully created VAE model and loaded state dict")
404
-
405
- # Move to CUDA if needed
406
- if self.use_cuda and torch.cuda.is_available():
407
- self.vae.cuda()
408
- print("Moved model to CUDA")
409
- except Exception as e:
410
- print(f"Error initializing VAE model: {e}")
411
- # Create model without trying to use saved parameters
412
- self.vae = VAE(
413
- input_dim=100, # Default size
414
- latent_dim=16, # Small default
415
- demo_dim=4, # Default
416
- use_cuda=False # Avoid CUDA issues
417
- )
418
- print("Created default VAE model (loading state dict failed)")
419
-
420
- print(f"Successfully loaded VAE model from component files {path}_*")
421
-
422
- # If component files don't exist, try loading the combined file
423
- else:
424
- print(f"Component files not found. Trying to load from {path}")
425
- try:
426
- # Simple approach for PyTorch 2.1
427
- checkpoint = torch.load(path, map_location='cpu')
428
-
429
- # Initialize from checkpoint
430
- self.set_params(**checkpoint['params'])
431
- self.pred_stats = checkpoint['pred_stats']
432
- self.input_dim = checkpoint['input_dim']
433
- self.demo_dim = checkpoint['demo_dim']
434
-
435
- # Initialize model and load state dict
436
- try:
437
- # Create model on CPU first
438
- self.vae = VAE(
439
- input_dim=int(self.input_dim),
440
- latent_dim=int(self.latent_dim),
441
- demo_dim=int(self.demo_dim),
442
- use_cuda=False # Start with CPU
443
- )
444
-
445
- # Then load state dict
446
- self.vae.load_state_dict(checkpoint['model_state_dict'])
447
-
448
- # Move to CUDA if needed
449
- if self.use_cuda and torch.cuda.is_available():
450
- self.vae.cuda()
451
- except Exception as e:
452
- print(f"Error creating VAE model: {e}")
453
- # Fallback to a default model
454
- self.vae = VAE(
455
- input_dim=100,
456
- latent_dim=16,
457
- demo_dim=4,
458
- use_cuda=False
459
- )
460
-
461
- # Load training history
462
- if 'train_losses' in checkpoint:
463
- self.train_losses = checkpoint['train_losses']
464
- if 'val_losses' in checkpoint:
465
- self.val_losses = checkpoint['val_losses']
466
-
467
- print(f"Successfully loaded VAE model from {path}")
468
- except Exception as e:
469
- print(f"Error loading model: {e}")
470
- raise
471
- except Exception as e:
472
- import os
473
- print(f"Error during model loading: {e}")
474
- print("Available files in models directory:")
475
- if os.path.exists('models'):
476
- print('\n'.join(os.listdir('models')))
477
- else:
478
- print("models directory does not exist")
479
 
480
- # Create a minimal model for fallback
481
- print("Creating a new untrained model as fallback")
482
- self.input_dim = 100 # Default size for a typical FC matrix
483
- self.demo_dim = 4 # Default for common demographic variables
484
- self.pred_stats = []
485
- self.train_losses = []
486
- self.val_losses = []
487
- self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
- raise RuntimeError(f"Unable to load VAE model: {e}")
 
 
 
 
 
 
 
 
 
490
 
491
- # Move model to appropriate device after loading
492
- if self.use_cuda and torch.cuda.is_available():
493
- self.vae.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  else:
495
- self.vae.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified VAE implementation with explicit loss tracking.
3
+ """
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import numpy as np
8
+ import os
9
+ import matplotlib.pyplot as plt
10
  from sklearn.base import BaseEstimator
11
 
12
+ class SimpleVAE(nn.Module):
13
+ def __init__(self, input_dim, latent_dim, demo_dim):
14
+ super(SimpleVAE, self).__init__()
15
+ # Store dimensions
16
  self.input_dim = input_dim
17
  self.latent_dim = latent_dim
18
  self.demo_dim = demo_dim
 
19
 
20
+ # Encoder (FC data latent)
21
+ self.enc1 = nn.Linear(input_dim, 256)
22
+ self.enc2 = nn.Linear(256, latent_dim)
23
 
24
+ # Decoder (latent + demographics → FC reconstruction)
25
+ self.dec1 = nn.Linear(latent_dim + demo_dim, 256)
26
+ self.dec2 = nn.Linear(256, input_dim)
27
 
28
+ def encode(self, x):
29
+ """Encode FC data to latent space"""
30
+ h = F.relu(self.enc1(x))
31
+ return self.enc2(h)
32
+
33
+ def decode(self, z, demo):
34
+ """Decode from latent space to FC reconstruction"""
35
+ # Combine latent with demographics
36
+ z_combined = torch.cat([z, demo], dim=1)
37
+ h = F.relu(self.dec1(z_combined))
38
+ return self.dec2(h)
39
 
40
+ def forward(self, x, demo):
41
+ """Full forward pass"""
42
+ z = self.encode(x)
43
+ return self.decode(z, demo)
44
 
45
+ class DemoVAE:
46
+ def __init__(self, nepochs=50, batch_size=8, latent_dim=16, lr=1e-3):
47
+ """Simple VAE model with demographic conditioning"""
48
+ self.nepochs = nepochs
49
+ self.batch_size = batch_size
50
+ self.latent_dim = latent_dim
51
+ self.lr = lr
52
+ self.vae = None
53
+ self.train_losses = []
54
+ self.val_losses = []
55
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
 
57
+ def preprocess_demo(self, demo_data, demo_types, n_samples=None):
58
+ """Process demographic data into one-hot encoded tensors"""
59
+ if n_samples is None:
60
+ n_samples = len(demo_data[0])
61
+
62
+ processed_demos = []
63
+ total_dims = 0
64
 
65
+ # Process each demographic variable
66
+ for i, (data, dtype) in enumerate(zip(demo_data, demo_types)):
67
+ if dtype == 'continuous':
68
+ # For continuous variables, just normalize
69
+ data_np = np.array(data).reshape(-1, 1)
70
+ mean, std = np.mean(data_np), np.std(data_np)
71
+ if std == 0: # Handle constant values
72
+ normalized = np.zeros_like(data_np)
73
+ else:
74
+ normalized = (data_np - mean) / std
75
+ processed_demos.append(normalized)
76
+ total_dims += 1
77
+ elif dtype == 'categorical':
78
+ # For categorical, create one-hot encoding
79
+ data_list = list(data)
80
+ categories = sorted(list(set(data_list)))
81
+
82
+ # Create one-hot vectors
83
+ one_hot = np.zeros((len(data_list), len(categories)))
84
+ for j, val in enumerate(data_list):
85
+ idx = categories.index(val)
86
+ one_hot[j, idx] = 1
87
+
88
+ processed_demos.append(one_hot)
89
+ total_dims += len(categories)
90
 
91
+ # Combine all demographics
92
+ demo_tensor = np.hstack(processed_demos)
93
+ return torch.tensor(demo_tensor, dtype=torch.float32), total_dims
94
+
95
+ def fit(self, X, demo_data, demo_types):
96
+ """Train the VAE model"""
97
+ # Convert to numpy arrays if needed
98
+ X = np.array(X)
99
 
100
+ # Process demographics
101
+ print("Processing demographics...")
102
+ demo_tensor, demo_dim = self.preprocess_demo(demo_data, demo_types)
103
 
104
+ # Initialize model
105
+ input_dim = X.shape[1]
106
+ print(f"Creating model with input_dim={input_dim}, latent_dim={self.latent_dim}, demo_dim={demo_dim}")
107
+ self.vae = SimpleVAE(input_dim, self.latent_dim, demo_dim)
108
+ self.vae.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # Convert data to tensors
111
+ X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
112
+ demo_tensor = demo_tensor.to(self.device)
113
 
114
+ # Initialize optimizer
115
+ optimizer = torch.optim.Adam(self.vae.parameters(), lr=self.lr)
116
+
117
+ # Training loop
118
+ n_samples = X.shape[0]
119
+ batch_size = min(self.batch_size, n_samples)
120
+
121
+ # Clear any old losses
122
+ self.train_losses = []
123
+ self.val_losses = []
124
+
125
+ # Initial validation loss
 
 
 
 
 
 
 
 
 
126
  self.vae.eval()
127
+ with torch.no_grad():
128
+ reconstructed = self.vae(X_tensor, demo_tensor)
129
+ init_val_loss = F.mse_loss(reconstructed, X_tensor).item()
130
+ self.val_losses.append(init_val_loss)
131
+ print(f"Initial validation loss: {init_val_loss:.4f}")
132
 
133
+ # Main training loop
134
+ for epoch in range(self.nepochs):
135
+ epoch_losses = []
136
+ self.vae.train()
137
+
138
+ # Process in batches
139
+ for i in range(0, n_samples, batch_size):
140
+ # Get batch
141
+ end = min(i + batch_size, n_samples)
142
+ x_batch = X_tensor[i:end]
143
+ demo_batch = demo_tensor[i:end]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Forward pass
146
+ optimizer.zero_grad()
147
+ reconstructed = self.vae(x_batch, demo_batch)
148
 
149
+ # Calculate loss
150
+ loss = F.mse_loss(reconstructed, x_batch)
 
 
 
151
 
152
+ # Backward pass
153
+ loss.backward()
154
+ optimizer.step()
155
 
156
+ # Record loss
157
+ epoch_losses.append(loss.item())
 
 
158
 
159
+ # End of epoch
160
+ avg_loss = np.mean(epoch_losses)
161
+ self.train_losses.append(avg_loss)
162
+
163
+ # Validation
164
+ self.vae.eval()
165
+ with torch.no_grad():
166
+ reconstructed = self.vae(X_tensor, demo_tensor)
167
+ val_loss = F.mse_loss(reconstructed, X_tensor).item()
168
+ self.val_losses.append(val_loss)
169
+
170
+ # Print progress every few epochs
171
+ if (epoch + 1) % 5 == 0 or epoch == 0:
172
+ print(f"Epoch {epoch+1}/{self.nepochs} - "
173
+ f"Train loss: {avg_loss:.4f}, Val loss: {val_loss:.4f}")
174
+
175
+ print(f"Training complete! Final loss: {self.train_losses[-1]:.4f}")
176
+ print(f"Loss history: {len(self.train_losses)} train, {len(self.val_losses)} validation")
177
+
178
+ return self.train_losses, self.val_losses
179
+
180
+ def transform(self, X, demo_data, demo_types):
181
+ """Generate reconstructions or synthetic samples"""
182
+ # Check if model is available
183
+ if self.vae is None:
184
+ raise ValueError("Model not trained or loaded yet")
185
 
 
186
  # Set model to evaluation mode
187
  self.vae.eval()
188
 
189
+ # Check if we're generating or reconstructing
190
+ if isinstance(X, int):
191
+ # Generating n random samples
192
+ n_samples = X
193
+
194
+ # Process demo data (repeat single values if needed)
195
+ demo_list = []
196
+ for d in demo_data:
197
+ if not isinstance(d, (list, np.ndarray)):
198
+ # Single value, repeat for all samples
199
+ demo_list.append([d] * n_samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  else:
201
+ demo_list.append(d)
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ print(f"Generating {n_samples} samples with demo data: {demo_list}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ # Process demographics
206
+ demo_tensor, demo_dim = self.preprocess_demo(demo_list, demo_types, n_samples)
 
 
207
 
208
+ # Generate random latent vectors
209
+ z = torch.randn(n_samples, self.latent_dim).to(self.device)
210
+
211
+ else:
212
+ # Reconstructing existing data
213
+ X = np.array(X)
214
+ n_samples = X.shape[0]
215
+
216
+ # Process demo data (repeat single values if needed)
217
+ demo_list = []
218
+ for d in demo_data:
219
+ if not isinstance(d, (list, np.ndarray)) or len(d) != n_samples:
220
+ # Single value, repeat for all samples
221
+ demo_list.append([d] * n_samples)
222
  else:
223
+ demo_list.append(d)
224
+
225
+ # Process demographics
226
+ demo_tensor, demo_dim = self.preprocess_demo(demo_list, demo_types)
227
+
228
+ # Encode input data
229
+ X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
230
+ z = self.vae.encode(X_tensor)
231
+
232
+ # Print shapes for debugging
233
+ print(f"Latent shape: {z.shape}, Demo tensor shape: {demo_tensor.shape}")
234
+
235
+ # Decode to get output
236
+ demo_tensor = demo_tensor.to(self.device)
237
+ with torch.no_grad():
238
+ # Make sure demo_tensor has the right dimensions
239
+ if demo_tensor.shape[1] != self.vae.demo_dim:
240
+ print(f"WARNING: Demo dimension mismatch. Expected {self.vae.demo_dim}, got {demo_tensor.shape[1]}")
241
+ # Use demographic dimension from the model
242
+ if demo_tensor.shape[1] > self.vae.demo_dim:
243
+ # Trim extra dimensions
244
+ demo_tensor = demo_tensor[:, :self.vae.demo_dim]
245
  else:
246
+ # Pad with zeros
247
+ padding = torch.zeros(demo_tensor.shape[0], self.vae.demo_dim - demo_tensor.shape[1]).to(self.device)
248
+ demo_tensor = torch.cat([demo_tensor, padding], dim=1)
249
+ print(f"Adjusted demo tensor shape: {demo_tensor.shape}")
250
 
251
+ output = self.vae.decode(z, demo_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ # Convert to numpy
254
+ return output.cpu().numpy()
255
+
256
+ def get_latents(self, X):
257
+ """Encode data to latent representations"""
258
+ X = np.array(X)
259
+ X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
260
+
261
+ with torch.no_grad():
262
+ z = self.vae.encode(X_tensor)
263
+
264
+ return z.cpu().numpy()
265
+
266
+ def save(self, path):
267
+ """Save the model and training history"""
268
+ # Ensure the directory exists
269
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
270
+
271
+ # Create state dict with all necessary info
272
+ state = {
273
+ 'vae_state': self.vae.state_dict(),
274
+ 'input_dim': self.vae.input_dim,
275
+ 'latent_dim': self.latent_dim,
276
+ 'demo_dim': self.vae.demo_dim,
277
+ 'train_losses': self.train_losses,
278
+ 'val_losses': self.val_losses,
279
+ 'nepochs': self.nepochs,
280
+ 'batch_size': self.batch_size,
281
+ 'lr': self.lr
282
+ }
283
+
284
+ # Save the model
285
+ torch.save(state, path)
286
+ print(f"Model saved to {path}")
287
+
288
+ # Print info about saved losses
289
+ print(f"Saved loss data: {len(self.train_losses)} train, {len(self.val_losses)} validation")
290
+
291
+ def load(self, path):
292
+ """Load the model from a file"""
293
+ if not os.path.exists(path):
294
+ raise FileNotFoundError(f"Model file not found: {path}")
295
 
296
+ # Load state dict
297
+ state = torch.load(path, map_location=self.device)
298
+
299
+ # Set attributes
300
+ self.latent_dim = state['latent_dim']
301
+ self.nepochs = state.get('nepochs', 50)
302
+ self.batch_size = state.get('batch_size', 8)
303
+ self.lr = state.get('lr', 1e-3)
304
+ self.train_losses = state.get('train_losses', [])
305
+ self.val_losses = state.get('val_losses', [])
306
 
307
+ # Create model
308
+ self.vae = SimpleVAE(
309
+ input_dim=state['input_dim'],
310
+ latent_dim=self.latent_dim,
311
+ demo_dim=state['demo_dim']
312
+ )
313
+
314
+ # Load weights
315
+ self.vae.load_state_dict(state['vae_state'])
316
+ self.vae.to(self.device)
317
+
318
+ print(f"Model loaded from {path}")
319
+ print(f"Loaded loss data: {len(self.train_losses)} train, {len(self.val_losses)} validation")
320
+
321
+ def plot_learning_curves(train_losses, val_losses):
322
+ """Plot training and validation loss curves"""
323
+ # Create figure
324
+ plt.figure(figsize=(10, 6))
325
+
326
+ # Check if we have loss data
327
+ if not train_losses:
328
+ plt.text(0.5, 0.5, "No training loss data available",
329
+ ha='center', va='center', transform=plt.gca().transAxes,
330
+ fontsize=14, color='red')
331
+ plt.axis('off')
332
+ return plt.gcf()
333
+
334
+ # Plot losses
335
+ epochs = range(1, len(train_losses) + 1)
336
+ plt.plot(epochs, train_losses, 'b-', label='Training loss')
337
+
338
+ if val_losses:
339
+ # Adjust validation epochs if lengths differ
340
+ if len(val_losses) == len(train_losses) + 1:
341
+ # Initial validation + epoch validations
342
+ val_epochs = [0] + list(epochs)
343
  else:
344
+ val_epochs = epochs[:len(val_losses)]
345
+
346
+ plt.plot(val_epochs, val_losses, 'r-', label='Validation loss')
347
+
348
+ # Add labels
349
+ plt.title('VAE Training and Validation Loss')
350
+ plt.xlabel('Epoch')
351
+ plt.ylabel('Loss')
352
+ plt.legend()
353
+ plt.grid(True, alpha=0.3)
354
+
355
+ return plt.gcf()