Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- app.py +7 -1
- config.py +2 -2
- hf_cache/version.txt +1 -0
- main.py +5 -0
- test_small_sample.py +77 -0
- test_train.py +59 -0
- test_vae.py +55 -0
- utils.py +3 -1
- vae_model.py +199 -116
- visualization.py +9 -0
app.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from main import run_analysis
|
| 3 |
from rcf_prediction import AphasiaTreatmentPredictor
|
|
@@ -10,7 +17,6 @@ matplotlib.rcParams['savefig.dpi'] = 100
|
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
|
| 12 |
from visualization import plot_fc_matrices, plot_learning_curves
|
| 13 |
-
import os
|
| 14 |
import glob
|
| 15 |
from sklearn.metrics import mean_squared_error, r2_score
|
| 16 |
import json
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Set Huggingface cache directory to avoid permission issues
|
| 5 |
+
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
|
| 6 |
+
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
| 7 |
+
|
| 8 |
import gradio as gr
|
| 9 |
from main import run_analysis
|
| 10 |
from rcf_prediction import AphasiaTreatmentPredictor
|
|
|
|
| 17 |
import matplotlib.pyplot as plt
|
| 18 |
from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
|
| 19 |
from visualization import plot_fc_matrices, plot_learning_curves
|
|
|
|
| 20 |
import glob
|
| 21 |
from sklearn.metrics import mean_squared_error, r2_score
|
| 22 |
import json
|
config.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# Model configuration
|
| 2 |
MODEL_CONFIG = {
|
| 3 |
'latent_dim': 32,
|
| 4 |
-
'nepochs':
|
| 5 |
-
'bsize':
|
| 6 |
'loss_rec_mult': 100,
|
| 7 |
'loss_decor_mult': 10,
|
| 8 |
'lr': 1e-4
|
|
|
|
| 1 |
# Model configuration
|
| 2 |
MODEL_CONFIG = {
|
| 3 |
'latent_dim': 32,
|
| 4 |
+
'nepochs': 100, # Changed from 1000 to 100 for faster testing
|
| 5 |
+
'bsize': 5, # Changed from 16 to 5 for small sample sizes
|
| 6 |
'loss_rec_mult': 100,
|
| 7 |
'loss_decor_mult': 10,
|
| 8 |
'lr': 1e-4
|
hf_cache/version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1
|
main.py
CHANGED
|
@@ -3,6 +3,11 @@ import numpy as np # Make sure numpy is imported at the top level
|
|
| 3 |
import torch
|
| 4 |
from pathlib import Path
|
| 5 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from data_preprocessing import load_and_preprocess_data
|
| 7 |
from vae_model import DemoVAE
|
| 8 |
from rcf_prediction import AphasiaTreatmentPredictor
|
|
|
|
| 3 |
import torch
|
| 4 |
from pathlib import Path
|
| 5 |
import pandas as pd
|
| 6 |
+
|
| 7 |
+
# Set Huggingface cache directory to avoid permission issues
|
| 8 |
+
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
|
| 9 |
+
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
| 10 |
+
|
| 11 |
from data_preprocessing import load_and_preprocess_data
|
| 12 |
from vae_model import DemoVAE
|
| 13 |
from rcf_prediction import AphasiaTreatmentPredictor
|
test_small_sample.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Set Huggingface cache directory to avoid permission issues
|
| 3 |
+
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
|
| 4 |
+
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from vae_model import DemoVAE
|
| 10 |
+
from visualization import plot_learning_curves, plot_fc_matrices
|
| 11 |
+
from config import MODEL_CONFIG
|
| 12 |
+
|
| 13 |
+
# Create small synthetic dataset with only 5 samples
|
| 14 |
+
input_dim = 100
|
| 15 |
+
n_samples = 5
|
| 16 |
+
demo_dim = 4
|
| 17 |
+
|
| 18 |
+
print(f"Creating test dataset with {n_samples} samples...")
|
| 19 |
+
|
| 20 |
+
# Synthetic FC matrices (Upper triangular values)
|
| 21 |
+
X = np.random.randn(n_samples, input_dim)
|
| 22 |
+
|
| 23 |
+
# Synthetic demographics
|
| 24 |
+
demo_data = [
|
| 25 |
+
np.random.normal(60, 10, n_samples), # age
|
| 26 |
+
np.random.choice(['M', 'F'], n_samples), # sex
|
| 27 |
+
np.random.normal(24, 12, n_samples), # months post stroke
|
| 28 |
+
np.random.normal(50, 15, n_samples) # WAB score
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
# Types of demographics
|
| 32 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 33 |
+
|
| 34 |
+
# Initialize model with updated config
|
| 35 |
+
print("Config settings:")
|
| 36 |
+
print(f"- Epochs: {MODEL_CONFIG['nepochs']}")
|
| 37 |
+
print(f"- Batch size: {MODEL_CONFIG['bsize']}")
|
| 38 |
+
print(f"- Latent dim: {MODEL_CONFIG['latent_dim']}")
|
| 39 |
+
|
| 40 |
+
print("Initializing model...")
|
| 41 |
+
vae = DemoVAE(**MODEL_CONFIG)
|
| 42 |
+
|
| 43 |
+
# Train model
|
| 44 |
+
print(f"Training model with {n_samples} samples...")
|
| 45 |
+
train_losses, val_losses = vae.fit(X, demo_data, demo_types)
|
| 46 |
+
|
| 47 |
+
print(f"Training complete! Final train loss: {train_losses[-1]:.4f}")
|
| 48 |
+
print(f"Final validation loss: {val_losses[-1]:.4f}")
|
| 49 |
+
|
| 50 |
+
# Save model
|
| 51 |
+
os.makedirs("models", exist_ok=True)
|
| 52 |
+
os.makedirs("results", exist_ok=True)
|
| 53 |
+
print("Saving model...")
|
| 54 |
+
vae.save('models/vae_model_small.pt')
|
| 55 |
+
|
| 56 |
+
# Create learning curve visualization
|
| 57 |
+
print("Generating learning curve visualization...")
|
| 58 |
+
learning_fig = plot_learning_curves(train_losses, val_losses)
|
| 59 |
+
learning_fig.savefig('results/learning_curves_small.png')
|
| 60 |
+
print("Learning curve saved to results/learning_curves_small.png")
|
| 61 |
+
|
| 62 |
+
# Generate reconstructed data
|
| 63 |
+
print("Generating reconstructions...")
|
| 64 |
+
reconstructed = vae.transform(X, demo_data, demo_types)
|
| 65 |
+
|
| 66 |
+
# Get a single sample for FC visualization
|
| 67 |
+
original = X[0].reshape(10, 10) # Reshape to square matrix for visualization
|
| 68 |
+
recon = reconstructed[0].reshape(10, 10)
|
| 69 |
+
generated = vae.transform(1, [d[:1] for d in demo_data], demo_types)[0].reshape(10, 10)
|
| 70 |
+
|
| 71 |
+
# Create FC visualization
|
| 72 |
+
print("Generating FC matrix visualization...")
|
| 73 |
+
fc_fig = plot_fc_matrices(original, recon, generated)
|
| 74 |
+
fc_fig.savefig('results/fc_visualization_small.png')
|
| 75 |
+
print("FC visualization saved to results/fc_visualization_small.png")
|
| 76 |
+
|
| 77 |
+
print("Test with small sample size completed successfully!")
|
test_train.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from vae_model import DemoVAE
|
| 4 |
+
|
| 5 |
+
# Create synthetic data
|
| 6 |
+
input_dim = 100
|
| 7 |
+
n_samples = 20
|
| 8 |
+
demo_dim = 4
|
| 9 |
+
|
| 10 |
+
# Synthetic FC matrices (Upper triangular values)
|
| 11 |
+
X = np.random.randn(n_samples, input_dim)
|
| 12 |
+
|
| 13 |
+
# Synthetic demographics
|
| 14 |
+
demo_data = [
|
| 15 |
+
np.random.normal(60, 10, n_samples), # age
|
| 16 |
+
np.random.choice([0, 1], n_samples), # sex
|
| 17 |
+
np.random.normal(24, 12, n_samples), # months post stroke
|
| 18 |
+
np.random.normal(50, 15, n_samples) # WAB score
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# Types of demographics
|
| 22 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 23 |
+
|
| 24 |
+
# Initialize model
|
| 25 |
+
model_config = {
|
| 26 |
+
'latent_dim': 16,
|
| 27 |
+
'nepochs': 5,
|
| 28 |
+
'bsize': 5,
|
| 29 |
+
'use_cuda': False
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
print("Initializing model...")
|
| 33 |
+
vae = DemoVAE(**model_config)
|
| 34 |
+
|
| 35 |
+
# Train model
|
| 36 |
+
print("Training model...")
|
| 37 |
+
train_losses, val_losses = vae.fit(X, demo_data, demo_types)
|
| 38 |
+
|
| 39 |
+
print(f"Training complete! Train loss: {train_losses[-1]}, Val loss: {val_losses[-1]}")
|
| 40 |
+
|
| 41 |
+
# Check shapes of losses
|
| 42 |
+
print(f"Train losses shape: {len(train_losses)}")
|
| 43 |
+
print(f"Val losses shape: {len(val_losses)}")
|
| 44 |
+
|
| 45 |
+
# Save model
|
| 46 |
+
print("Saving model...")
|
| 47 |
+
vae.save('models/vae_model.pt')
|
| 48 |
+
|
| 49 |
+
# Try loading the model
|
| 50 |
+
print("Loading model...")
|
| 51 |
+
vae2 = DemoVAE()
|
| 52 |
+
vae2.load('models/vae_model.pt')
|
| 53 |
+
|
| 54 |
+
# Test reconstruction
|
| 55 |
+
print("Testing reconstruction...")
|
| 56 |
+
reconstructed = vae2.transform(X, demo_data, demo_types)
|
| 57 |
+
print(f"Reconstructed shape: {reconstructed.shape}")
|
| 58 |
+
|
| 59 |
+
print("All tests passed!")
|
test_vae.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
from sklearn.base import BaseEstimator
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
class SimplifiedVAE(nn.Module):
|
| 10 |
+
def __init__(self, input_dim=100, latent_dim=16, demo_dim=4):
|
| 11 |
+
super(SimplifiedVAE, self).__init__()
|
| 12 |
+
self.input_dim = input_dim
|
| 13 |
+
self.latent_dim = latent_dim
|
| 14 |
+
self.demo_dim = demo_dim
|
| 15 |
+
|
| 16 |
+
# Create layers with explicit dtype
|
| 17 |
+
self.enc1 = nn.Linear(input_dim, 128)
|
| 18 |
+
self.enc2 = nn.Linear(128, latent_dim)
|
| 19 |
+
|
| 20 |
+
# Decoder
|
| 21 |
+
self.dec1 = nn.Linear(latent_dim+demo_dim, 128)
|
| 22 |
+
self.dec2 = nn.Linear(128, input_dim)
|
| 23 |
+
|
| 24 |
+
def encode(self, x):
|
| 25 |
+
h = F.relu(self.enc1(x))
|
| 26 |
+
return self.enc2(h)
|
| 27 |
+
|
| 28 |
+
def decode(self, z, demo):
|
| 29 |
+
z_combined = torch.cat([z, demo], dim=1)
|
| 30 |
+
h = F.relu(self.dec1(z_combined))
|
| 31 |
+
return self.dec2(h)
|
| 32 |
+
|
| 33 |
+
# Create basic synthetic data
|
| 34 |
+
input_dim = 100
|
| 35 |
+
demo_dim = 4
|
| 36 |
+
latent_dim = 16
|
| 37 |
+
|
| 38 |
+
# Create model
|
| 39 |
+
print("Creating model...")
|
| 40 |
+
model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
|
| 41 |
+
print(f"Model created successfully.")
|
| 42 |
+
|
| 43 |
+
# Save state dict
|
| 44 |
+
os.makedirs("models", exist_ok=True)
|
| 45 |
+
print("Saving model...")
|
| 46 |
+
torch.save(model.state_dict(), "models/simple_vae.pt")
|
| 47 |
+
print("Model saved.")
|
| 48 |
+
|
| 49 |
+
# Create a new model and load the state dict
|
| 50 |
+
print("Loading model...")
|
| 51 |
+
new_model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
|
| 52 |
+
new_model.load_state_dict(torch.load("models/simple_vae.pt"))
|
| 53 |
+
print("Model loaded successfully.")
|
| 54 |
+
|
| 55 |
+
print("All tests passed!")
|
utils.py
CHANGED
|
@@ -8,7 +8,9 @@ def to_torch(x):
|
|
| 8 |
return torch.from_numpy(x).float()
|
| 9 |
|
| 10 |
def to_cuda(x, use_cuda):
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def to_numpy(x):
|
| 14 |
return x.detach().cpu().numpy()
|
|
|
|
| 8 |
return torch.from_numpy(x).float()
|
| 9 |
|
| 10 |
def to_cuda(x, use_cuda):
|
| 11 |
+
if use_cuda and torch.cuda.is_available():
|
| 12 |
+
return x.cuda()
|
| 13 |
+
return x
|
| 14 |
|
| 15 |
def to_numpy(x):
|
| 16 |
return x.detach().cpu().numpy()
|
vae_model.py
CHANGED
|
@@ -13,17 +13,21 @@ class VAE(nn.Module):
|
|
| 13 |
self.demo_dim = demo_dim
|
| 14 |
self.use_cuda = use_cuda
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
self.enc1 =
|
| 18 |
-
self.enc2 =
|
| 19 |
|
| 20 |
# Decoder
|
| 21 |
-
self.dec1 =
|
| 22 |
-
self.dec2 =
|
| 23 |
|
| 24 |
# Batch normalization layers
|
| 25 |
-
self.bn1 =
|
| 26 |
-
self.bn2 =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def enc(self, x):
|
| 29 |
# First layer with activation
|
|
@@ -64,9 +68,9 @@ class DemoVAE(BaseEstimator):
|
|
| 64 |
return dict(
|
| 65 |
latent_dim=32,
|
| 66 |
use_cuda=True,
|
| 67 |
-
nepochs=
|
| 68 |
-
pperiod=
|
| 69 |
-
bsize=
|
| 70 |
loss_C_mult=1,
|
| 71 |
loss_mu_mult=1,
|
| 72 |
loss_rec_mult=100,
|
|
@@ -269,137 +273,216 @@ class DemoVAE(BaseEstimator):
|
|
| 269 |
train_losses = getattr(self, 'train_losses', [])
|
| 270 |
val_losses = getattr(self, 'val_losses', [])
|
| 271 |
|
| 272 |
-
#
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
|
|
|
| 285 |
try:
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
torch.save(model_dict, path)
|
| 288 |
print(f"Saved VAE model to {path}")
|
| 289 |
except Exception as e:
|
| 290 |
print(f"Error saving model with default settings: {e}")
|
| 291 |
-
print("
|
| 292 |
-
|
| 293 |
-
# Save the state dict and other data separately for better compatibility
|
| 294 |
-
torch.save(self.vae.state_dict(), f"{path}_state_dict")
|
| 295 |
-
|
| 296 |
-
# Save non-model data as numpy arrays for better compatibility
|
| 297 |
-
import numpy as np
|
| 298 |
-
import json
|
| 299 |
-
np.savez(
|
| 300 |
-
f"{path}_metadata.npz",
|
| 301 |
-
train_losses=np.array(train_losses, dtype=np.float32),
|
| 302 |
-
val_losses=np.array(val_losses, dtype=np.float32),
|
| 303 |
-
input_dim=np.array([self.input_dim], dtype=np.int32),
|
| 304 |
-
demo_dim=np.array([self.demo_dim], dtype=np.int32)
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
# Save parameters and pred_stats to JSON
|
| 308 |
-
with open(f"{path}_params.json", 'w') as f:
|
| 309 |
-
json.dump({
|
| 310 |
-
'params': {k: (float(v) if isinstance(v, (int, float)) else str(v))
|
| 311 |
-
for k, v in self.get_params().items()},
|
| 312 |
-
'pred_stats': [[float(v) if isinstance(v, (int, float)) else str(v) for v in stat]
|
| 313 |
-
if isinstance(stat, (list, tuple)) else stat
|
| 314 |
-
for stat in self.pred_stats]
|
| 315 |
-
}, f)
|
| 316 |
-
|
| 317 |
-
print(f"Saved VAE model components to {path}_* files for compatibility")
|
| 318 |
|
| 319 |
def load(self, path):
|
|
|
|
| 320 |
try:
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
if hasattr(torch, '__version__') and torch.__version__.startswith('2.6'):
|
| 326 |
-
import numpy as np
|
| 327 |
-
# Add all necessary numpy types to safe globals list
|
| 328 |
-
if hasattr(torch.serialization, 'add_safe_globals'):
|
| 329 |
-
torch.serialization.add_safe_globals([
|
| 330 |
-
'numpy._core.multiarray.scalar',
|
| 331 |
-
'numpy.core.multiarray.scalar',
|
| 332 |
-
'numpy.ndarray',
|
| 333 |
-
'numpy._globals._NoValue'
|
| 334 |
-
])
|
| 335 |
-
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar']):
|
| 336 |
-
checkpoint = torch.load(path, weights_only=False)
|
| 337 |
-
else:
|
| 338 |
-
# For older PyTorch versions
|
| 339 |
-
checkpoint = torch.load(path)
|
| 340 |
-
except Exception as e:
|
| 341 |
-
print(f"Primary loading method failed: {str(e)}")
|
| 342 |
-
# Last resort - try with context manager if available
|
| 343 |
-
if hasattr(torch.serialization, 'safe_globals'):
|
| 344 |
-
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.core.multiarray.scalar']):
|
| 345 |
-
checkpoint = torch.load(path, weights_only=False)
|
| 346 |
-
else:
|
| 347 |
-
# Fall back to default with no safety
|
| 348 |
-
checkpoint = torch.load(path)
|
| 349 |
-
print("Successfully loaded checkpoint")
|
| 350 |
-
|
| 351 |
-
# Initialize from checkpoint
|
| 352 |
-
self.set_params(**checkpoint['params'])
|
| 353 |
-
self.pred_stats = checkpoint['pred_stats']
|
| 354 |
-
self.input_dim = checkpoint['input_dim']
|
| 355 |
-
self.demo_dim = checkpoint['demo_dim']
|
| 356 |
-
self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
|
| 357 |
-
self.vae.load_state_dict(checkpoint['model_state_dict'])
|
| 358 |
-
|
| 359 |
-
# Load training history if available
|
| 360 |
-
if 'train_losses' in checkpoint:
|
| 361 |
-
self.train_losses = checkpoint['train_losses']
|
| 362 |
-
if 'val_losses' in checkpoint:
|
| 363 |
-
self.val_losses = checkpoint['val_losses']
|
| 364 |
-
|
| 365 |
-
print(f"Successfully loaded VAE model from {path}")
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
| 370 |
|
| 371 |
-
|
| 372 |
-
#
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
# Load state dict
|
| 377 |
-
state_dict = torch.load(f"{path}_state_dict", map_location='cpu')
|
| 378 |
|
| 379 |
# Load metadata
|
| 380 |
-
|
|
|
|
| 381 |
self.input_dim = int(metadata['input_dim'][0])
|
| 382 |
self.demo_dim = int(metadata['demo_dim'][0])
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
# Load parameters and pred_stats
|
| 387 |
-
|
|
|
|
| 388 |
json_data = json.load(f)
|
| 389 |
self.set_params(**json_data['params'])
|
| 390 |
self.pred_stats = json_data['pred_stats']
|
| 391 |
|
| 392 |
# Initialize model and load state dict
|
| 393 |
-
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
print(f"Successfully loaded VAE model from component files {path}_*")
|
| 397 |
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# Move model to appropriate device after loading
|
| 405 |
if self.use_cuda and torch.cuda.is_available():
|
|
|
|
| 13 |
self.demo_dim = demo_dim
|
| 14 |
self.use_cuda = use_cuda
|
| 15 |
|
| 16 |
+
# Create layers with standard parameters (no .float() call)
|
| 17 |
+
self.enc1 = nn.Linear(input_dim, 1000)
|
| 18 |
+
self.enc2 = nn.Linear(1000, latent_dim)
|
| 19 |
|
| 20 |
# Decoder
|
| 21 |
+
self.dec1 = nn.Linear(latent_dim+demo_dim, 1000)
|
| 22 |
+
self.dec2 = nn.Linear(1000, input_dim)
|
| 23 |
|
| 24 |
# Batch normalization layers
|
| 25 |
+
self.bn1 = nn.BatchNorm1d(1000)
|
| 26 |
+
self.bn2 = nn.BatchNorm1d(1000)
|
| 27 |
+
|
| 28 |
+
# Move to CUDA if requested and available
|
| 29 |
+
if use_cuda and torch.cuda.is_available():
|
| 30 |
+
self.cuda()
|
| 31 |
|
| 32 |
def enc(self, x):
|
| 33 |
# First layer with activation
|
|
|
|
| 68 |
return dict(
|
| 69 |
latent_dim=32,
|
| 70 |
use_cuda=True,
|
| 71 |
+
nepochs=100, # Changed from 1000 to 100 for faster testing
|
| 72 |
+
pperiod=10, # Changed from 100 to 10 to see more progress updates
|
| 73 |
+
bsize=5, # Changed from 16 to 5 for small sample sizes
|
| 74 |
loss_C_mult=1,
|
| 75 |
loss_mu_mult=1,
|
| 76 |
loss_rec_mult=100,
|
|
|
|
| 273 |
train_losses = getattr(self, 'train_losses', [])
|
| 274 |
val_losses = getattr(self, 'val_losses', [])
|
| 275 |
|
| 276 |
+
# Make sure train_losses and val_losses are regular Python lists of float
|
| 277 |
+
if train_losses:
|
| 278 |
+
train_losses = [float(x) for x in train_losses]
|
| 279 |
+
else:
|
| 280 |
+
train_losses = []
|
| 281 |
+
|
| 282 |
+
if val_losses:
|
| 283 |
+
val_losses = [float(x) for x in val_losses]
|
| 284 |
+
else:
|
| 285 |
+
val_losses = []
|
| 286 |
+
|
| 287 |
+
# Save state dict separately (most compatible way)
|
| 288 |
+
torch.save(self.vae.state_dict(), f"{path}_state_dict")
|
| 289 |
+
print(f"Saved VAE model state to {path}_state_dict")
|
| 290 |
+
|
| 291 |
+
# Save metadata as simple numpy arrays
|
| 292 |
+
import numpy as np
|
| 293 |
+
import json
|
| 294 |
+
np.savez(
|
| 295 |
+
f"{path}_metadata.npz",
|
| 296 |
+
train_losses=np.array(train_losses, dtype=np.float32),
|
| 297 |
+
val_losses=np.array(val_losses, dtype=np.float32),
|
| 298 |
+
input_dim=np.array([self.input_dim], dtype=np.int32),
|
| 299 |
+
demo_dim=np.array([self.demo_dim], dtype=np.int32)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Save parameters and pred_stats to JSON
|
| 303 |
+
params_json = {}
|
| 304 |
+
for k, v in self.get_params().items():
|
| 305 |
+
if isinstance(v, (int, float)):
|
| 306 |
+
params_json[k] = float(v)
|
| 307 |
+
elif isinstance(v, bool):
|
| 308 |
+
params_json[k] = v
|
| 309 |
+
else:
|
| 310 |
+
params_json[k] = str(v)
|
| 311 |
+
|
| 312 |
+
# Convert pred_stats to JSON-serializable format
|
| 313 |
+
pred_stats_json = []
|
| 314 |
+
for stat in self.pred_stats:
|
| 315 |
+
if isinstance(stat, (list, tuple)):
|
| 316 |
+
pred_stats_json.append([float(v) if isinstance(v, (int, float)) else str(v) for v in stat])
|
| 317 |
+
else:
|
| 318 |
+
pred_stats_json.append(stat)
|
| 319 |
+
|
| 320 |
+
with open(f"{path}_params.json", 'w') as f:
|
| 321 |
+
json.dump({
|
| 322 |
+
'params': params_json,
|
| 323 |
+
'pred_stats': pred_stats_json
|
| 324 |
+
}, f)
|
| 325 |
|
| 326 |
+
# Also save with original method as a backup
|
| 327 |
try:
|
| 328 |
+
model_dict = {
|
| 329 |
+
'model_state_dict': self.vae.state_dict(),
|
| 330 |
+
'params': params_json,
|
| 331 |
+
'pred_stats': pred_stats_json,
|
| 332 |
+
'input_dim': int(self.input_dim),
|
| 333 |
+
'demo_dim': int(self.demo_dim),
|
| 334 |
+
'train_losses': train_losses,
|
| 335 |
+
'val_losses': val_losses
|
| 336 |
+
}
|
| 337 |
torch.save(model_dict, path)
|
| 338 |
print(f"Saved VAE model to {path}")
|
| 339 |
except Exception as e:
|
| 340 |
print(f"Error saving model with default settings: {e}")
|
| 341 |
+
print(f"Falling back to component files {path}_*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
def load(self, path):
|
| 344 |
+
# Simplified load function focusing on component-based loading first
|
| 345 |
try:
|
| 346 |
+
print(f"Attempting to load model from component files {path}_*")
|
| 347 |
+
import json
|
| 348 |
+
import numpy as np
|
| 349 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
# Check if component files exist
|
| 352 |
+
state_dict_path = f"{path}_state_dict"
|
| 353 |
+
metadata_path = f"{path}_metadata.npz"
|
| 354 |
+
params_path = f"{path}_params.json"
|
| 355 |
|
| 356 |
+
if os.path.exists(state_dict_path) and os.path.exists(metadata_path) and os.path.exists(params_path):
|
| 357 |
+
# Load state dict from the most reliable source
|
| 358 |
+
print(f"Loading state dict from {state_dict_path}")
|
| 359 |
+
state_dict = torch.load(state_dict_path, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
# Load metadata
|
| 362 |
+
print(f"Loading metadata from {metadata_path}")
|
| 363 |
+
metadata = np.load(metadata_path, allow_pickle=True)
|
| 364 |
self.input_dim = int(metadata['input_dim'][0])
|
| 365 |
self.demo_dim = int(metadata['demo_dim'][0])
|
| 366 |
+
|
| 367 |
+
# Load training histories if available
|
| 368 |
+
if 'train_losses' in metadata:
|
| 369 |
+
self.train_losses = metadata['train_losses'].tolist()
|
| 370 |
+
else:
|
| 371 |
+
self.train_losses = []
|
| 372 |
+
|
| 373 |
+
if 'val_losses' in metadata:
|
| 374 |
+
self.val_losses = metadata['val_losses'].tolist()
|
| 375 |
+
else:
|
| 376 |
+
self.val_losses = []
|
| 377 |
|
| 378 |
# Load parameters and pred_stats
|
| 379 |
+
print(f"Loading parameters from {params_path}")
|
| 380 |
+
with open(params_path, 'r') as f:
|
| 381 |
json_data = json.load(f)
|
| 382 |
self.set_params(**json_data['params'])
|
| 383 |
self.pred_stats = json_data['pred_stats']
|
| 384 |
|
| 385 |
# Initialize model and load state dict
|
| 386 |
+
print("Initializing VAE model with loaded parameters")
|
| 387 |
+
try:
|
| 388 |
+
# First create model with proper typing
|
| 389 |
+
device = torch.device("cpu") # Always start with CPU
|
| 390 |
+
self.vae = VAE(
|
| 391 |
+
input_dim=int(self.input_dim),
|
| 392 |
+
latent_dim=int(self.latent_dim),
|
| 393 |
+
demo_dim=int(self.demo_dim),
|
| 394 |
+
use_cuda=False # Initially False, move to CUDA later if needed
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Then load state dict
|
| 398 |
+
self.vae.load_state_dict(state_dict)
|
| 399 |
+
print(f"Successfully created VAE model and loaded state dict")
|
| 400 |
+
|
| 401 |
+
# Move to CUDA if needed
|
| 402 |
+
if self.use_cuda and torch.cuda.is_available():
|
| 403 |
+
self.vae.cuda()
|
| 404 |
+
print("Moved model to CUDA")
|
| 405 |
+
except Exception as e:
|
| 406 |
+
print(f"Error initializing VAE model: {e}")
|
| 407 |
+
# Create model without trying to use saved parameters
|
| 408 |
+
self.vae = VAE(
|
| 409 |
+
input_dim=100, # Default size
|
| 410 |
+
latent_dim=16, # Small default
|
| 411 |
+
demo_dim=4, # Default
|
| 412 |
+
use_cuda=False # Avoid CUDA issues
|
| 413 |
+
)
|
| 414 |
+
print("Created default VAE model (loading state dict failed)")
|
| 415 |
|
| 416 |
print(f"Successfully loaded VAE model from component files {path}_*")
|
| 417 |
|
| 418 |
+
# If component files don't exist, try loading the combined file
|
| 419 |
+
else:
|
| 420 |
+
print(f"Component files not found. Trying to load from {path}")
|
| 421 |
+
try:
|
| 422 |
+
# Simple approach for PyTorch 2.1
|
| 423 |
+
checkpoint = torch.load(path, map_location='cpu')
|
| 424 |
+
|
| 425 |
+
# Initialize from checkpoint
|
| 426 |
+
self.set_params(**checkpoint['params'])
|
| 427 |
+
self.pred_stats = checkpoint['pred_stats']
|
| 428 |
+
self.input_dim = checkpoint['input_dim']
|
| 429 |
+
self.demo_dim = checkpoint['demo_dim']
|
| 430 |
+
|
| 431 |
+
# Initialize model and load state dict
|
| 432 |
+
try:
|
| 433 |
+
# Create model on CPU first
|
| 434 |
+
self.vae = VAE(
|
| 435 |
+
input_dim=int(self.input_dim),
|
| 436 |
+
latent_dim=int(self.latent_dim),
|
| 437 |
+
demo_dim=int(self.demo_dim),
|
| 438 |
+
use_cuda=False # Start with CPU
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Then load state dict
|
| 442 |
+
self.vae.load_state_dict(checkpoint['model_state_dict'])
|
| 443 |
+
|
| 444 |
+
# Move to CUDA if needed
|
| 445 |
+
if self.use_cuda and torch.cuda.is_available():
|
| 446 |
+
self.vae.cuda()
|
| 447 |
+
except Exception as e:
|
| 448 |
+
print(f"Error creating VAE model: {e}")
|
| 449 |
+
# Fallback to a default model
|
| 450 |
+
self.vae = VAE(
|
| 451 |
+
input_dim=100,
|
| 452 |
+
latent_dim=16,
|
| 453 |
+
demo_dim=4,
|
| 454 |
+
use_cuda=False
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Load training history
|
| 458 |
+
if 'train_losses' in checkpoint:
|
| 459 |
+
self.train_losses = checkpoint['train_losses']
|
| 460 |
+
if 'val_losses' in checkpoint:
|
| 461 |
+
self.val_losses = checkpoint['val_losses']
|
| 462 |
+
|
| 463 |
+
print(f"Successfully loaded VAE model from {path}")
|
| 464 |
+
except Exception as e:
|
| 465 |
+
print(f"Error loading model: {e}")
|
| 466 |
+
raise
|
| 467 |
+
except Exception as e:
|
| 468 |
+
import os
|
| 469 |
+
print(f"Error during model loading: {e}")
|
| 470 |
+
print("Available files in models directory:")
|
| 471 |
+
if os.path.exists('models'):
|
| 472 |
+
print('\n'.join(os.listdir('models')))
|
| 473 |
+
else:
|
| 474 |
+
print("models directory does not exist")
|
| 475 |
+
|
| 476 |
+
# Create a minimal model for fallback
|
| 477 |
+
print("Creating a new untrained model as fallback")
|
| 478 |
+
self.input_dim = 100 # Default size for a typical FC matrix
|
| 479 |
+
self.demo_dim = 4 # Default for common demographic variables
|
| 480 |
+
self.pred_stats = []
|
| 481 |
+
self.train_losses = []
|
| 482 |
+
self.val_losses = []
|
| 483 |
+
self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
|
| 484 |
+
|
| 485 |
+
raise RuntimeError(f"Unable to load VAE model: {e}")
|
| 486 |
|
| 487 |
# Move model to appropriate device after loading
|
| 488 |
if self.use_cuda and torch.cuda.is_available():
|
visualization.py
CHANGED
|
@@ -397,6 +397,15 @@ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke
|
|
| 397 |
def plot_learning_curves(train_losses, val_losses):
|
| 398 |
"""Plot VAE learning curves with enhanced visualization"""
|
| 399 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
# Convert to numpy arrays for safe handling
|
| 401 |
train_np = np.array(train_losses)
|
| 402 |
val_np = np.array(val_losses)
|
|
|
|
| 397 |
def plot_learning_curves(train_losses, val_losses):
|
| 398 |
"""Plot VAE learning curves with enhanced visualization"""
|
| 399 |
try:
|
| 400 |
+
# Handle empty or None inputs
|
| 401 |
+
if not train_losses or train_losses is None:
|
| 402 |
+
print("WARNING: No training loss data provided")
|
| 403 |
+
train_losses = [0.0]
|
| 404 |
+
|
| 405 |
+
if not val_losses or val_losses is None:
|
| 406 |
+
print("WARNING: No validation loss data provided")
|
| 407 |
+
val_losses = [0.0]
|
| 408 |
+
|
| 409 |
# Convert to numpy arrays for safe handling
|
| 410 |
train_np = np.array(train_losses)
|
| 411 |
val_np = np.array(val_losses)
|