Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files
app.py
CHANGED
|
@@ -1763,23 +1763,67 @@ def create_learning_figure(vae):
|
|
| 1763 |
"""Helper function to create VAE learning curve figure"""
|
| 1764 |
plt.close('all') # Close previous figures
|
| 1765 |
|
| 1766 |
-
if
|
| 1767 |
-
|
| 1768 |
-
|
| 1769 |
-
|
| 1770 |
-
|
| 1771 |
-
|
| 1772 |
-
|
| 1773 |
else:
|
| 1774 |
-
logger.warning("No loss data found in VAE model
|
| 1775 |
-
|
| 1776 |
-
|
| 1777 |
-
|
| 1778 |
-
|
| 1779 |
-
|
| 1780 |
-
|
| 1781 |
-
|
| 1782 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1783 |
|
| 1784 |
def find_real_nifti_files(max_samples=2):
|
| 1785 |
"""Find real NIfTI files in the dataset or local directories, limited to the specified number"""
|
|
|
|
| 1763 |
"""Helper function to create VAE learning curve figure"""
|
| 1764 |
plt.close('all') # Close previous figures
|
| 1765 |
|
| 1766 |
+
# First check if loss data exists in the VAE object
|
| 1767 |
+
has_train_losses = hasattr(vae, 'train_losses') and isinstance(vae.train_losses, (list, tuple)) and len(vae.train_losses) > 0
|
| 1768 |
+
has_val_losses = hasattr(vae, 'val_losses') and isinstance(vae.val_losses, (list, tuple)) and len(vae.val_losses) > 0
|
| 1769 |
+
|
| 1770 |
+
# Log the status for debugging
|
| 1771 |
+
if has_train_losses:
|
| 1772 |
+
logger.info(f"Found training losses: {len(vae.train_losses)} points")
|
| 1773 |
else:
|
| 1774 |
+
logger.warning("No training loss data found in VAE model")
|
| 1775 |
+
|
| 1776 |
+
if has_val_losses:
|
| 1777 |
+
logger.info(f"Found validation losses: {len(vae.val_losses)} points")
|
| 1778 |
+
else:
|
| 1779 |
+
logger.warning("No validation loss data found in VAE model")
|
| 1780 |
+
|
| 1781 |
+
# If we have both train and validation losses, create the learning curve
|
| 1782 |
+
if has_train_losses and has_val_losses:
|
| 1783 |
+
logger.info(f"Creating learning curve with {len(vae.train_losses)} loss points")
|
| 1784 |
+
try:
|
| 1785 |
+
fig = plot_learning_curves(vae.train_losses, vae.val_losses)
|
| 1786 |
+
# Force rendering
|
| 1787 |
+
fig.canvas.draw()
|
| 1788 |
+
logger.info("Successfully created learning curve figure")
|
| 1789 |
+
return fig
|
| 1790 |
+
except Exception as e:
|
| 1791 |
+
logger.error(f"Error creating learning curve: {e}")
|
| 1792 |
+
# Fall through to the default figure below
|
| 1793 |
+
|
| 1794 |
+
# If we're missing one type of loss data but have the other, we can create a partial plot
|
| 1795 |
+
elif has_train_losses:
|
| 1796 |
+
logger.info("Creating learning curve with training losses only")
|
| 1797 |
+
try:
|
| 1798 |
+
# Create dummy validation losses (same as training but offset)
|
| 1799 |
+
dummy_val = [t * 1.1 for t in vae.train_losses]
|
| 1800 |
+
fig = plot_learning_curves(vae.train_losses, dummy_val)
|
| 1801 |
+
plt.title("VAE Learning Curve (Training Only)")
|
| 1802 |
+
plt.figtext(0.5, 0.01, "Note: Validation data unavailable",
|
| 1803 |
+
ha='center', fontsize=10, color='red')
|
| 1804 |
+
fig.canvas.draw()
|
| 1805 |
+
logger.info("Created partial learning curve with training data only")
|
| 1806 |
+
return fig
|
| 1807 |
+
except Exception as e:
|
| 1808 |
+
logger.error(f"Error creating partial learning curve: {e}")
|
| 1809 |
+
# Fall through to the default figure below
|
| 1810 |
+
|
| 1811 |
+
# Create a default figure if no loss data is available or plotting failed
|
| 1812 |
+
logger.warning("No complete loss data found - creating placeholder learning figure")
|
| 1813 |
+
fig = plt.figure(figsize=(10, 6))
|
| 1814 |
+
plt.title("VAE Learning Curve Data Unavailable", color='darkred')
|
| 1815 |
+
plt.xlabel("Epoch")
|
| 1816 |
+
plt.ylabel("Loss")
|
| 1817 |
+
plt.text(0.5, 0.5, "Learning curves will appear here after training",
|
| 1818 |
+
ha='center', va='center', transform=plt.gca().transAxes,
|
| 1819 |
+
fontsize=14)
|
| 1820 |
+
plt.text(0.5, 0.4, "Try using more training epochs to see learning progress",
|
| 1821 |
+
ha='center', va='center', transform=plt.gca().transAxes,
|
| 1822 |
+
fontsize=12, color='darkblue')
|
| 1823 |
+
plt.grid(True, alpha=0.3)
|
| 1824 |
+
plt.axis('on')
|
| 1825 |
+
fig.canvas.draw()
|
| 1826 |
+
return fig
|
| 1827 |
|
| 1828 |
def find_real_nifti_files(max_samples=2):
|
| 1829 |
"""Find real NIfTI files in the dataset or local directories, limited to the specified number"""
|
main.py
CHANGED
|
@@ -129,10 +129,21 @@ def run_analysis(data_dir="data",
|
|
| 129 |
|
| 130 |
# Format demographics for predictor and results
|
| 131 |
demographics = {}
|
|
|
|
|
|
|
| 132 |
demo_keys = ['age_at_stroke', 'sex', 'months_post_stroke', 'wab_score']
|
|
|
|
|
|
|
|
|
|
| 133 |
for i, key in enumerate(demo_keys):
|
| 134 |
if i < len(demo_data):
|
| 135 |
demographics[key] = demo_data[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# Generate reconstructions and synthetic FC
|
| 138 |
try:
|
|
@@ -185,19 +196,52 @@ def run_analysis(data_dir="data",
|
|
| 185 |
|
| 186 |
# Learning curves
|
| 187 |
try:
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
else:
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
except Exception as e:
|
|
|
|
| 197 |
print(f"Error creating learning curve plot: {e}")
|
|
|
|
|
|
|
|
|
|
| 198 |
learning_fig = plt.figure(figsize=(10, 6))
|
| 199 |
-
plt.text(0.5, 0.5, "Error creating learning curves",
|
| 200 |
-
ha='center', va='center', transform=plt.gca().transAxes
|
|
|
|
|
|
|
| 201 |
plt.tight_layout()
|
| 202 |
|
| 203 |
# Initialize results dictionary
|
|
|
|
| 129 |
|
| 130 |
# Format demographics for predictor and results
|
| 131 |
demographics = {}
|
| 132 |
+
|
| 133 |
+
# Define both standard and alternative keys
|
| 134 |
demo_keys = ['age_at_stroke', 'sex', 'months_post_stroke', 'wab_score']
|
| 135 |
+
alternate_keys = {'age_at_stroke': 'age', 'months_post_stroke': 'mpo', 'wab_score': 'wab_aq'}
|
| 136 |
+
|
| 137 |
+
# Map demographic data to consistent keys
|
| 138 |
for i, key in enumerate(demo_keys):
|
| 139 |
if i < len(demo_data):
|
| 140 |
demographics[key] = demo_data[i]
|
| 141 |
+
# Also add alternate versions of the key for compatibility
|
| 142 |
+
if key in alternate_keys:
|
| 143 |
+
demographics[alternate_keys[key]] = demo_data[i]
|
| 144 |
+
|
| 145 |
+
# Print the keys available in demographics for debugging
|
| 146 |
+
print(f"Demographics keys available: {list(demographics.keys())}")
|
| 147 |
|
| 148 |
# Generate reconstructions and synthetic FC
|
| 149 |
try:
|
|
|
|
| 196 |
|
| 197 |
# Learning curves
|
| 198 |
try:
|
| 199 |
+
print("Creating learning curve visualization...")
|
| 200 |
+
|
| 201 |
+
# Check if losses are stored in the VAE object first (most reliable source)
|
| 202 |
+
if hasattr(vae, 'train_losses') and hasattr(vae, 'val_losses'):
|
| 203 |
+
if len(vae.train_losses) > 0 and len(vae.val_losses) > 0:
|
| 204 |
+
print(f"Using learning curves from VAE object: {len(vae.train_losses)} train, {len(vae.val_losses)} validation points")
|
| 205 |
+
learning_fig = plot_learning_curves(vae.train_losses, vae.val_losses)
|
| 206 |
+
else:
|
| 207 |
+
# Fall back to the losses passed directly
|
| 208 |
+
if train_losses and val_losses:
|
| 209 |
+
print(f"Using passed learning curves: {len(train_losses)} train, {len(val_losses)} validation points")
|
| 210 |
+
learning_fig = plot_learning_curves(train_losses, val_losses)
|
| 211 |
+
else:
|
| 212 |
+
# Create a placeholder
|
| 213 |
+
print("No training history available for learning curves")
|
| 214 |
+
learning_fig = plt.figure(figsize=(10, 6))
|
| 215 |
+
plt.text(0.5, 0.5, "Learning curve data unavailable",
|
| 216 |
+
ha='center', va='center', transform=plt.gca().transAxes,
|
| 217 |
+
fontsize=14, color='darkred')
|
| 218 |
+
plt.axis('off')
|
| 219 |
+
plt.tight_layout()
|
| 220 |
else:
|
| 221 |
+
# Fall back to the losses passed directly
|
| 222 |
+
if train_losses and val_losses:
|
| 223 |
+
print(f"Using passed learning curves: {len(train_losses)} train, {len(val_losses)} validation points")
|
| 224 |
+
learning_fig = plot_learning_curves(train_losses, val_losses)
|
| 225 |
+
else:
|
| 226 |
+
# Create a placeholder
|
| 227 |
+
print("No training history available for learning curves")
|
| 228 |
+
learning_fig = plt.figure(figsize=(10, 6))
|
| 229 |
+
plt.text(0.5, 0.5, "Learning curve data unavailable",
|
| 230 |
+
ha='center', va='center', transform=plt.gca().transAxes,
|
| 231 |
+
fontsize=14, color='darkred')
|
| 232 |
+
plt.axis('off')
|
| 233 |
+
plt.tight_layout()
|
| 234 |
except Exception as e:
|
| 235 |
+
import traceback
|
| 236 |
print(f"Error creating learning curve plot: {e}")
|
| 237 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 238 |
+
|
| 239 |
+
# Create a more informative error display
|
| 240 |
learning_fig = plt.figure(figsize=(10, 6))
|
| 241 |
+
plt.text(0.5, 0.5, f"Error creating learning curves: {str(e)}",
|
| 242 |
+
ha='center', va='center', transform=plt.gca().transAxes,
|
| 243 |
+
fontsize=12, color='darkred')
|
| 244 |
+
plt.axis('off')
|
| 245 |
plt.tight_layout()
|
| 246 |
|
| 247 |
# Initialize results dictionary
|
utils.py
CHANGED
|
@@ -129,6 +129,17 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
|
|
| 129 |
ce = torch.nn.CrossEntropyLoss()
|
| 130 |
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
for e in range(nepochs):
|
| 133 |
epoch_losses = []
|
| 134 |
vae.train()
|
|
@@ -162,15 +173,27 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
|
|
| 162 |
# Print progress for every epoch
|
| 163 |
print(f'Epoch {e+1}/{nepochs} - Train Loss: {epoch_loss:.4f}')
|
| 164 |
|
| 165 |
-
# Validation step
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
| 174 |
print(f' Validation - Val Loss: {val_loss:.4f}')
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
return train_losses, val_losses
|
|
|
|
| 129 |
ce = torch.nn.CrossEntropyLoss()
|
| 130 |
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 131 |
|
| 132 |
+
# Calculate initial validation loss
|
| 133 |
+
print("Calculating initial validation metrics...")
|
| 134 |
+
vae.eval()
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
z_val = vae.enc(x)
|
| 137 |
+
y_val = vae.dec(z_val, demo_t)
|
| 138 |
+
initial_val_loss = rmse(x, y_val).item()
|
| 139 |
+
val_losses.append(initial_val_loss)
|
| 140 |
+
print(f"Initial validation loss: {initial_val_loss:.4f}")
|
| 141 |
+
|
| 142 |
+
# Main training loop
|
| 143 |
for e in range(nepochs):
|
| 144 |
epoch_losses = []
|
| 145 |
vae.train()
|
|
|
|
| 173 |
# Print progress for every epoch
|
| 174 |
print(f'Epoch {e+1}/{nepochs} - Train Loss: {epoch_loss:.4f}')
|
| 175 |
|
| 176 |
+
# Validation step (perform at every epoch to have full data for plotting)
|
| 177 |
+
vae.eval()
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
z = vae.enc(x)
|
| 180 |
+
y = vae.dec(z, demo_t)
|
| 181 |
+
val_loss = rmse(x, y).item()
|
| 182 |
+
val_losses.append(val_loss)
|
| 183 |
+
|
| 184 |
+
# Only print detailed validation logs at pperiod intervals
|
| 185 |
+
if (e + 1) % pperiod == 0:
|
| 186 |
print(f' Validation - Val Loss: {val_loss:.4f}')
|
| 187 |
|
| 188 |
+
# Make sure losses are converted to regular Python lists (for serialization)
|
| 189 |
+
train_losses = [float(loss) for loss in train_losses]
|
| 190 |
+
val_losses = [float(loss) for loss in val_losses]
|
| 191 |
+
|
| 192 |
+
print(f"Training complete - Final train loss: {train_losses[-1]:.4f}, Val loss: {val_losses[-1]:.4f}")
|
| 193 |
+
print(f"Loss history recorded: {len(train_losses)} train points, {len(val_losses)} validation points")
|
| 194 |
+
|
| 195 |
+
# Store the losses in the return object for future reference
|
| 196 |
+
ret_obj.train_losses = train_losses
|
| 197 |
+
ret_obj.val_losses = val_losses
|
| 198 |
+
|
| 199 |
return train_losses, val_losses
|
visualization.py
CHANGED
|
@@ -312,15 +312,110 @@ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke
|
|
| 312 |
return fig
|
| 313 |
|
| 314 |
def plot_learning_curves(train_losses, val_losses):
|
| 315 |
-
"""Plot VAE learning curves"""
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
return fig
|
| 313 |
|
| 314 |
def plot_learning_curves(train_losses, val_losses):
|
| 315 |
+
"""Plot VAE learning curves with enhanced visualization"""
|
| 316 |
+
try:
|
| 317 |
+
# Convert to numpy arrays for safe handling
|
| 318 |
+
train_np = np.array(train_losses)
|
| 319 |
+
val_np = np.array(val_losses)
|
| 320 |
+
|
| 321 |
+
# Check for NaN values
|
| 322 |
+
if np.any(np.isnan(train_np)) or np.any(np.isnan(val_np)):
|
| 323 |
+
print("WARNING: Learning curves contain NaN values, replacing with zeros")
|
| 324 |
+
train_np = np.nan_to_num(train_np)
|
| 325 |
+
val_np = np.nan_to_num(val_np)
|
| 326 |
+
|
| 327 |
+
# Create figure
|
| 328 |
+
fig = plt.figure(figsize=(12, 6))
|
| 329 |
+
|
| 330 |
+
# Add improved styling
|
| 331 |
+
plt.rcParams['font.size'] = 12
|
| 332 |
+
|
| 333 |
+
# Check if train and val lengths match
|
| 334 |
+
if len(train_np) != len(val_np):
|
| 335 |
+
print(f"Training and validation loss lengths don't match: {len(train_np)} vs {len(val_np)}")
|
| 336 |
+
if len(train_np) > len(val_np):
|
| 337 |
+
# Validation might be evaluated less frequently
|
| 338 |
+
# Create epoch indices for each
|
| 339 |
+
train_epochs = np.arange(len(train_np))
|
| 340 |
+
val_factor = len(train_np) / len(val_np)
|
| 341 |
+
val_epochs = np.arange(0, len(train_np), val_factor)[:len(val_np)]
|
| 342 |
+
|
| 343 |
+
plt.plot(train_epochs, train_np, 'b-', linewidth=2, label='Training Loss')
|
| 344 |
+
plt.plot(val_epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
|
| 345 |
+
else:
|
| 346 |
+
# This is unusual, but handle it anyway
|
| 347 |
+
plt.plot(train_np, 'b-', linewidth=2, label='Training Loss')
|
| 348 |
+
plt.plot(val_np[:len(train_np)], 'r-', linewidth=2, label='Validation Loss')
|
| 349 |
+
else:
|
| 350 |
+
# Standard case - equal length arrays
|
| 351 |
+
epochs = np.arange(len(train_np))
|
| 352 |
+
plt.plot(epochs, train_np, 'b-', linewidth=2, label='Training Loss')
|
| 353 |
+
plt.plot(epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
|
| 354 |
+
|
| 355 |
+
# Add shaded confidence region
|
| 356 |
+
if len(train_np) > 5: # Only if we have enough points
|
| 357 |
+
# Calculate moving average for smoother trend lines
|
| 358 |
+
window_size = min(5, len(train_np) // 5)
|
| 359 |
+
if window_size > 1:
|
| 360 |
+
avg_train = np.convolve(train_np, np.ones(window_size)/window_size, mode='valid')
|
| 361 |
+
avg_val = np.convolve(val_np, np.ones(window_size)/window_size, mode='valid')
|
| 362 |
+
avg_epochs = epochs[window_size-1:]
|
| 363 |
+
plt.plot(avg_epochs, avg_train, 'b--', linewidth=1, alpha=0.6)
|
| 364 |
+
plt.plot(avg_epochs, avg_val, 'r--', linewidth=1, alpha=0.6)
|
| 365 |
+
|
| 366 |
+
# Calculate improvement from start to end
|
| 367 |
+
if len(train_np) > 1:
|
| 368 |
+
train_improvement = ((train_np[0] - train_np[-1]) / train_np[0]) * 100
|
| 369 |
+
if len(val_np) > 1:
|
| 370 |
+
val_improvement = ((val_np[0] - val_np[-1]) / val_np[0]) * 100
|
| 371 |
+
plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement, Validation: {val_improvement:.1f}% improvement')
|
| 372 |
+
else:
|
| 373 |
+
plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement')
|
| 374 |
+
else:
|
| 375 |
+
plt.title('VAE Learning Curves')
|
| 376 |
+
|
| 377 |
+
# Add min/max annotations
|
| 378 |
+
if len(train_np) > 0:
|
| 379 |
+
min_train = np.min(train_np)
|
| 380 |
+
min_train_epoch = np.argmin(train_np)
|
| 381 |
+
plt.annotate(f'Min: {min_train:.4f}', xy=(min_train_epoch, min_train),
|
| 382 |
+
xytext=(min_train_epoch+5, min_train+0.05),
|
| 383 |
+
arrowprops=dict(facecolor='blue', shrink=0.05, alpha=0.5),
|
| 384 |
+
color='blue', fontsize=10)
|
| 385 |
+
|
| 386 |
+
if len(val_np) > 0:
|
| 387 |
+
min_val = np.min(val_np)
|
| 388 |
+
min_val_epoch = np.argmin(val_np)
|
| 389 |
+
plt.annotate(f'Min: {min_val:.4f}', xy=(min_val_epoch, min_val),
|
| 390 |
+
xytext=(min_val_epoch+5, min_val+0.05),
|
| 391 |
+
arrowprops=dict(facecolor='red', shrink=0.05, alpha=0.5),
|
| 392 |
+
color='red', fontsize=10)
|
| 393 |
+
|
| 394 |
+
# Styling
|
| 395 |
+
plt.xlabel('Epoch')
|
| 396 |
+
plt.ylabel('Loss')
|
| 397 |
+
plt.legend(loc='upper right')
|
| 398 |
+
plt.grid(True, alpha=0.3)
|
| 399 |
+
|
| 400 |
+
# Set reasonable y-axis limits
|
| 401 |
+
all_losses = np.concatenate([train_np, val_np])
|
| 402 |
+
y_min = max(0, np.min(all_losses) * 0.9) # Don't go below zero
|
| 403 |
+
y_max = np.percentile(all_losses, 95) * 1.1 # Exclude outliers
|
| 404 |
+
plt.ylim(y_min, y_max)
|
| 405 |
+
|
| 406 |
+
plt.tight_layout()
|
| 407 |
+
return fig
|
| 408 |
+
|
| 409 |
+
except Exception as e:
|
| 410 |
+
import traceback
|
| 411 |
+
print(f"Error in plot_learning_curves: {e}")
|
| 412 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 413 |
+
|
| 414 |
+
# Create a simple error figure
|
| 415 |
+
fig = plt.figure(figsize=(10, 6))
|
| 416 |
+
plt.text(0.5, 0.5, f"Learning curves error: {str(e)}",
|
| 417 |
+
ha='center', va='center', transform=plt.gca().transAxes,
|
| 418 |
+
fontsize=12, color='red')
|
| 419 |
+
plt.axis('off')
|
| 420 |
+
plt.tight_layout()
|
| 421 |
+
return fig
|