SreekarB commited on
Commit
763369a
·
verified ·
1 Parent(s): 79a3849

Upload 10 files

Browse files
Files changed (10) hide show
  1. app.py +7 -1
  2. config.py +2 -2
  3. hf_cache/version.txt +1 -0
  4. main.py +5 -0
  5. test_small_sample.py +77 -0
  6. test_train.py +59 -0
  7. test_vae.py +55 -0
  8. utils.py +3 -1
  9. vae_model.py +199 -116
  10. visualization.py +9 -0
app.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from main import run_analysis
3
  from rcf_prediction import AphasiaTreatmentPredictor
@@ -10,7 +17,6 @@ matplotlib.rcParams['savefig.dpi'] = 100
10
  import matplotlib.pyplot as plt
11
  from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
12
  from visualization import plot_fc_matrices, plot_learning_curves
13
- import os
14
  import glob
15
  from sklearn.metrics import mean_squared_error, r2_score
16
  import json
 
1
+ import os
2
+ import sys
3
+
4
+ # Set Huggingface cache directory to avoid permission issues
5
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
6
+ os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
7
+
8
  import gradio as gr
9
  from main import run_analysis
10
  from rcf_prediction import AphasiaTreatmentPredictor
 
17
  import matplotlib.pyplot as plt
18
  from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
19
  from visualization import plot_fc_matrices, plot_learning_curves
 
20
  import glob
21
  from sklearn.metrics import mean_squared_error, r2_score
22
  import json
config.py CHANGED
@@ -1,8 +1,8 @@
1
  # Model configuration
2
  MODEL_CONFIG = {
3
  'latent_dim': 32,
4
- 'nepochs': 1000,
5
- 'bsize': 16,
6
  'loss_rec_mult': 100,
7
  'loss_decor_mult': 10,
8
  'lr': 1e-4
 
1
  # Model configuration
2
  MODEL_CONFIG = {
3
  'latent_dim': 32,
4
+ 'nepochs': 100, # Changed from 1000 to 100 for faster testing
5
+ 'bsize': 5, # Changed from 16 to 5 for small sample sizes
6
  'loss_rec_mult': 100,
7
  'loss_decor_mult': 10,
8
  'lr': 1e-4
hf_cache/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
main.py CHANGED
@@ -3,6 +3,11 @@ import numpy as np # Make sure numpy is imported at the top level
3
  import torch
4
  from pathlib import Path
5
  import pandas as pd
 
 
 
 
 
6
  from data_preprocessing import load_and_preprocess_data
7
  from vae_model import DemoVAE
8
  from rcf_prediction import AphasiaTreatmentPredictor
 
3
  import torch
4
  from pathlib import Path
5
  import pandas as pd
6
+
7
+ # Set Huggingface cache directory to avoid permission issues
8
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
9
+ os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
10
+
11
  from data_preprocessing import load_and_preprocess_data
12
  from vae_model import DemoVAE
13
  from rcf_prediction import AphasiaTreatmentPredictor
test_small_sample.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Set Huggingface cache directory to avoid permission issues
3
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
4
+ os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
5
+
6
+ import numpy as np
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ from vae_model import DemoVAE
10
+ from visualization import plot_learning_curves, plot_fc_matrices
11
+ from config import MODEL_CONFIG
12
+
13
+ # Create small synthetic dataset with only 5 samples
14
+ input_dim = 100
15
+ n_samples = 5
16
+ demo_dim = 4
17
+
18
+ print(f"Creating test dataset with {n_samples} samples...")
19
+
20
+ # Synthetic FC matrices (Upper triangular values)
21
+ X = np.random.randn(n_samples, input_dim)
22
+
23
+ # Synthetic demographics
24
+ demo_data = [
25
+ np.random.normal(60, 10, n_samples), # age
26
+ np.random.choice(['M', 'F'], n_samples), # sex
27
+ np.random.normal(24, 12, n_samples), # months post stroke
28
+ np.random.normal(50, 15, n_samples) # WAB score
29
+ ]
30
+
31
+ # Types of demographics
32
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
33
+
34
+ # Initialize model with updated config
35
+ print("Config settings:")
36
+ print(f"- Epochs: {MODEL_CONFIG['nepochs']}")
37
+ print(f"- Batch size: {MODEL_CONFIG['bsize']}")
38
+ print(f"- Latent dim: {MODEL_CONFIG['latent_dim']}")
39
+
40
+ print("Initializing model...")
41
+ vae = DemoVAE(**MODEL_CONFIG)
42
+
43
+ # Train model
44
+ print(f"Training model with {n_samples} samples...")
45
+ train_losses, val_losses = vae.fit(X, demo_data, demo_types)
46
+
47
+ print(f"Training complete! Final train loss: {train_losses[-1]:.4f}")
48
+ print(f"Final validation loss: {val_losses[-1]:.4f}")
49
+
50
+ # Save model
51
+ os.makedirs("models", exist_ok=True)
52
+ os.makedirs("results", exist_ok=True)
53
+ print("Saving model...")
54
+ vae.save('models/vae_model_small.pt')
55
+
56
+ # Create learning curve visualization
57
+ print("Generating learning curve visualization...")
58
+ learning_fig = plot_learning_curves(train_losses, val_losses)
59
+ learning_fig.savefig('results/learning_curves_small.png')
60
+ print("Learning curve saved to results/learning_curves_small.png")
61
+
62
+ # Generate reconstructed data
63
+ print("Generating reconstructions...")
64
+ reconstructed = vae.transform(X, demo_data, demo_types)
65
+
66
+ # Get a single sample for FC visualization
67
+ original = X[0].reshape(10, 10) # Reshape to square matrix for visualization
68
+ recon = reconstructed[0].reshape(10, 10)
69
+ generated = vae.transform(1, [d[:1] for d in demo_data], demo_types)[0].reshape(10, 10)
70
+
71
+ # Create FC visualization
72
+ print("Generating FC matrix visualization...")
73
+ fc_fig = plot_fc_matrices(original, recon, generated)
74
+ fc_fig.savefig('results/fc_visualization_small.png')
75
+ print("FC visualization saved to results/fc_visualization_small.png")
76
+
77
+ print("Test with small sample size completed successfully!")
test_train.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from vae_model import DemoVAE
4
+
5
+ # Create synthetic data
6
+ input_dim = 100
7
+ n_samples = 20
8
+ demo_dim = 4
9
+
10
+ # Synthetic FC matrices (Upper triangular values)
11
+ X = np.random.randn(n_samples, input_dim)
12
+
13
+ # Synthetic demographics
14
+ demo_data = [
15
+ np.random.normal(60, 10, n_samples), # age
16
+ np.random.choice([0, 1], n_samples), # sex
17
+ np.random.normal(24, 12, n_samples), # months post stroke
18
+ np.random.normal(50, 15, n_samples) # WAB score
19
+ ]
20
+
21
+ # Types of demographics
22
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
23
+
24
+ # Initialize model
25
+ model_config = {
26
+ 'latent_dim': 16,
27
+ 'nepochs': 5,
28
+ 'bsize': 5,
29
+ 'use_cuda': False
30
+ }
31
+
32
+ print("Initializing model...")
33
+ vae = DemoVAE(**model_config)
34
+
35
+ # Train model
36
+ print("Training model...")
37
+ train_losses, val_losses = vae.fit(X, demo_data, demo_types)
38
+
39
+ print(f"Training complete! Train loss: {train_losses[-1]}, Val loss: {val_losses[-1]}")
40
+
41
+ # Check shapes of losses
42
+ print(f"Train losses shape: {len(train_losses)}")
43
+ print(f"Val losses shape: {len(val_losses)}")
44
+
45
+ # Save model
46
+ print("Saving model...")
47
+ vae.save('models/vae_model.pt')
48
+
49
+ # Try loading the model
50
+ print("Loading model...")
51
+ vae2 = DemoVAE()
52
+ vae2.load('models/vae_model.pt')
53
+
54
+ # Test reconstruction
55
+ print("Testing reconstruction...")
56
+ reconstructed = vae2.transform(X, demo_data, demo_types)
57
+ print(f"Reconstructed shape: {reconstructed.shape}")
58
+
59
+ print("All tests passed!")
test_vae.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ from sklearn.base import BaseEstimator
7
+ import json
8
+
9
+ class SimplifiedVAE(nn.Module):
10
+ def __init__(self, input_dim=100, latent_dim=16, demo_dim=4):
11
+ super(SimplifiedVAE, self).__init__()
12
+ self.input_dim = input_dim
13
+ self.latent_dim = latent_dim
14
+ self.demo_dim = demo_dim
15
+
16
+ # Create layers with explicit dtype
17
+ self.enc1 = nn.Linear(input_dim, 128)
18
+ self.enc2 = nn.Linear(128, latent_dim)
19
+
20
+ # Decoder
21
+ self.dec1 = nn.Linear(latent_dim+demo_dim, 128)
22
+ self.dec2 = nn.Linear(128, input_dim)
23
+
24
+ def encode(self, x):
25
+ h = F.relu(self.enc1(x))
26
+ return self.enc2(h)
27
+
28
+ def decode(self, z, demo):
29
+ z_combined = torch.cat([z, demo], dim=1)
30
+ h = F.relu(self.dec1(z_combined))
31
+ return self.dec2(h)
32
+
33
+ # Create basic synthetic data
34
+ input_dim = 100
35
+ demo_dim = 4
36
+ latent_dim = 16
37
+
38
+ # Create model
39
+ print("Creating model...")
40
+ model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
41
+ print(f"Model created successfully.")
42
+
43
+ # Save state dict
44
+ os.makedirs("models", exist_ok=True)
45
+ print("Saving model...")
46
+ torch.save(model.state_dict(), "models/simple_vae.pt")
47
+ print("Model saved.")
48
+
49
+ # Create a new model and load the state dict
50
+ print("Loading model...")
51
+ new_model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
52
+ new_model.load_state_dict(torch.load("models/simple_vae.pt"))
53
+ print("Model loaded successfully.")
54
+
55
+ print("All tests passed!")
utils.py CHANGED
@@ -8,7 +8,9 @@ def to_torch(x):
8
  return torch.from_numpy(x).float()
9
 
10
  def to_cuda(x, use_cuda):
11
- return x.cuda() if use_cuda else x
 
 
12
 
13
  def to_numpy(x):
14
  return x.detach().cpu().numpy()
 
8
  return torch.from_numpy(x).float()
9
 
10
  def to_cuda(x, use_cuda):
11
+ if use_cuda and torch.cuda.is_available():
12
+ return x.cuda()
13
+ return x
14
 
15
  def to_numpy(x):
16
  return x.detach().cpu().numpy()
vae_model.py CHANGED
@@ -13,17 +13,21 @@ class VAE(nn.Module):
13
  self.demo_dim = demo_dim
14
  self.use_cuda = use_cuda
15
 
16
- # Encoder
17
- self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
18
- self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
19
 
20
  # Decoder
21
- self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
22
- self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
23
 
24
  # Batch normalization layers
25
- self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
26
- self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
 
 
 
 
27
 
28
  def enc(self, x):
29
  # First layer with activation
@@ -64,9 +68,9 @@ class DemoVAE(BaseEstimator):
64
  return dict(
65
  latent_dim=32,
66
  use_cuda=True,
67
- nepochs=1000,
68
- pperiod=100,
69
- bsize=16,
70
  loss_C_mult=1,
71
  loss_mu_mult=1,
72
  loss_rec_mult=100,
@@ -269,137 +273,216 @@ class DemoVAE(BaseEstimator):
269
  train_losses = getattr(self, 'train_losses', [])
270
  val_losses = getattr(self, 'val_losses', [])
271
 
272
- # Only save the essential model components
273
- model_dict = {
274
- 'model_state_dict': self.vae.state_dict(),
275
- # Convert complex objects to simple types for better compatibility
276
- 'params': {k: (float(v) if isinstance(v, (int, float)) else v)
277
- for k, v in self.get_params().items()},
278
- 'pred_stats': self.pred_stats,
279
- 'input_dim': int(self.input_dim),
280
- 'demo_dim': int(self.demo_dim),
281
- 'train_losses': [float(x) for x in train_losses],
282
- 'val_losses': [float(x) for x in val_losses]
283
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
 
285
  try:
286
- # Try saving with weights_only=False first (for backward compatibility)
 
 
 
 
 
 
 
 
287
  torch.save(model_dict, path)
288
  print(f"Saved VAE model to {path}")
289
  except Exception as e:
290
  print(f"Error saving model with default settings: {e}")
291
- print("Trying with more compatible settings...")
292
-
293
- # Save the state dict and other data separately for better compatibility
294
- torch.save(self.vae.state_dict(), f"{path}_state_dict")
295
-
296
- # Save non-model data as numpy arrays for better compatibility
297
- import numpy as np
298
- import json
299
- np.savez(
300
- f"{path}_metadata.npz",
301
- train_losses=np.array(train_losses, dtype=np.float32),
302
- val_losses=np.array(val_losses, dtype=np.float32),
303
- input_dim=np.array([self.input_dim], dtype=np.int32),
304
- demo_dim=np.array([self.demo_dim], dtype=np.int32)
305
- )
306
-
307
- # Save parameters and pred_stats to JSON
308
- with open(f"{path}_params.json", 'w') as f:
309
- json.dump({
310
- 'params': {k: (float(v) if isinstance(v, (int, float)) else str(v))
311
- for k, v in self.get_params().items()},
312
- 'pred_stats': [[float(v) if isinstance(v, (int, float)) else str(v) for v in stat]
313
- if isinstance(stat, (list, tuple)) else stat
314
- for stat in self.pred_stats]
315
- }, f)
316
-
317
- print(f"Saved VAE model components to {path}_* files for compatibility")
318
 
319
  def load(self, path):
 
320
  try:
321
- # Try different loading methods based on PyTorch version
322
- print(f"Attempting to load model from {path}")
323
- try:
324
- # For PyTorch 2.6+, explicitly set weights_only=False for backward compatibility
325
- if hasattr(torch, '__version__') and torch.__version__.startswith('2.6'):
326
- import numpy as np
327
- # Add all necessary numpy types to safe globals list
328
- if hasattr(torch.serialization, 'add_safe_globals'):
329
- torch.serialization.add_safe_globals([
330
- 'numpy._core.multiarray.scalar',
331
- 'numpy.core.multiarray.scalar',
332
- 'numpy.ndarray',
333
- 'numpy._globals._NoValue'
334
- ])
335
- with torch.serialization.safe_globals(['numpy._core.multiarray.scalar']):
336
- checkpoint = torch.load(path, weights_only=False)
337
- else:
338
- # For older PyTorch versions
339
- checkpoint = torch.load(path)
340
- except Exception as e:
341
- print(f"Primary loading method failed: {str(e)}")
342
- # Last resort - try with context manager if available
343
- if hasattr(torch.serialization, 'safe_globals'):
344
- with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.core.multiarray.scalar']):
345
- checkpoint = torch.load(path, weights_only=False)
346
- else:
347
- # Fall back to default with no safety
348
- checkpoint = torch.load(path)
349
- print("Successfully loaded checkpoint")
350
-
351
- # Initialize from checkpoint
352
- self.set_params(**checkpoint['params'])
353
- self.pred_stats = checkpoint['pred_stats']
354
- self.input_dim = checkpoint['input_dim']
355
- self.demo_dim = checkpoint['demo_dim']
356
- self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
357
- self.vae.load_state_dict(checkpoint['model_state_dict'])
358
-
359
- # Load training history if available
360
- if 'train_losses' in checkpoint:
361
- self.train_losses = checkpoint['train_losses']
362
- if 'val_losses' in checkpoint:
363
- self.val_losses = checkpoint['val_losses']
364
-
365
- print(f"Successfully loaded VAE model from {path}")
366
 
367
- except Exception as primary_error:
368
- print(f"Standard loading failed: {primary_error}")
369
- print("Attempting to load from separate component files...")
 
370
 
371
- try:
372
- # Try loading from separated component files
373
- import json
374
- import numpy as np
375
-
376
- # Load state dict
377
- state_dict = torch.load(f"{path}_state_dict", map_location='cpu')
378
 
379
  # Load metadata
380
- metadata = np.load(f"{path}_metadata.npz")
 
381
  self.input_dim = int(metadata['input_dim'][0])
382
  self.demo_dim = int(metadata['demo_dim'][0])
383
- self.train_losses = metadata['train_losses'].tolist()
384
- self.val_losses = metadata['val_losses'].tolist()
 
 
 
 
 
 
 
 
 
385
 
386
  # Load parameters and pred_stats
387
- with open(f"{path}_params.json", 'r') as f:
 
388
  json_data = json.load(f)
389
  self.set_params(**json_data['params'])
390
  self.pred_stats = json_data['pred_stats']
391
 
392
  # Initialize model and load state dict
393
- self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
394
- self.vae.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  print(f"Successfully loaded VAE model from component files {path}_*")
397
 
398
- except Exception as secondary_error:
399
- error_message = (f"Failed to load model: Primary error: {primary_error}, "
400
- f"Secondary error: {secondary_error}")
401
- print(error_message)
402
- raise RuntimeError(f"Unable to load VAE model: {error_message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
  # Move model to appropriate device after loading
405
  if self.use_cuda and torch.cuda.is_available():
 
13
  self.demo_dim = demo_dim
14
  self.use_cuda = use_cuda
15
 
16
+ # Create layers with standard parameters (no .float() call)
17
+ self.enc1 = nn.Linear(input_dim, 1000)
18
+ self.enc2 = nn.Linear(1000, latent_dim)
19
 
20
  # Decoder
21
+ self.dec1 = nn.Linear(latent_dim+demo_dim, 1000)
22
+ self.dec2 = nn.Linear(1000, input_dim)
23
 
24
  # Batch normalization layers
25
+ self.bn1 = nn.BatchNorm1d(1000)
26
+ self.bn2 = nn.BatchNorm1d(1000)
27
+
28
+ # Move to CUDA if requested and available
29
+ if use_cuda and torch.cuda.is_available():
30
+ self.cuda()
31
 
32
  def enc(self, x):
33
  # First layer with activation
 
68
  return dict(
69
  latent_dim=32,
70
  use_cuda=True,
71
+ nepochs=100, # Changed from 1000 to 100 for faster testing
72
+ pperiod=10, # Changed from 100 to 10 to see more progress updates
73
+ bsize=5, # Changed from 16 to 5 for small sample sizes
74
  loss_C_mult=1,
75
  loss_mu_mult=1,
76
  loss_rec_mult=100,
 
273
  train_losses = getattr(self, 'train_losses', [])
274
  val_losses = getattr(self, 'val_losses', [])
275
 
276
+ # Make sure train_losses and val_losses are regular Python lists of float
277
+ if train_losses:
278
+ train_losses = [float(x) for x in train_losses]
279
+ else:
280
+ train_losses = []
281
+
282
+ if val_losses:
283
+ val_losses = [float(x) for x in val_losses]
284
+ else:
285
+ val_losses = []
286
+
287
+ # Save state dict separately (most compatible way)
288
+ torch.save(self.vae.state_dict(), f"{path}_state_dict")
289
+ print(f"Saved VAE model state to {path}_state_dict")
290
+
291
+ # Save metadata as simple numpy arrays
292
+ import numpy as np
293
+ import json
294
+ np.savez(
295
+ f"{path}_metadata.npz",
296
+ train_losses=np.array(train_losses, dtype=np.float32),
297
+ val_losses=np.array(val_losses, dtype=np.float32),
298
+ input_dim=np.array([self.input_dim], dtype=np.int32),
299
+ demo_dim=np.array([self.demo_dim], dtype=np.int32)
300
+ )
301
+
302
+ # Save parameters and pred_stats to JSON
303
+ params_json = {}
304
+ for k, v in self.get_params().items():
305
+ if isinstance(v, (int, float)):
306
+ params_json[k] = float(v)
307
+ elif isinstance(v, bool):
308
+ params_json[k] = v
309
+ else:
310
+ params_json[k] = str(v)
311
+
312
+ # Convert pred_stats to JSON-serializable format
313
+ pred_stats_json = []
314
+ for stat in self.pred_stats:
315
+ if isinstance(stat, (list, tuple)):
316
+ pred_stats_json.append([float(v) if isinstance(v, (int, float)) else str(v) for v in stat])
317
+ else:
318
+ pred_stats_json.append(stat)
319
+
320
+ with open(f"{path}_params.json", 'w') as f:
321
+ json.dump({
322
+ 'params': params_json,
323
+ 'pred_stats': pred_stats_json
324
+ }, f)
325
 
326
+ # Also save with original method as a backup
327
  try:
328
+ model_dict = {
329
+ 'model_state_dict': self.vae.state_dict(),
330
+ 'params': params_json,
331
+ 'pred_stats': pred_stats_json,
332
+ 'input_dim': int(self.input_dim),
333
+ 'demo_dim': int(self.demo_dim),
334
+ 'train_losses': train_losses,
335
+ 'val_losses': val_losses
336
+ }
337
  torch.save(model_dict, path)
338
  print(f"Saved VAE model to {path}")
339
  except Exception as e:
340
  print(f"Error saving model with default settings: {e}")
341
+ print(f"Falling back to component files {path}_*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  def load(self, path):
344
+ # Simplified load function focusing on component-based loading first
345
  try:
346
+ print(f"Attempting to load model from component files {path}_*")
347
+ import json
348
+ import numpy as np
349
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ # Check if component files exist
352
+ state_dict_path = f"{path}_state_dict"
353
+ metadata_path = f"{path}_metadata.npz"
354
+ params_path = f"{path}_params.json"
355
 
356
+ if os.path.exists(state_dict_path) and os.path.exists(metadata_path) and os.path.exists(params_path):
357
+ # Load state dict from the most reliable source
358
+ print(f"Loading state dict from {state_dict_path}")
359
+ state_dict = torch.load(state_dict_path, map_location='cpu')
 
 
 
360
 
361
  # Load metadata
362
+ print(f"Loading metadata from {metadata_path}")
363
+ metadata = np.load(metadata_path, allow_pickle=True)
364
  self.input_dim = int(metadata['input_dim'][0])
365
  self.demo_dim = int(metadata['demo_dim'][0])
366
+
367
+ # Load training histories if available
368
+ if 'train_losses' in metadata:
369
+ self.train_losses = metadata['train_losses'].tolist()
370
+ else:
371
+ self.train_losses = []
372
+
373
+ if 'val_losses' in metadata:
374
+ self.val_losses = metadata['val_losses'].tolist()
375
+ else:
376
+ self.val_losses = []
377
 
378
  # Load parameters and pred_stats
379
+ print(f"Loading parameters from {params_path}")
380
+ with open(params_path, 'r') as f:
381
  json_data = json.load(f)
382
  self.set_params(**json_data['params'])
383
  self.pred_stats = json_data['pred_stats']
384
 
385
  # Initialize model and load state dict
386
+ print("Initializing VAE model with loaded parameters")
387
+ try:
388
+ # First create model with proper typing
389
+ device = torch.device("cpu") # Always start with CPU
390
+ self.vae = VAE(
391
+ input_dim=int(self.input_dim),
392
+ latent_dim=int(self.latent_dim),
393
+ demo_dim=int(self.demo_dim),
394
+ use_cuda=False # Initially False, move to CUDA later if needed
395
+ )
396
+
397
+ # Then load state dict
398
+ self.vae.load_state_dict(state_dict)
399
+ print(f"Successfully created VAE model and loaded state dict")
400
+
401
+ # Move to CUDA if needed
402
+ if self.use_cuda and torch.cuda.is_available():
403
+ self.vae.cuda()
404
+ print("Moved model to CUDA")
405
+ except Exception as e:
406
+ print(f"Error initializing VAE model: {e}")
407
+ # Create model without trying to use saved parameters
408
+ self.vae = VAE(
409
+ input_dim=100, # Default size
410
+ latent_dim=16, # Small default
411
+ demo_dim=4, # Default
412
+ use_cuda=False # Avoid CUDA issues
413
+ )
414
+ print("Created default VAE model (loading state dict failed)")
415
 
416
  print(f"Successfully loaded VAE model from component files {path}_*")
417
 
418
+ # If component files don't exist, try loading the combined file
419
+ else:
420
+ print(f"Component files not found. Trying to load from {path}")
421
+ try:
422
+ # Simple approach for PyTorch 2.1
423
+ checkpoint = torch.load(path, map_location='cpu')
424
+
425
+ # Initialize from checkpoint
426
+ self.set_params(**checkpoint['params'])
427
+ self.pred_stats = checkpoint['pred_stats']
428
+ self.input_dim = checkpoint['input_dim']
429
+ self.demo_dim = checkpoint['demo_dim']
430
+
431
+ # Initialize model and load state dict
432
+ try:
433
+ # Create model on CPU first
434
+ self.vae = VAE(
435
+ input_dim=int(self.input_dim),
436
+ latent_dim=int(self.latent_dim),
437
+ demo_dim=int(self.demo_dim),
438
+ use_cuda=False # Start with CPU
439
+ )
440
+
441
+ # Then load state dict
442
+ self.vae.load_state_dict(checkpoint['model_state_dict'])
443
+
444
+ # Move to CUDA if needed
445
+ if self.use_cuda and torch.cuda.is_available():
446
+ self.vae.cuda()
447
+ except Exception as e:
448
+ print(f"Error creating VAE model: {e}")
449
+ # Fallback to a default model
450
+ self.vae = VAE(
451
+ input_dim=100,
452
+ latent_dim=16,
453
+ demo_dim=4,
454
+ use_cuda=False
455
+ )
456
+
457
+ # Load training history
458
+ if 'train_losses' in checkpoint:
459
+ self.train_losses = checkpoint['train_losses']
460
+ if 'val_losses' in checkpoint:
461
+ self.val_losses = checkpoint['val_losses']
462
+
463
+ print(f"Successfully loaded VAE model from {path}")
464
+ except Exception as e:
465
+ print(f"Error loading model: {e}")
466
+ raise
467
+ except Exception as e:
468
+ import os
469
+ print(f"Error during model loading: {e}")
470
+ print("Available files in models directory:")
471
+ if os.path.exists('models'):
472
+ print('\n'.join(os.listdir('models')))
473
+ else:
474
+ print("models directory does not exist")
475
+
476
+ # Create a minimal model for fallback
477
+ print("Creating a new untrained model as fallback")
478
+ self.input_dim = 100 # Default size for a typical FC matrix
479
+ self.demo_dim = 4 # Default for common demographic variables
480
+ self.pred_stats = []
481
+ self.train_losses = []
482
+ self.val_losses = []
483
+ self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
484
+
485
+ raise RuntimeError(f"Unable to load VAE model: {e}")
486
 
487
  # Move model to appropriate device after loading
488
  if self.use_cuda and torch.cuda.is_available():
visualization.py CHANGED
@@ -397,6 +397,15 @@ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke
397
  def plot_learning_curves(train_losses, val_losses):
398
  """Plot VAE learning curves with enhanced visualization"""
399
  try:
 
 
 
 
 
 
 
 
 
400
  # Convert to numpy arrays for safe handling
401
  train_np = np.array(train_losses)
402
  val_np = np.array(val_losses)
 
397
  def plot_learning_curves(train_losses, val_losses):
398
  """Plot VAE learning curves with enhanced visualization"""
399
  try:
400
+ # Handle empty or None inputs
401
+ if not train_losses or train_losses is None:
402
+ print("WARNING: No training loss data provided")
403
+ train_losses = [0.0]
404
+
405
+ if not val_losses or val_losses is None:
406
+ print("WARNING: No validation loss data provided")
407
+ val_losses = [0.0]
408
+
409
  # Convert to numpy arrays for safe handling
410
  train_np = np.array(train_losses)
411
  val_np = np.array(val_losses)