| |
|
| | """
|
| | Advanced Multi-Modal Aphasia Classification System
|
| | With Adaptive Learning Rate and Comprehensive Reporting
|
| | """
|
| |
|
| | import re
|
| | import json
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import time
|
| | import datetime
|
| | import numpy as np
|
| | import os
|
| | import random
|
| | import csv
|
| | import math
|
| | from collections import Counter, defaultdict
|
| | from typing import Dict, List, Optional, Tuple, Union
|
| | from dataclasses import dataclass
|
| |
|
| | import torch.optim as optim
|
| | from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
|
| | from transformers import (
|
| | AutoTokenizer, AutoModel, AutoConfig,
|
| | TrainingArguments, Trainer, TrainerCallback,
|
| | EarlyStoppingCallback, get_cosine_schedule_with_warmup,
|
| | default_data_collator, set_seed
|
| | )
|
| |
|
| | import seaborn as sns
|
| | import matplotlib.pyplot as plt
|
| | import pandas as pd
|
| | from sklearn.metrics import (
|
| | accuracy_score, f1_score, precision_score, recall_score,
|
| | confusion_matrix, classification_report, roc_auc_score
|
| | )
|
| | from sklearn.model_selection import StratifiedKFold
|
| | import gc
|
| | from scipy import stats
|
| |
|
| |
|
| | os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| | os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| | json_file = '/workspace/SH001/aphasia_data_augmented.json'
|
| |
|
| |
|
| | def set_all_seeds(seed=42):
|
| | random.seed(seed)
|
| | np.random.seed(seed)
|
| | torch.manual_seed(seed)
|
| | torch.cuda.manual_seed_all(seed)
|
| | os.environ['PYTHONHASHSEED'] = str(seed)
|
| |
|
| | set_all_seeds(42)
|
| |
|
| |
|
| | @dataclass
|
| | class ModelConfig:
|
| |
|
| | model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| | max_length: int = 512
|
| | hidden_size: int = 768
|
| |
|
| |
|
| | pos_vocab_size: int = 150
|
| | pos_emb_dim: int = 64
|
| | grammar_dim: int = 3
|
| | grammar_hidden_dim: int = 64
|
| | duration_hidden_dim: int = 128
|
| | prosody_dim: int = 32
|
| |
|
| |
|
| | num_attention_heads: int = 8
|
| | attention_dropout: float = 0.3
|
| |
|
| |
|
| | classifier_hidden_dims: List[int] = None
|
| | dropout_rate: float = 0.3
|
| | activation_fn: str = "tanh"
|
| |
|
| |
|
| | learning_rate: float = 5e-4
|
| | weight_decay: float = 0.01
|
| | warmup_ratio: float = 0.1
|
| | batch_size: int = 10
|
| | num_epochs: int = 500
|
| | gradient_accumulation_steps: int = 4
|
| |
|
| |
|
| | adaptive_lr: bool = True
|
| | lr_patience: int = 3
|
| | lr_factor: float = 0.8
|
| | lr_increase_factor: float = 1.2
|
| | min_lr: float = 1e-6
|
| | max_lr: float = 1e-3
|
| | oscillation_amplitude: float = 0.1
|
| |
|
| |
|
| | use_focal_loss: bool = True
|
| | focal_alpha: float = 1.0
|
| | focal_gamma: float = 2.0
|
| | use_mixup: bool = False
|
| | mixup_alpha: float = 0.2
|
| | use_label_smoothing: bool = True
|
| | label_smoothing: float = 0.1
|
| |
|
| | def __post_init__(self):
|
| | if self.classifier_hidden_dims is None:
|
| | self.classifier_hidden_dims = [512, 256]
|
| |
|
| |
|
| | def log_message(message):
|
| | timestamp = datetime.datetime.now().isoformat()
|
| | full_message = f"{timestamp}: {message}"
|
| | log_file = "./training_log.txt"
|
| | with open(log_file, "a", encoding="utf-8") as f:
|
| | f.write(full_message + "\n")
|
| | print(full_message, flush=True)
|
| |
|
| | def clear_memory():
|
| | gc.collect()
|
| | if torch.cuda.is_available():
|
| | torch.cuda.empty_cache()
|
| |
|
| | def normalize_type(t):
|
| | return t.strip().upper() if isinstance(t, str) else t
|
| |
|
| |
|
| | class AdaptiveLearningRateScheduler:
|
| | """智能學習率調度器,結合多種策略"""
|
| | def __init__(self, optimizer, config: ModelConfig, total_steps: int):
|
| | self.optimizer = optimizer
|
| | self.config = config
|
| | self.total_steps = total_steps
|
| |
|
| |
|
| | self.loss_history = []
|
| | self.f1_history = []
|
| | self.accuracy_history = []
|
| | self.lr_history = []
|
| |
|
| |
|
| | self.plateau_counter = 0
|
| | self.best_f1 = 0.0
|
| | self.best_loss = float('inf')
|
| | self.step_count = 0
|
| |
|
| |
|
| | self.base_lr = config.learning_rate
|
| | self.current_lr = self.base_lr
|
| |
|
| | log_message(f"Adaptive LR Scheduler initialized with base_lr={self.base_lr}")
|
| |
|
| | def calculate_slope(self, values, window=3):
|
| | """計算近期數值的斜率"""
|
| | if len(values) < window:
|
| | return 0.0
|
| |
|
| | recent_values = values[-window:]
|
| | x = np.arange(len(recent_values))
|
| | slope, _, _, _, _ = stats.linregress(x, recent_values)
|
| | return slope
|
| |
|
| | def exponential_adjustment(self, current_value, target_value, base_factor=1.1):
|
| | """指數調整函數"""
|
| | ratio = current_value / target_value if target_value != 0 else 1.0
|
| | factor = math.exp(-ratio) * base_factor
|
| | return factor
|
| |
|
| | def logarithmic_adjustment(self, current_value, threshold=0.1):
|
| | """對數調整函數"""
|
| | if current_value <= 0:
|
| | return 1.0
|
| | factor = math.log(1 + current_value / threshold)
|
| | return max(0.5, min(2.0, factor))
|
| |
|
| | def sinusoidal_oscillation(self, step, amplitude=None):
|
| | """正弦波動調整"""
|
| | if amplitude is None:
|
| | amplitude = self.config.oscillation_amplitude
|
| |
|
| |
|
| | phase = 2 * math.pi * step / (self.total_steps / 4)
|
| | oscillation = 1 + amplitude * math.sin(phase)
|
| | return oscillation
|
| |
|
| | def cosine_decay(self, step):
|
| | """餘弦衰減"""
|
| | progress = step / self.total_steps
|
| | decay = 0.5 * (1 + math.cos(math.pi * progress))
|
| | return decay
|
| |
|
| | def adaptive_lr_calculation(self, current_loss, current_f1, current_acc):
|
| | """智能學習率計算"""
|
| |
|
| | self.loss_history.append(current_loss)
|
| | self.f1_history.append(current_f1)
|
| | self.accuracy_history.append(current_acc)
|
| |
|
| |
|
| | loss_slope = self.calculate_slope(self.loss_history)
|
| | f1_slope = self.calculate_slope(self.f1_history)
|
| | acc_slope = self.calculate_slope(self.accuracy_history)
|
| |
|
| |
|
| | adjustment_factor = 1.0
|
| |
|
| |
|
| | if abs(loss_slope) < 0.001:
|
| | log_message(f"Loss plateau detected (slope: {loss_slope:.6f})")
|
| |
|
| | exp_factor = self.exponential_adjustment(abs(loss_slope), 0.01, 1.15)
|
| | adjustment_factor *= exp_factor
|
| |
|
| | elif current_loss > 2.0:
|
| | log_message(f"High loss detected: {current_loss:.4f}")
|
| |
|
| | log_factor = self.logarithmic_adjustment(current_loss, 1.0)
|
| | adjustment_factor *= log_factor
|
| |
|
| |
|
| | if current_f1 < 0.3:
|
| | log_message(f"Low F1 detected: {current_f1:.4f}")
|
| |
|
| | exp_factor = self.exponential_adjustment(0.3, current_f1, 1.2)
|
| | adjustment_factor *= exp_factor
|
| |
|
| | elif abs(f1_slope) < 0.001:
|
| | log_message(f"F1 plateau detected (slope: {f1_slope:.6f})")
|
| | adjustment_factor *= 1.1
|
| |
|
| |
|
| | sin_factor = self.sinusoidal_oscillation(self.step_count)
|
| |
|
| |
|
| | cos_factor = self.cosine_decay(self.step_count)
|
| |
|
| |
|
| | final_factor = adjustment_factor * sin_factor * (0.3 + 0.7 * cos_factor)
|
| |
|
| |
|
| | new_lr = self.current_lr * final_factor
|
| |
|
| |
|
| | new_lr = max(self.config.min_lr, min(self.config.max_lr, new_lr))
|
| |
|
| |
|
| | if abs(new_lr - self.current_lr) > 1e-7:
|
| | self.current_lr = new_lr
|
| | for param_group in self.optimizer.param_groups:
|
| | param_group['lr'] = new_lr
|
| |
|
| | log_message(f"Learning rate adjusted: {new_lr:.2e} (factor: {final_factor:.3f})")
|
| | log_message(f" - Loss slope: {loss_slope:.6f}, F1 slope: {f1_slope:.6f}")
|
| | log_message(f" - Sin factor: {sin_factor:.3f}, Cos factor: {cos_factor:.3f}")
|
| |
|
| | self.lr_history.append(self.current_lr)
|
| | self.step_count += 1
|
| |
|
| | return self.current_lr
|
| |
|
| |
|
| | class TrainingHistoryTracker:
|
| | """訓練歷史記錄器"""
|
| | def __init__(self):
|
| | self.history = {
|
| | 'epoch': [],
|
| | 'train_loss': [],
|
| | 'eval_loss': [],
|
| | 'train_accuracy': [],
|
| | 'eval_accuracy': [],
|
| | 'train_f1': [],
|
| | 'eval_f1': [],
|
| | 'learning_rate': [],
|
| | 'train_precision': [],
|
| | 'eval_precision': [],
|
| | 'train_recall': [],
|
| | 'eval_recall': []
|
| | }
|
| |
|
| | def update(self, epoch, metrics):
|
| | """更新歷史記錄"""
|
| | self.history['epoch'].append(epoch)
|
| | for key, value in metrics.items():
|
| | if key in self.history:
|
| | self.history[key].append(value)
|
| |
|
| | def save_history(self, output_dir):
|
| | """保存歷史記錄"""
|
| | df = pd.DataFrame(self.history)
|
| | df.to_csv(os.path.join(output_dir, "training_history.csv"), index=False)
|
| | return df
|
| |
|
| | def plot_training_curves(self, output_dir):
|
| | """繪製訓練曲線"""
|
| | if not self.history['epoch']:
|
| | return
|
| |
|
| |
|
| | plt.style.use('seaborn-v0_8')
|
| | fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
| |
|
| | epochs = self.history['epoch']
|
| |
|
| |
|
| | axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
| | axes[0, 0].plot(epochs, self.history['eval_loss'], 'r-', label='Eval Loss', linewidth=2)
|
| | axes[0, 0].set_title('Loss Over Time', fontsize=14, fontweight='bold')
|
| | axes[0, 0].set_xlabel('Epoch')
|
| | axes[0, 0].set_ylabel('Loss')
|
| | axes[0, 0].legend()
|
| | axes[0, 0].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[0, 1].plot(epochs, self.history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
|
| | axes[0, 1].plot(epochs, self.history['eval_accuracy'], 'r-', label='Eval Accuracy', linewidth=2)
|
| | axes[0, 1].set_title('Accuracy Over Time', fontsize=14, fontweight='bold')
|
| | axes[0, 1].set_xlabel('Epoch')
|
| | axes[0, 1].set_ylabel('Accuracy')
|
| | axes[0, 1].legend()
|
| | axes[0, 1].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[0, 2].plot(epochs, self.history['train_f1'], 'b-', label='Train F1', linewidth=2)
|
| | axes[0, 2].plot(epochs, self.history['eval_f1'], 'r-', label='Eval F1', linewidth=2)
|
| | axes[0, 2].set_title('F1 Score Over Time', fontsize=14, fontweight='bold')
|
| | axes[0, 2].set_xlabel('Epoch')
|
| | axes[0, 2].set_ylabel('F1 Score')
|
| | axes[0, 2].legend()
|
| | axes[0, 2].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[1, 0].plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
| | axes[1, 0].set_title('Learning Rate Over Time', fontsize=14, fontweight='bold')
|
| | axes[1, 0].set_xlabel('Epoch')
|
| | axes[1, 0].set_ylabel('Learning Rate')
|
| | axes[1, 0].set_yscale('log')
|
| | axes[1, 0].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[1, 1].plot(epochs, self.history['train_precision'], 'b-', label='Train Precision', linewidth=2)
|
| | axes[1, 1].plot(epochs, self.history['eval_precision'], 'r-', label='Eval Precision', linewidth=2)
|
| | axes[1, 1].set_title('Precision Over Time', fontsize=14, fontweight='bold')
|
| | axes[1, 1].set_xlabel('Epoch')
|
| | axes[1, 1].set_ylabel('Precision')
|
| | axes[1, 1].legend()
|
| | axes[1, 1].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[1, 2].plot(epochs, self.history['train_recall'], 'b-', label='Train Recall', linewidth=2)
|
| | axes[1, 2].plot(epochs, self.history['eval_recall'], 'r-', label='Eval Recall', linewidth=2)
|
| | axes[1, 2].set_title('Recall Over Time', fontsize=14, fontweight='bold')
|
| | axes[1, 2].set_xlabel('Epoch')
|
| | axes[1, 2].set_ylabel('Recall')
|
| | axes[1, 2].legend()
|
| | axes[1, 2].grid(True, alpha=0.3)
|
| |
|
| | plt.tight_layout()
|
| | plt.savefig(os.path.join(output_dir, "training_curves.png"), dpi=300, bbox_inches='tight')
|
| | plt.close()
|
| |
|
| |
|
| | class FocalLoss(nn.Module):
|
| | def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
|
| | super().__init__()
|
| | self.alpha = alpha
|
| | self.gamma = gamma
|
| | self.reduction = reduction
|
| |
|
| | def forward(self, inputs, targets):
|
| | ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
| | pt = torch.exp(-ce_loss)
|
| | focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
|
| |
|
| | if self.reduction == 'mean':
|
| | return focal_loss.mean()
|
| | elif self.reduction == 'sum':
|
| | return focal_loss.sum()
|
| | else:
|
| | return focal_loss
|
| |
|
| |
|
| | class StablePositionalEncoding(nn.Module):
|
| | """Simplified but stable positional encoding"""
|
| | def __init__(self, d_model: int, max_len: int = 5000):
|
| | super().__init__()
|
| | self.d_model = d_model
|
| |
|
| |
|
| | pe = torch.zeros(max_len, d_model)
|
| | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| | (-math.log(10000.0) / d_model))
|
| |
|
| | pe[:, 0::2] = torch.sin(position * div_term)
|
| | pe[:, 1::2] = torch.cos(position * div_term)
|
| |
|
| | self.register_buffer('pe', pe.unsqueeze(0))
|
| |
|
| |
|
| | self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
|
| |
|
| | def forward(self, x):
|
| | seq_len = x.size(1)
|
| | sinusoidal = self.pe[:, :seq_len, :].to(x.device)
|
| | learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
|
| | return x + 0.1 * (sinusoidal + learnable)
|
| |
|
| |
|
| | class StableMultiHeadAttention(nn.Module):
|
| | """Stable multi-head attention for feature fusion"""
|
| | def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
|
| | super().__init__()
|
| | self.num_heads = num_heads
|
| | self.feature_dim = feature_dim
|
| | self.head_dim = feature_dim // num_heads
|
| |
|
| | assert feature_dim % num_heads == 0
|
| |
|
| | self.query = nn.Linear(feature_dim, feature_dim)
|
| | self.key = nn.Linear(feature_dim, feature_dim)
|
| | self.value = nn.Linear(feature_dim, feature_dim)
|
| | self.dropout = nn.Dropout(dropout)
|
| | self.output_proj = nn.Linear(feature_dim, feature_dim)
|
| | self.layer_norm = nn.LayerNorm(feature_dim)
|
| |
|
| | def forward(self, x, mask=None):
|
| | batch_size, seq_len, _ = x.size()
|
| |
|
| | Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| |
|
| | scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| |
|
| | if mask is not None:
|
| | if mask.dim() == 2:
|
| | mask = mask.unsqueeze(1).unsqueeze(1)
|
| | scores.masked_fill_(mask == 0, -1e9)
|
| |
|
| | attn_weights = F.softmax(scores, dim=-1)
|
| | attn_weights = self.dropout(attn_weights)
|
| |
|
| | context = torch.matmul(attn_weights, V)
|
| | context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
|
| |
|
| | output = self.output_proj(context)
|
| | return self.layer_norm(output + x)
|
| |
|
| |
|
| | class StableLinguisticFeatureExtractor(nn.Module):
|
| | """Stable linguistic feature processing"""
|
| | def __init__(self, config: ModelConfig):
|
| | super().__init__()
|
| | self.config = config
|
| |
|
| |
|
| | self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
|
| | self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
|
| |
|
| |
|
| | self.grammar_projection = nn.Sequential(
|
| | nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
|
| | nn.Tanh(),
|
| | nn.LayerNorm(config.grammar_hidden_dim),
|
| | nn.Dropout(config.dropout_rate * 0.3)
|
| | )
|
| |
|
| |
|
| | self.duration_projection = nn.Sequential(
|
| | nn.Linear(1, config.duration_hidden_dim),
|
| | nn.Tanh(),
|
| | nn.LayerNorm(config.duration_hidden_dim)
|
| | )
|
| |
|
| |
|
| | self.prosody_projection = nn.Sequential(
|
| | nn.Linear(config.prosody_dim, config.prosody_dim),
|
| | nn.ReLU(),
|
| | nn.LayerNorm(config.prosody_dim)
|
| | )
|
| |
|
| |
|
| | total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
| | config.duration_hidden_dim + config.prosody_dim)
|
| | self.feature_fusion = nn.Sequential(
|
| | nn.Linear(total_feature_dim, total_feature_dim // 2),
|
| | nn.Tanh(),
|
| | nn.LayerNorm(total_feature_dim // 2),
|
| | nn.Dropout(config.dropout_rate)
|
| | )
|
| |
|
| | def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
|
| | batch_size, seq_len = pos_ids.size()
|
| |
|
| |
|
| | pos_ids_clamped = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
|
| | pos_embeds = self.pos_embedding(pos_ids_clamped)
|
| | pos_features = self.pos_attention(pos_embeds, attention_mask)
|
| |
|
| |
|
| | grammar_features = self.grammar_projection(grammar_ids.float())
|
| |
|
| |
|
| | duration_features = self.duration_projection(durations.unsqueeze(-1).float())
|
| |
|
| |
|
| | prosody_features = self.prosody_projection(prosody_features.float())
|
| |
|
| |
|
| | combined_features = torch.cat([
|
| | pos_features, grammar_features, duration_features, prosody_features
|
| | ], dim=-1)
|
| |
|
| |
|
| | fused_features = self.feature_fusion(combined_features)
|
| |
|
| |
|
| | mask_expanded = attention_mask.unsqueeze(-1).float()
|
| | pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
|
| |
|
| | return pooled_features
|
| |
|
| |
|
| | class StableAphasiaClassifier(nn.Module):
|
| | """Stable aphasia classification model"""
|
| | def __init__(self, config: ModelConfig, num_labels: int):
|
| | super().__init__()
|
| | self.config = config
|
| | self.num_labels = num_labels
|
| |
|
| |
|
| | self.bert = AutoModel.from_pretrained(config.model_name)
|
| | self.bert_config = self.bert.config
|
| |
|
| |
|
| | for param in self.bert.embeddings.parameters():
|
| | param.requires_grad = False
|
| |
|
| |
|
| | self.positional_encoder = StablePositionalEncoding(
|
| | d_model=self.bert_config.hidden_size,
|
| | max_len=config.max_length
|
| | )
|
| |
|
| |
|
| | self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
|
| |
|
| |
|
| | bert_dim = self.bert_config.hidden_size
|
| | linguistic_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
| | config.duration_hidden_dim + config.prosody_dim) // 2
|
| |
|
| |
|
| | self.feature_fusion = nn.Sequential(
|
| | nn.Linear(bert_dim + linguistic_dim, bert_dim),
|
| | nn.LayerNorm(bert_dim),
|
| | nn.Tanh(),
|
| | nn.Dropout(config.dropout_rate)
|
| | )
|
| |
|
| |
|
| | self.classifier = self._build_classifier(bert_dim, num_labels)
|
| |
|
| |
|
| | self.severity_head = nn.Sequential(
|
| | nn.Linear(bert_dim, 4),
|
| | nn.Softmax(dim=-1)
|
| | )
|
| |
|
| | self.fluency_head = nn.Sequential(
|
| | nn.Linear(bert_dim, 1),
|
| | nn.Sigmoid()
|
| | )
|
| |
|
| | def _build_classifier(self, input_dim: int, num_labels: int):
|
| | layers = []
|
| | current_dim = input_dim
|
| |
|
| | for hidden_dim in self.config.classifier_hidden_dims:
|
| | layers.extend([
|
| | nn.Linear(current_dim, hidden_dim),
|
| | nn.LayerNorm(hidden_dim),
|
| | nn.Tanh(),
|
| | nn.Dropout(self.config.dropout_rate)
|
| | ])
|
| | current_dim = hidden_dim
|
| |
|
| | layers.append(nn.Linear(current_dim, num_labels))
|
| | return nn.Sequential(*layers)
|
| |
|
| | def forward(self, input_ids, attention_mask, labels=None,
|
| | word_pos_ids=None, word_grammar_ids=None, word_durations=None,
|
| | prosody_features=None, **kwargs):
|
| |
|
| |
|
| | bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| | sequence_output = bert_outputs.last_hidden_state
|
| |
|
| |
|
| | position_enhanced = self.positional_encoder(sequence_output)
|
| |
|
| |
|
| | pooled_output = self._attention_pooling(position_enhanced, attention_mask)
|
| |
|
| |
|
| | if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
|
| | if prosody_features is None:
|
| | batch_size, seq_len = input_ids.size()
|
| | prosody_features = torch.zeros(
|
| | batch_size, seq_len, self.config.prosody_dim,
|
| | device=input_ids.device
|
| | )
|
| |
|
| | linguistic_features = self.linguistic_extractor(
|
| | word_pos_ids, word_grammar_ids, word_durations,
|
| | prosody_features, attention_mask
|
| | )
|
| | else:
|
| | linguistic_features = torch.zeros(
|
| | input_ids.size(0),
|
| | (self.config.pos_emb_dim + self.config.grammar_hidden_dim +
|
| | self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
|
| | device=input_ids.device
|
| | )
|
| |
|
| |
|
| | combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
|
| | fused_features = self.feature_fusion(combined_features)
|
| |
|
| |
|
| | logits = self.classifier(fused_features)
|
| | severity_pred = self.severity_head(fused_features)
|
| | fluency_pred = self.fluency_head(fused_features)
|
| |
|
| |
|
| | loss = None
|
| | if labels is not None:
|
| | loss = self._compute_loss(logits, labels)
|
| |
|
| | return {
|
| | "logits": logits,
|
| | "severity_pred": severity_pred,
|
| | "fluency_pred": fluency_pred,
|
| | "loss": loss
|
| | }
|
| |
|
| | def _attention_pooling(self, sequence_output, attention_mask):
|
| | """Attention-based pooling"""
|
| | attention_weights = torch.softmax(
|
| | torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
|
| | )
|
| | attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
|
| | attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
|
| | pooled = torch.sum(sequence_output * attention_weights, dim=1)
|
| | return pooled
|
| |
|
| | def _compute_loss(self, logits, labels):
|
| | if self.config.use_focal_loss:
|
| | focal_loss = FocalLoss(
|
| | alpha=self.config.focal_alpha,
|
| | gamma=self.config.focal_gamma,
|
| | reduction='mean'
|
| | )
|
| | return focal_loss(logits, labels)
|
| | else:
|
| | if self.config.use_label_smoothing:
|
| | return F.cross_entropy(
|
| | logits, labels,
|
| | label_smoothing=self.config.label_smoothing
|
| | )
|
| | else:
|
| | return F.cross_entropy(logits, labels)
|
| |
|
| |
|
| | class StableAphasiaDataset(Dataset):
|
| | """Stable dataset with simplified processing"""
|
| | def __init__(self, sentences, tokenizer, aphasia_types_mapping, config: ModelConfig):
|
| | self.samples = []
|
| | self.tokenizer = tokenizer
|
| | self.config = config
|
| | self.aphasia_types_mapping = aphasia_types_mapping
|
| |
|
| |
|
| | special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"]
|
| | tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
|
| |
|
| | for idx, item in enumerate(sentences):
|
| | sentence_id = item.get("sentence_id", f"S{idx}")
|
| | aphasia_type = normalize_type(item.get("aphasia_type", ""))
|
| |
|
| | if aphasia_type not in aphasia_types_mapping:
|
| | log_message(f"Skipping Sentence {sentence_id}: Invalid aphasia type '{aphasia_type}'")
|
| | continue
|
| |
|
| | self._process_sentence(item, sentence_id, aphasia_type)
|
| |
|
| | if not self.samples:
|
| | raise ValueError("No valid samples found in dataset!")
|
| |
|
| | log_message(f"Dataset created with {len(self.samples)} samples")
|
| | self._print_class_distribution()
|
| |
|
| | def _process_sentence(self, item, sentence_id, aphasia_type):
|
| | """Process sentence with stable approach"""
|
| | all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
|
| |
|
| | for dialogue_idx, dialogue in enumerate(item.get("dialogues", [])):
|
| | if dialogue_idx > 0:
|
| | all_tokens.append("[DIALOGUE]")
|
| | all_pos.append(0)
|
| | all_grammar.append([0, 0, 0])
|
| | all_durations.append(0.0)
|
| |
|
| | for par in dialogue.get("PAR", []):
|
| | if "tokens" in par and par["tokens"]:
|
| | tokens = par["tokens"]
|
| | pos_ids = par.get("word_pos_ids", [0] * len(tokens))
|
| | grammar_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(tokens))
|
| | durations = par.get("word_durations", [0.0] * len(tokens))
|
| |
|
| | all_tokens.extend(tokens)
|
| | all_pos.extend(pos_ids)
|
| | all_grammar.extend(grammar_ids)
|
| | all_durations.extend(durations)
|
| |
|
| | if not all_tokens:
|
| | return
|
| |
|
| |
|
| | self._create_sample(all_tokens, all_pos, all_grammar, all_durations,
|
| | sentence_id, aphasia_type)
|
| |
|
| | def _create_sample(self, tokens, pos_ids, grammar_ids, durations,
|
| | sentence_id, aphasia_type):
|
| | """Create training sample"""
|
| |
|
| | text = " ".join(tokens)
|
| | encoded = self.tokenizer(
|
| | text,
|
| | max_length=self.config.max_length,
|
| | padding="max_length",
|
| | truncation=True,
|
| | return_tensors="pt"
|
| | )
|
| |
|
| |
|
| | aligned_pos, aligned_grammar, aligned_durations = self._align_features(
|
| | tokens, pos_ids, grammar_ids, durations, encoded
|
| | )
|
| |
|
| |
|
| | prosody_features = self._extract_prosodic_features(durations, tokens)
|
| | prosody_tensor = torch.tensor(prosody_features).unsqueeze(0).repeat(
|
| | self.config.max_length, 1
|
| | )
|
| |
|
| | label = self.aphasia_types_mapping[aphasia_type]
|
| |
|
| | sample = {
|
| | "input_ids": encoded["input_ids"].squeeze(0),
|
| | "attention_mask": encoded["attention_mask"].squeeze(0),
|
| | "labels": torch.tensor(label, dtype=torch.long),
|
| | "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
|
| | "word_grammar_ids": torch.tensor(aligned_grammar, dtype=torch.long),
|
| | "word_durations": torch.tensor(aligned_durations, dtype=torch.float),
|
| | "prosody_features": prosody_tensor.float(),
|
| | "sentence_id": sentence_id
|
| | }
|
| | self.samples.append(sample)
|
| |
|
| | def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
|
| | """Align features with BERT subtokens"""
|
| | subtoken_to_token = []
|
| |
|
| | for token_idx, token in enumerate(tokens):
|
| | subtokens = self.tokenizer.tokenize(token)
|
| | subtoken_to_token.extend([token_idx] * len(subtokens))
|
| |
|
| | aligned_pos = [0]
|
| | aligned_grammar = [[0, 0, 0]]
|
| | aligned_durations = [0.0]
|
| |
|
| | for subtoken_idx in range(1, self.config.max_length - 1):
|
| | if subtoken_idx - 1 < len(subtoken_to_token):
|
| | original_idx = subtoken_to_token[subtoken_idx - 1]
|
| | aligned_pos.append(pos_ids[original_idx] if original_idx < len(pos_ids) else 0)
|
| | aligned_grammar.append(grammar_ids[original_idx] if original_idx < len(grammar_ids) else [0, 0, 0])
|
| | raw = durations[original_idx] if original_idx < len(durations) else 0.0
|
| | if isinstance(raw, list) and (isinstance(raw[1], int) and isinstance(raw[0], int)):
|
| | if len(raw) >= 2:
|
| | duration_val = int(raw[1]) - int(raw[0])
|
| | else:
|
| | duration_val = raw[0]
|
| | else:
|
| | duration_val = 0.0
|
| | aligned_durations.append(duration_val)
|
| | else:
|
| | aligned_pos.append(0)
|
| | aligned_grammar.append([0, 0, 0])
|
| | aligned_durations.append(0.0)
|
| |
|
| | aligned_pos.append(0)
|
| | aligned_grammar.append([0, 0, 0])
|
| | aligned_durations.append(0.0)
|
| |
|
| | return aligned_pos, aligned_grammar, aligned_durations
|
| |
|
| | def _extract_prosodic_features(self, durations, tokens):
|
| | """Extract prosodic features"""
|
| | if not durations:
|
| | return [0.0] * self.config.prosody_dim
|
| |
|
| | valid_durations = [d for d in durations if isinstance(d, (int, float)) and d > 0]
|
| | if not valid_durations:
|
| | return [0.0] * self.config.prosody_dim
|
| |
|
| | features = [
|
| | np.mean(valid_durations),
|
| | np.std(valid_durations),
|
| | np.median(valid_durations),
|
| | len([d for d in valid_durations if d > np.mean(valid_durations) * 1.5])
|
| | ]
|
| |
|
| |
|
| | while len(features) < self.config.prosody_dim:
|
| | features.append(0.0)
|
| |
|
| | return features[:self.config.prosody_dim]
|
| |
|
| | def _print_class_distribution(self):
|
| | """Print class distribution"""
|
| | label_counts = Counter(sample["labels"].item() for sample in self.samples)
|
| | reverse_mapping = {v: k for k, v in self.aphasia_types_mapping.items()}
|
| |
|
| | log_message("\nClass Distribution:")
|
| | for label_id, count in sorted(label_counts.items()):
|
| | class_name = reverse_mapping.get(label_id, f"Unknown_{label_id}")
|
| | log_message(f" {class_name}: {count} samples")
|
| |
|
| | def __len__(self):
|
| | return len(self.samples)
|
| |
|
| | def __getitem__(self, idx):
|
| | return self.samples[idx]
|
| |
|
| |
|
| | def stable_collate_fn(batch):
|
| | """Stable data collation"""
|
| | if not batch or batch[0] is None:
|
| | return None
|
| |
|
| | try:
|
| | max_length = batch[0]["input_ids"].size(0)
|
| |
|
| | collated_batch = {
|
| | "input_ids": torch.stack([item["input_ids"] for item in batch]),
|
| | "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
|
| | "labels": torch.stack([item["labels"] for item in batch]),
|
| | "sentence_ids": [item.get("sentence_id", "N/A") for item in batch],
|
| | "word_pos_ids": torch.stack([item.get("word_pos_ids", torch.zeros(max_length, dtype=torch.long)) for item in batch]),
|
| | "word_grammar_ids": torch.stack([item.get("word_grammar_ids", torch.zeros(max_length, 3, dtype=torch.long)) for item in batch]),
|
| | "word_durations": torch.stack([item.get("word_durations", torch.zeros(max_length, dtype=torch.float)) for item in batch]),
|
| | "prosody_features": torch.stack([item.get("prosody_features", torch.zeros(max_length, 32, dtype=torch.float)) for item in batch])
|
| | }
|
| | return collated_batch
|
| | except Exception as e:
|
| | log_message(f"Collation error: {e}")
|
| | return None
|
| |
|
| |
|
| | class AdaptiveTrainingCallback(TrainerCallback):
|
| | """Enhanced training callback with adaptive learning rate and comprehensive tracking"""
|
| | def __init__(self, config: ModelConfig, patience=5, min_delta=0.8):
|
| | self.config = config
|
| | self.patience = patience
|
| | self.min_delta = min_delta
|
| | self.best_metric = float('-inf')
|
| | self.patience_counter = 0
|
| |
|
| |
|
| | self.lr_scheduler = None
|
| |
|
| |
|
| | self.history_tracker = TrainingHistoryTracker()
|
| |
|
| |
|
| | self.current_train_metrics = {}
|
| | self.current_eval_metrics = {}
|
| |
|
| | def on_train_begin(self, args, state, control, **kwargs):
|
| | """Initialize learning rate scheduler"""
|
| | if self.config.adaptive_lr:
|
| | model = kwargs.get('model')
|
| | optimizer = kwargs.get('optimizer')
|
| | if optimizer and model:
|
| | total_steps = state.max_steps if state.max_steps > 0 else len(kwargs.get('train_dataloader', [])) * args.num_train_epochs
|
| | self.lr_scheduler = AdaptiveLearningRateScheduler(optimizer, self.config, total_steps)
|
| | log_message("Adaptive learning rate scheduler initialized")
|
| |
|
| | def on_log(self, args, state, control, logs=None, **kwargs):
|
| | """Capture training metrics"""
|
| | if logs:
|
| |
|
| | if 'train_loss' in logs:
|
| | self.current_train_metrics['loss'] = logs['train_loss']
|
| | if 'learning_rate' in logs:
|
| | self.current_train_metrics['lr'] = logs['learning_rate']
|
| |
|
| | def on_evaluate(self, args, state, control, logs=None, **kwargs):
|
| | """Handle evaluation and learning rate adjustment"""
|
| | if logs is not None:
|
| | current_metric = logs.get('eval_f1', 0)
|
| | current_loss = logs.get('eval_loss', float('inf'))
|
| | current_acc = logs.get('eval_accuracy', 0)
|
| |
|
| |
|
| | self.current_eval_metrics = {
|
| | 'loss': current_loss,
|
| | 'f1': current_metric,
|
| | 'accuracy': current_acc,
|
| | 'precision': logs.get('eval_precision_macro', 0),
|
| | 'recall': logs.get('eval_recall_macro', 0)
|
| | }
|
| |
|
| |
|
| | epoch_metrics = {
|
| | 'train_loss': self.current_train_metrics.get('loss', 0),
|
| | 'eval_loss': current_loss,
|
| | 'train_accuracy': 0,
|
| | 'eval_accuracy': current_acc,
|
| | 'train_f1': 0,
|
| | 'eval_f1': current_metric,
|
| | 'learning_rate': self.current_train_metrics.get('lr', self.config.learning_rate),
|
| | 'train_precision': 0,
|
| | 'eval_precision': logs.get('eval_precision_macro', 0),
|
| | 'train_recall': 0,
|
| | 'eval_recall': logs.get('eval_recall_macro', 0)
|
| | }
|
| |
|
| | self.history_tracker.update(state.epoch, epoch_metrics)
|
| |
|
| |
|
| | if self.lr_scheduler and self.config.adaptive_lr:
|
| | new_lr = self.lr_scheduler.adaptive_lr_calculation(current_loss, current_metric, current_acc)
|
| | if current_acc > 0.84:
|
| | log_message(f"Target accuracy reached ({current_acc:.2%}) → stopping and saving model")
|
| | control.should_save = True
|
| | control.should_training_stop = True
|
| | return control
|
| |
|
| | if current_metric > self.best_metric + self.min_delta:
|
| | self.best_metric = current_metric
|
| | self.patience_counter = 0
|
| | log_message(f"New best F1 score: {current_metric:.4f}")
|
| | else:
|
| | self.patience_counter += 1
|
| | log_message(f"No improvement for {self.patience_counter} evaluations")
|
| |
|
| | if self.patience_counter >= self.patience:
|
| | log_message("Early stopping triggered")
|
| | control.should_training_stop = True
|
| |
|
| | clear_memory()
|
| |
|
| | def on_train_end(self, args, state, control, **kwargs):
|
| | """Save training history at the end"""
|
| | output_dir = args.output_dir
|
| | self.history_tracker.save_history(output_dir)
|
| | self.history_tracker.plot_training_curves(output_dir)
|
| | log_message("Training history and curves saved")
|
| |
|
| |
|
| | def compute_comprehensive_metrics(pred):
|
| | """Compute comprehensive evaluation metrics"""
|
| | predictions = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
|
| | labels = pred.label_ids
|
| |
|
| | preds = np.argmax(predictions, axis=1)
|
| |
|
| | acc = accuracy_score(labels, preds)
|
| | f1_macro = f1_score(labels, preds, average='macro', zero_division=0)
|
| | f1_weighted = f1_score(labels, preds, average='weighted', zero_division=0)
|
| | precision_macro = precision_score(labels, preds, average='macro', zero_division=0)
|
| | recall_macro = recall_score(labels, preds, average='macro', zero_division=0)
|
| |
|
| |
|
| | f1_per_class = f1_score(labels, preds, average=None, zero_division=0)
|
| | precision_per_class = precision_score(labels, preds, average=None, zero_division=0)
|
| | recall_per_class = recall_score(labels, preds, average=None, zero_division=0)
|
| |
|
| | return {
|
| | "accuracy": acc,
|
| | "f1": f1_weighted,
|
| | "f1_macro": f1_macro,
|
| | "precision_macro": precision_macro,
|
| | "recall_macro": recall_macro,
|
| | "f1_std": np.std(f1_per_class),
|
| | "precision_std": np.std(precision_per_class),
|
| | "recall_std": np.std(recall_per_class)
|
| | }
|
| |
|
| |
|
| | def generate_comprehensive_reports(trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir):
|
| | """Generate comprehensive analysis reports and visualizations"""
|
| | log_message("Generating comprehensive reports...")
|
| |
|
| | model = trainer.model
|
| | if hasattr(model, 'module'):
|
| | model = model.module
|
| |
|
| | model.eval()
|
| | device = next(model.parameters()).device
|
| |
|
| | predictions = []
|
| | true_labels = []
|
| | sentence_ids = []
|
| | severity_preds = []
|
| | fluency_preds = []
|
| | prediction_probs = []
|
| |
|
| |
|
| | dataloader = DataLoader(eval_dataset, batch_size=8, collate_fn=stable_collate_fn)
|
| |
|
| | with torch.no_grad():
|
| | for batch_idx, batch in enumerate(dataloader):
|
| | if batch is None:
|
| | continue
|
| |
|
| |
|
| | for key in ['input_ids', 'attention_mask', 'word_pos_ids',
|
| | 'word_grammar_ids', 'word_durations', 'labels', 'prosody_features']:
|
| | if key in batch:
|
| | batch[key] = batch[key].to(device)
|
| |
|
| | outputs = model(**batch)
|
| |
|
| | logits = outputs["logits"]
|
| | probs = F.softmax(logits, dim=1)
|
| | preds = torch.argmax(logits, dim=1).cpu().numpy()
|
| |
|
| | predictions.extend(preds)
|
| | true_labels.extend(batch["labels"].cpu().numpy())
|
| | sentence_ids.extend(batch["sentence_ids"])
|
| | severity_preds.extend(outputs["severity_pred"].cpu().numpy())
|
| | fluency_preds.extend(outputs["fluency_pred"].cpu().numpy())
|
| | prediction_probs.extend(probs.cpu().numpy())
|
| |
|
| |
|
| | reverse_mapping = {v: k for k, v in aphasia_types_mapping.items()}
|
| |
|
| |
|
| | log_message("=== DETAILED PREDICTIONS (First 20) ===")
|
| | for i in range(min(20, len(predictions))):
|
| | true_type = reverse_mapping.get(true_labels[i], 'Unknown')
|
| | pred_type = reverse_mapping.get(predictions[i], 'Unknown')
|
| | severity_level = np.argmax(severity_preds[i])
|
| | fluency_score = fluency_preds[i][0] if isinstance(fluency_preds[i], np.ndarray) else fluency_preds[i]
|
| | confidence = np.max(prediction_probs[i])
|
| |
|
| | log_message(f"ID: {sentence_ids[i]} | True: {true_type} | Pred: {pred_type} | "
|
| | f"Confidence: {confidence:.3f} | Severity: {severity_level} | Fluency: {fluency_score:.3f}")
|
| |
|
| |
|
| | cm = confusion_matrix(true_labels, predictions)
|
| |
|
| |
|
| | plt.figure(figsize=(14, 12))
|
| |
|
| |
|
| | cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
|
| |
|
| |
|
| | annotations = np.empty_like(cm, dtype=object)
|
| | for i in range(cm.shape[0]):
|
| | for j in range(cm.shape[1]):
|
| | annotations[i, j] = f'{cm[i, j]}\n({cm_percentage[i, j]:.1f}%)'
|
| |
|
| | sns.heatmap(cm, annot=annotations, fmt='', cmap="Blues",
|
| | xticklabels=list(aphasia_types_mapping.keys()),
|
| | yticklabels=list(aphasia_types_mapping.keys()),
|
| | cbar_kws={'label': 'Count'})
|
| |
|
| | plt.xlabel("Predicted Label", fontsize=12, fontweight='bold')
|
| | plt.ylabel("True Label", fontsize=12, fontweight='bold')
|
| | plt.title("Enhanced Confusion Matrix\n(Count and Percentage)", fontsize=14, fontweight='bold')
|
| | plt.xticks(rotation=45, ha='right')
|
| | plt.yticks(rotation=0)
|
| | plt.tight_layout()
|
| | plt.savefig(os.path.join(output_dir, "enhanced_confusion_matrix.png"), dpi=300, bbox_inches='tight')
|
| | plt.close()
|
| |
|
| |
|
| | all_label_ids = list(aphasia_types_mapping.values())
|
| | report_dict = classification_report(
|
| | true_labels,
|
| | predictions,
|
| | labels=all_label_ids,
|
| | target_names=list(aphasia_types_mapping.keys()),
|
| | output_dict=True,
|
| | zero_division=0
|
| | )
|
| |
|
| | df_report = pd.DataFrame(report_dict).transpose()
|
| | df_report.to_csv(os.path.join(output_dir, "comprehensive_classification_report.csv"))
|
| |
|
| |
|
| | class_names = list(aphasia_types_mapping.keys())
|
| | metrics_data = []
|
| |
|
| | for i, class_name in enumerate(class_names):
|
| | if class_name in report_dict:
|
| | metrics_data.append({
|
| | 'Class': class_name,
|
| | 'Precision': report_dict[class_name]['precision'],
|
| | 'Recall': report_dict[class_name]['recall'],
|
| | 'F1-Score': report_dict[class_name]['f1-score'],
|
| | 'Support': report_dict[class_name]['support']
|
| | })
|
| |
|
| | df_metrics = pd.DataFrame(metrics_data)
|
| | df_metrics.to_csv(os.path.join(output_dir, "per_class_metrics.csv"), index=False)
|
| |
|
| |
|
| | fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
| |
|
| |
|
| | axes[0, 0].bar(df_metrics['Class'], df_metrics['Precision'], color='skyblue', alpha=0.8)
|
| | axes[0, 0].set_title('Precision by Class', fontweight='bold')
|
| | axes[0, 0].set_ylabel('Precision')
|
| | axes[0, 0].tick_params(axis='x', rotation=45)
|
| | axes[0, 0].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[0, 1].bar(df_metrics['Class'], df_metrics['Recall'], color='lightcoral', alpha=0.8)
|
| | axes[0, 1].set_title('Recall by Class', fontweight='bold')
|
| | axes[0, 1].set_ylabel('Recall')
|
| | axes[0, 1].tick_params(axis='x', rotation=45)
|
| | axes[0, 1].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[1, 0].bar(df_metrics['Class'], df_metrics['F1-Score'], color='lightgreen', alpha=0.8)
|
| | axes[1, 0].set_title('F1-Score by Class', fontweight='bold')
|
| | axes[1, 0].set_ylabel('F1-Score')
|
| | axes[1, 0].tick_params(axis='x', rotation=45)
|
| | axes[1, 0].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[1, 1].bar(df_metrics['Class'], df_metrics['Support'], color='gold', alpha=0.8)
|
| | axes[1, 1].set_title('Support by Class', fontweight='bold')
|
| | axes[1, 1].set_ylabel('Support (Number of Samples)')
|
| | axes[1, 1].tick_params(axis='x', rotation=45)
|
| | axes[1, 1].grid(True, alpha=0.3)
|
| |
|
| | plt.tight_layout()
|
| | plt.savefig(os.path.join(output_dir, "per_class_performance.png"), dpi=300, bbox_inches='tight')
|
| | plt.close()
|
| |
|
| |
|
| | confidences = [np.max(prob) for prob in prediction_probs]
|
| | correct_predictions = [pred == true for pred, true in zip(predictions, true_labels)]
|
| |
|
| | plt.figure(figsize=(12, 8))
|
| |
|
| |
|
| | correct_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if correct]
|
| | incorrect_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if not correct]
|
| |
|
| | plt.hist(correct_confidences, bins=30, alpha=0.7, label='Correct Predictions', color='green', density=True)
|
| | plt.hist(incorrect_confidences, bins=30, alpha=0.7, label='Incorrect Predictions', color='red', density=True)
|
| |
|
| | plt.xlabel('Prediction Confidence', fontsize=12)
|
| | plt.ylabel('Density', fontsize=12)
|
| | plt.title('Distribution of Prediction Confidence', fontsize=14, fontweight='bold')
|
| | plt.legend()
|
| | plt.grid(True, alpha=0.3)
|
| | plt.tight_layout()
|
| | plt.savefig(os.path.join(output_dir, "confidence_distribution.png"), dpi=300, bbox_inches='tight')
|
| | plt.close()
|
| |
|
| |
|
| | log_message("=== FEATURE ANALYSIS ===")
|
| | avg_severity = np.mean(severity_preds, axis=0)
|
| | avg_fluency = np.mean(fluency_preds)
|
| | std_fluency = np.std(fluency_preds)
|
| |
|
| | log_message(f"Average Severity Distribution: {avg_severity}")
|
| | log_message(f"Average Fluency Score: {avg_fluency:.3f} ± {std_fluency:.3f}")
|
| |
|
| |
|
| | results_df = pd.DataFrame({
|
| | 'sentence_id': sentence_ids,
|
| | 'true_label': [reverse_mapping[label] for label in true_labels],
|
| | 'predicted_label': [reverse_mapping[pred] for pred in predictions],
|
| | 'prediction_confidence': confidences,
|
| | 'correct_prediction': correct_predictions,
|
| | 'severity_level': [np.argmax(severity) for severity in severity_preds],
|
| | 'fluency_score': [fluency[0] if isinstance(fluency, np.ndarray) else fluency for fluency in fluency_preds]
|
| | })
|
| |
|
| |
|
| | for i, class_name in enumerate(aphasia_types_mapping.keys()):
|
| | results_df[f'prob_{class_name}'] = [prob[i] for prob in prediction_probs]
|
| |
|
| | results_df.to_csv(os.path.join(output_dir, "comprehensive_results.csv"), index=False)
|
| |
|
| |
|
| | summary_stats = {
|
| | 'Overall Accuracy': accuracy_score(true_labels, predictions),
|
| | 'Macro F1': f1_score(true_labels, predictions, average='macro'),
|
| | 'Weighted F1': f1_score(true_labels, predictions, average='weighted'),
|
| | 'Macro Precision': precision_score(true_labels, predictions, average='macro'),
|
| | 'Macro Recall': recall_score(true_labels, predictions, average='macro'),
|
| | 'Average Confidence': np.mean(confidences),
|
| | 'Confidence Std': np.std(confidences),
|
| | 'Average Severity': avg_severity.tolist(),
|
| | 'Average Fluency': avg_fluency,
|
| | 'Fluency Std': std_fluency
|
| | }
|
| |
|
| | serializable_summary = {
|
| | k: float(v) if isinstance(v, (np.floating, np.integer)) else v
|
| | for k, v in summary_stats.items()
|
| | }
|
| | with open(os.path.join(output_dir, "summary_statistics.json"), "w") as f:
|
| | json.dump(serializable_summary, f, indent=2)
|
| |
|
| | log_message("Comprehensive Classification Report:")
|
| | log_message(df_report.to_string())
|
| | log_message(f"Comprehensive results saved to {output_dir}")
|
| |
|
| | return results_df, df_report, summary_stats
|
| |
|
| |
|
| | def train_adaptive_model(json_file: str, output_dir: str = "./adaptive_aphasia_model"):
|
| | """Main training function with adaptive learning rate"""
|
| |
|
| | log_message("Starting Adaptive Aphasia Classification Training")
|
| | log_message("=" * 60)
|
| |
|
| |
|
| | config = ModelConfig()
|
| | os.makedirs(output_dir, exist_ok=True)
|
| |
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | log_message(f"Using device: {device}")
|
| |
|
| |
|
| | log_message("Loading dataset...")
|
| | with open(json_file, "r", encoding="utf-8") as f:
|
| | dataset_json = json.load(f)
|
| |
|
| | sentences = dataset_json.get("sentences", [])
|
| |
|
| |
|
| | for item in sentences:
|
| | if "aphasia_type" in item:
|
| | item["aphasia_type"] = normalize_type(item["aphasia_type"])
|
| |
|
| |
|
| | aphasia_types_mapping = {
|
| | "BROCA": 0,
|
| | "TRANSMOTOR": 1,
|
| | "NOTAPHASICBYWAB": 2,
|
| | "CONDUCTION": 3,
|
| | "WERNICKE": 4,
|
| | "ANOMIC": 5,
|
| | "GLOBAL": 6,
|
| | "ISOLATION": 7,
|
| | "TRANSSENSORY": 8
|
| | }
|
| |
|
| | log_message(f"Aphasia Types Mapping: {aphasia_types_mapping}")
|
| |
|
| | num_labels = len(aphasia_types_mapping)
|
| | log_message(f"Number of labels: {num_labels}")
|
| |
|
| |
|
| | filtered_sentences = []
|
| | for item in sentences:
|
| | aphasia_type = item.get("aphasia_type", "")
|
| | if aphasia_type in aphasia_types_mapping:
|
| | filtered_sentences.append(item)
|
| | else:
|
| | log_message(f"Excluding sentence with invalid type: {aphasia_type}")
|
| |
|
| | log_message(f"Filtered dataset: {len(filtered_sentences)} sentences")
|
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| | if tokenizer.pad_token is None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| |
|
| | random.shuffle(filtered_sentences)
|
| | dataset_all = StableAphasiaDataset(
|
| | filtered_sentences, tokenizer, aphasia_types_mapping, config
|
| | )
|
| |
|
| |
|
| | total_samples = len(dataset_all)
|
| | train_size = int(0.8 * total_samples)
|
| | eval_size = total_samples - train_size
|
| |
|
| | train_dataset, eval_dataset = torch.utils.data.random_split(
|
| | dataset_all, [train_size, eval_size]
|
| | )
|
| |
|
| | log_message(f"Train size: {train_size}, Eval size: {eval_size}")
|
| |
|
| |
|
| | train_labels = [dataset_all.samples[idx]["labels"].item() for idx in train_dataset.indices]
|
| | label_counts = Counter(train_labels)
|
| | sample_weights = [1.0 / label_counts[label] for label in train_labels]
|
| | sampler = WeightedRandomSampler(
|
| | weights=sample_weights,
|
| | num_samples=len(sample_weights),
|
| | replacement=True
|
| | )
|
| |
|
| |
|
| | def model_init():
|
| | model = StableAphasiaClassifier(config, num_labels)
|
| | model.bert.resize_token_embeddings(len(tokenizer))
|
| | return model.to(device)
|
| |
|
| |
|
| | training_args = TrainingArguments(
|
| | output_dir=output_dir,
|
| | eval_strategy="epoch",
|
| | save_strategy="epoch",
|
| | learning_rate=config.learning_rate,
|
| | per_device_train_batch_size=config.batch_size,
|
| | per_device_eval_batch_size=config.batch_size,
|
| | num_train_epochs=config.num_epochs,
|
| | weight_decay=config.weight_decay,
|
| | warmup_ratio=config.warmup_ratio,
|
| | logging_strategy="steps",
|
| | logging_steps=50,
|
| | seed=42,
|
| | dataloader_num_workers=0,
|
| | gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| | max_grad_norm=1.0,
|
| | fp16=False,
|
| | dataloader_drop_last=True,
|
| | report_to=None,
|
| | load_best_model_at_end=True,
|
| | metric_for_best_model="eval_f1",
|
| | greater_is_better=True,
|
| | save_total_limit=3,
|
| | remove_unused_columns=False,
|
| | )
|
| |
|
| |
|
| | trainer = Trainer(
|
| | model_init=model_init,
|
| | args=training_args,
|
| | train_dataset=train_dataset,
|
| | eval_dataset=eval_dataset,
|
| | compute_metrics=compute_comprehensive_metrics,
|
| | data_collator=stable_collate_fn,
|
| | callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)]
|
| | )
|
| |
|
| |
|
| | log_message("Starting adaptive training...")
|
| | try:
|
| | trainer.train()
|
| | log_message("Training completed successfully!")
|
| | except Exception as e:
|
| | log_message(f"Training error: {str(e)}")
|
| | import traceback
|
| | log_message(traceback.format_exc())
|
| | raise
|
| |
|
| |
|
| | log_message("Starting final evaluation...")
|
| | eval_results = trainer.evaluate()
|
| | log_message(f"Final evaluation results: {eval_results}")
|
| |
|
| |
|
| | results_df, report_df, summary_stats = generate_comprehensive_reports(
|
| | trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir
|
| | )
|
| |
|
| |
|
| | model_to_save = trainer.model
|
| | if hasattr(model_to_save, 'module'):
|
| | model_to_save = model_to_save.module
|
| |
|
| | torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
| | tokenizer.save_pretrained(output_dir)
|
| |
|
| |
|
| | config_dict = {
|
| | "model_name": config.model_name,
|
| | "num_labels": num_labels,
|
| | "aphasia_types_mapping": aphasia_types_mapping,
|
| | "training_args": training_args.to_dict(),
|
| | "adaptive_lr_config": {
|
| | "adaptive_lr": config.adaptive_lr,
|
| | "lr_patience": config.lr_patience,
|
| | "lr_factor": config.lr_factor,
|
| | "lr_increase_factor": config.lr_increase_factor,
|
| | "min_lr": config.min_lr,
|
| | "max_lr": config.max_lr,
|
| | "oscillation_amplitude": config.oscillation_amplitude
|
| | }
|
| | }
|
| |
|
| | with open(os.path.join(output_dir, "config.json"), "w") as f:
|
| | json.dump(config_dict, f, indent=2)
|
| |
|
| | log_message(f"Adaptive model and comprehensive reports saved to {output_dir}")
|
| | clear_memory()
|
| |
|
| | return trainer, eval_results, results_df
|
| |
|
| |
|
| | def train_adaptive_cross_validation(json_file: str, output_dir: str = "./adaptive_cv_results", n_folds: int = 5):
|
| | """Cross-validation training with adaptive learning rate"""
|
| | log_message("Starting Adaptive Cross-Validation Training")
|
| |
|
| | config = ModelConfig()
|
| | os.makedirs(output_dir, exist_ok=True)
|
| |
|
| |
|
| | with open(json_file, "r", encoding="utf-8") as f:
|
| | dataset_json = json.load(f)
|
| |
|
| | sentences = dataset_json.get("sentences", [])
|
| |
|
| |
|
| | for item in sentences:
|
| | if "aphasia_type" in item:
|
| | item["aphasia_type"] = normalize_type(item["aphasia_type"])
|
| |
|
| | aphasia_types_mapping = {
|
| | "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
|
| | "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
|
| | "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
|
| | }
|
| |
|
| | filtered_sentences = [s for s in sentences if s.get("aphasia_type") in aphasia_types_mapping]
|
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| | if tokenizer.pad_token is None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| |
|
| | full_dataset = StableAphasiaDataset(
|
| | filtered_sentences, tokenizer, aphasia_types_mapping, config
|
| | )
|
| |
|
| |
|
| | sample_labels = [sample["labels"].item() for sample in full_dataset.samples]
|
| |
|
| |
|
| | skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
| | fold_results = []
|
| | all_predictions = []
|
| | all_true_labels = []
|
| |
|
| | for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(sample_labels)), sample_labels)):
|
| | log_message(f"\n=== Fold {fold + 1}/{n_folds} ===")
|
| |
|
| | train_subset = Subset(full_dataset, train_idx)
|
| | val_subset = Subset(full_dataset, val_idx)
|
| |
|
| |
|
| | fold_trainer, fold_results_dict, fold_predictions = train_adaptive_single_fold(
|
| | train_subset, val_subset, config, aphasia_types_mapping,
|
| | tokenizer, fold, output_dir
|
| | )
|
| |
|
| | fold_results.append({
|
| | 'fold': fold + 1,
|
| | **fold_results_dict
|
| | })
|
| |
|
| |
|
| | all_predictions.extend(fold_predictions['predictions'])
|
| | all_true_labels.extend(fold_predictions['true_labels'])
|
| |
|
| | clear_memory()
|
| |
|
| |
|
| | results_df = pd.DataFrame(fold_results)
|
| | results_df.to_csv(os.path.join(output_dir, "adaptive_cv_summary.csv"), index=False)
|
| |
|
| |
|
| | cv_summary = {
|
| | 'mean_accuracy': results_df['accuracy'].mean(),
|
| | 'std_accuracy': results_df['accuracy'].std(),
|
| | 'mean_f1': results_df['f1'].mean(),
|
| | 'std_f1': results_df['f1'].std(),
|
| | 'mean_f1_macro': results_df['f1_macro'].mean(),
|
| | 'std_f1_macro': results_df['f1_macro'].std(),
|
| | 'mean_precision': results_df['precision_macro'].mean(),
|
| | 'std_precision': results_df['precision_macro'].std(),
|
| | 'mean_recall': results_df['recall_macro'].mean(),
|
| | 'std_recall': results_df['recall_macro'].std()
|
| | }
|
| |
|
| | with open(os.path.join(output_dir, "cv_statistics.json"), "w") as f:
|
| | json.dump(cv_summary, f, indent=2)
|
| |
|
| |
|
| | overall_cm = confusion_matrix(all_true_labels, all_predictions)
|
| |
|
| | plt.figure(figsize=(12, 10))
|
| | sns.heatmap(overall_cm, annot=True, fmt="d", cmap="Blues",
|
| | xticklabels=list(aphasia_types_mapping.keys()),
|
| | yticklabels=list(aphasia_types_mapping.keys()))
|
| | plt.xlabel("Predicted Label")
|
| | plt.ylabel("True Label")
|
| | plt.title("Overall Confusion Matrix (All Folds)")
|
| | plt.xticks(rotation=45)
|
| | plt.yticks(rotation=0)
|
| | plt.tight_layout()
|
| | plt.savefig(os.path.join(output_dir, "overall_confusion_matrix.png"), dpi=300, bbox_inches='tight')
|
| | plt.close()
|
| |
|
| |
|
| | fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
| |
|
| |
|
| | axes[0, 0].bar(range(1, n_folds + 1), results_df['accuracy'], color='skyblue', alpha=0.8)
|
| | axes[0, 0].axhline(y=results_df['accuracy'].mean(), color='red', linestyle='--',
|
| | label=f'Mean: {results_df["accuracy"].mean():.3f}')
|
| | axes[0, 0].set_title('Accuracy Across Folds')
|
| | axes[0, 0].set_xlabel('Fold')
|
| | axes[0, 0].set_ylabel('Accuracy')
|
| | axes[0, 0].legend()
|
| | axes[0, 0].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[0, 1].bar(range(1, n_folds + 1), results_df['f1'], color='lightgreen', alpha=0.8)
|
| | axes[0, 1].axhline(y=results_df['f1'].mean(), color='red', linestyle='--',
|
| | label=f'Mean: {results_df["f1"].mean():.3f}')
|
| | axes[0, 1].set_title('F1 Score Across Folds')
|
| | axes[0, 1].set_xlabel('Fold')
|
| | axes[0, 1].set_ylabel('F1 Score')
|
| | axes[0, 1].legend()
|
| | axes[0, 1].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[1, 0].bar(range(1, n_folds + 1), results_df['precision_macro'], color='coral', alpha=0.8)
|
| | axes[1, 0].axhline(y=results_df['precision_macro'].mean(), color='red', linestyle='--',
|
| | label=f'Mean: {results_df["precision_macro"].mean():.3f}')
|
| | axes[1, 0].set_title('Precision Across Folds')
|
| | axes[1, 0].set_xlabel('Fold')
|
| | axes[1, 0].set_ylabel('Precision')
|
| | axes[1, 0].legend()
|
| | axes[1, 0].grid(True, alpha=0.3)
|
| |
|
| |
|
| | axes[1, 1].bar(range(1, n_folds + 1), results_df['recall_macro'], color='gold', alpha=0.8)
|
| | axes[1, 1].axhline(y=results_df['recall_macro'].mean(), color='red', linestyle='--',
|
| | label=f'Mean: {results_df["recall_macro"].mean():.3f}')
|
| | axes[1, 1].set_title('Recall Across Folds')
|
| | axes[1, 1].set_xlabel('Fold')
|
| | axes[1, 1].set_ylabel('Recall')
|
| | axes[1, 1].legend()
|
| | axes[1, 1].grid(True, alpha=0.3)
|
| |
|
| | plt.tight_layout()
|
| | plt.savefig(os.path.join(output_dir, "cv_performance_comparison.png"), dpi=300, bbox_inches='tight')
|
| | plt.close()
|
| |
|
| | log_message("\n=== Adaptive Cross-Validation Summary ===")
|
| | log_message(results_df.to_string(index=False))
|
| |
|
| |
|
| | log_message(f"\nMean F1: {results_df['f1'].mean():.4f} ± {results_df['f1'].std():.4f}")
|
| | log_message(f"Mean Accuracy: {results_df['accuracy'].mean():.4f} ± {results_df['accuracy'].std():.4f}")
|
| | log_message(f"Mean F1 Macro: {results_df['f1_macro'].mean():.4f} ± {results_df['f1_macro'].std():.4f}")
|
| |
|
| | return results_df, cv_summary
|
| |
|
| | def train_adaptive_single_fold(train_dataset, val_dataset, config, aphasia_types_mapping,
|
| | tokenizer, fold, output_dir):
|
| | """Train a single fold with adaptive learning rate"""
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | num_labels = len(aphasia_types_mapping)
|
| |
|
| |
|
| | train_labels = [train_dataset[i]["labels"].item() for i in range(len(train_dataset))]
|
| | label_counts = Counter(train_labels)
|
| | sample_weights = [1.0 / label_counts[label] for label in train_labels]
|
| | sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
| |
|
| |
|
| | def model_init():
|
| | model = StableAphasiaClassifier(config, num_labels)
|
| | model.bert.resize_token_embeddings(len(tokenizer))
|
| | return model.to(device)
|
| |
|
| |
|
| | fold_output_dir = os.path.join(output_dir, f"fold_{fold}")
|
| | os.makedirs(fold_output_dir, exist_ok=True)
|
| |
|
| | training_args = TrainingArguments(
|
| | output_dir=fold_output_dir,
|
| | eval_strategy="epoch",
|
| | save_strategy="epoch",
|
| | learning_rate=config.learning_rate,
|
| | per_device_train_batch_size=config.batch_size,
|
| | per_device_eval_batch_size=config.batch_size,
|
| | num_train_epochs=config.num_epochs,
|
| | weight_decay=config.weight_decay,
|
| | warmup_ratio=config.warmup_ratio,
|
| | logging_steps=50,
|
| | seed=42,
|
| | dataloader_num_workers=0,
|
| | gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| | max_grad_norm=1.0,
|
| | fp16=False,
|
| | dataloader_drop_last=True,
|
| | report_to=None,
|
| | load_best_model_at_end=True,
|
| | metric_for_best_model="eval_f1",
|
| | greater_is_better=True,
|
| | save_total_limit=1,
|
| | remove_unused_columns=False,
|
| | )
|
| |
|
| |
|
| | trainer = Trainer(
|
| | model_init=model_init,
|
| | args=training_args,
|
| | train_dataset=train_dataset,
|
| | eval_dataset=val_dataset,
|
| | compute_metrics=compute_comprehensive_metrics,
|
| | data_collator=stable_collate_fn,
|
| | callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)]
|
| | )
|
| |
|
| |
|
| | trainer.train()
|
| |
|
| |
|
| | eval_results = trainer.evaluate()
|
| |
|
| |
|
| | predictions = trainer.predict(val_dataset)
|
| | pred_labels = np.argmax(predictions.predictions[0] if isinstance(predictions.predictions, tuple) else predictions.predictions, axis=1)
|
| | true_labels = predictions.label_ids
|
| |
|
| | fold_predictions = {
|
| | 'predictions': pred_labels.tolist(),
|
| | 'true_labels': true_labels.tolist()
|
| | }
|
| |
|
| |
|
| | model_to_save = trainer.model
|
| | if hasattr(model_to_save, 'module'):
|
| | model_to_save = model_to_save.module
|
| |
|
| | torch.save(model_to_save.state_dict(), os.path.join(fold_output_dir, "pytorch_model.bin"))
|
| |
|
| | return trainer, eval_results, fold_predictions
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | import argparse
|
| |
|
| | parser = argparse.ArgumentParser(description="Adaptive Learning Rate Aphasia Classification Training")
|
| | parser.add_argument("--output_dir", type=str, default="./adaptive_aphasia_model", help="Output directory")
|
| | parser.add_argument("--cross_validation", action="store_true", help="Use cross-validation")
|
| | parser.add_argument("--n_folds", type=int, default=5, help="Number of CV folds")
|
| | parser.add_argument("--json_file", type=str, default=json_file, help="Path to JSON dataset file")
|
| | parser.add_argument("--learning_rate", type=float, default=5e-4, help="Initial learning rate")
|
| | parser.add_argument("--batch_size", type=int, default=24, help="Batch size")
|
| | parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
|
| | parser.add_argument("--adaptive_lr", action="store_true", default=True, help="Use adaptive learning rate")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| |
|
| | config = ModelConfig()
|
| | config.learning_rate = args.learning_rate
|
| | config.batch_size = args.batch_size
|
| | config.num_epochs = args.num_epochs
|
| | config.adaptive_lr = args.adaptive_lr
|
| |
|
| | try:
|
| | clear_memory()
|
| |
|
| | log_message(f"Starting training with adaptive_lr={config.adaptive_lr}")
|
| | log_message(f"Config: lr={config.learning_rate}, batch_size={config.batch_size}, epochs={config.num_epochs}")
|
| |
|
| | if args.cross_validation:
|
| | results_df, cv_summary = train_adaptive_cross_validation(args.json_file, args.output_dir, args.n_folds)
|
| | log_message("Cross-validation training completed!")
|
| | else:
|
| | trainer, eval_results, results_df = train_adaptive_model(args.json_file, args.output_dir)
|
| | log_message("Single model training completed!")
|
| |
|
| | log_message("All adaptive training completed successfully!")
|
| |
|
| | except Exception as e:
|
| | log_message(f"Training failed: {str(e)}")
|
| | import traceback
|
| | log_message(traceback.format_exc())
|
| | finally:
|
| | clear_memory() |