Spaces:
Sleeping
Sleeping
| """ | |
| Transformer-based models for Parkinson's disease classification. | |
| This module implements pretrained transformer models (DistilBERT, BioBERT, PubMedBERT) | |
| adapted for tabular data with RAG integration. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| from sklearn.metrics import accuracy_score, classification_report, confusion_matrix | |
| from sklearn.model_selection import StratifiedKFold | |
| import joblib | |
| import os | |
| from typing import Tuple, Dict, Any, List, Optional | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from transformers import ( | |
| AutoModel, | |
| AutoTokenizer, | |
| DistilBertModel, | |
| DistilBertTokenizer, | |
| BertModel, | |
| BertTokenizer | |
| ) | |
| class TabularDataset(Dataset): | |
| """Custom dataset for tabular data.""" | |
| def __init__(self, X, y, feature_names=None): | |
| self.X = torch.FloatTensor(X) | |
| self.y = torch.LongTensor(y) | |
| self.feature_names = feature_names or [f"feature_{i}" for i in range(X.shape[1])] | |
| def __len__(self): | |
| return len(self.X) | |
| def __getitem__(self, idx): | |
| return self.X[idx], self.y[idx] | |
| def get_feature_description(self, idx): | |
| """Get feature values with names for text representation.""" | |
| sample = self.X[idx].numpy() | |
| return {name: float(val) for name, val in zip(self.feature_names, sample)} | |
| class DistilBERTForTabular(nn.Module): | |
| """DistilBERT model adapted for tabular data with RAG integration.""" | |
| def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.1): | |
| super(DistilBERTForTabular, self).__init__() | |
| self.input_dim = input_dim | |
| self.num_classes = num_classes | |
| self.model_name = "distilbert-base-uncased" | |
| # Load pretrained DistilBERT model and tokenizer | |
| self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_name) | |
| self.bert = DistilBertModel.from_pretrained(self.model_name) | |
| # Freeze BERT parameters to speed up training | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| # Feature projection layer | |
| self.feature_projection = nn.Linear(input_dim, 768) # Project to BERT hidden size | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Linear(768, 256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x, text_input=None): | |
| # Project tabular features | |
| tabular_features = self.feature_projection(x) | |
| if text_input is not None: | |
| # Process text input if available (for RAG integration) | |
| inputs = self.tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.to(x.device) for k, v in inputs.items()} | |
| # Get BERT embeddings | |
| with torch.no_grad(): | |
| outputs = self.bert(**inputs) | |
| text_features = outputs.last_hidden_state[:, 0, :] # CLS token | |
| # Combine tabular and text features | |
| combined_features = tabular_features + text_features | |
| else: | |
| # Use only tabular features if no text input | |
| combined_features = tabular_features | |
| # Classification | |
| output = self.classifier(combined_features) | |
| return output | |
| class BioBERTForTabular(nn.Module): | |
| """BioBERT model adapted for tabular data with RAG integration.""" | |
| def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.1): | |
| super(BioBERTForTabular, self).__init__() | |
| self.input_dim = input_dim | |
| self.num_classes = num_classes | |
| self.model_name = "dmis-lab/biobert-v1.1" | |
| # Load pretrained BioBERT model and tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.bert = AutoModel.from_pretrained(self.model_name) | |
| # Freeze BERT parameters to speed up training | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| # Feature projection layer | |
| self.feature_projection = nn.Linear(input_dim, 768) # Project to BERT hidden size | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Linear(768, 256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x, text_input=None): | |
| # Project tabular features | |
| tabular_features = self.feature_projection(x) | |
| if text_input is not None: | |
| # Process text input if available (for RAG integration) | |
| inputs = self.tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.to(x.device) for k, v in inputs.items()} | |
| # Get BERT embeddings | |
| with torch.no_grad(): | |
| outputs = self.bert(**inputs) | |
| text_features = outputs.last_hidden_state[:, 0, :] # CLS token | |
| # Combine tabular and text features | |
| combined_features = tabular_features + text_features | |
| else: | |
| # Use only tabular features if no text input | |
| combined_features = tabular_features | |
| # Classification | |
| output = self.classifier(combined_features) | |
| return output | |
| class PubMedBERTForTabular(nn.Module): | |
| """PubMedBERT model adapted for tabular data with RAG integration.""" | |
| def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.1): | |
| super(PubMedBERTForTabular, self).__init__() | |
| self.input_dim = input_dim | |
| self.num_classes = num_classes | |
| self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" | |
| # Load pretrained PubMedBERT model and tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.bert = AutoModel.from_pretrained(self.model_name) | |
| # Freeze BERT parameters to speed up training | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| # Feature projection layer | |
| self.feature_projection = nn.Linear(input_dim, 768) # Project to BERT hidden size | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Linear(768, 256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x, text_input=None): | |
| # Project tabular features | |
| tabular_features = self.feature_projection(x) | |
| if text_input is not None: | |
| # Process text input if available (for RAG integration) | |
| inputs = self.tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.to(x.device) for k, v in inputs.items()} | |
| # Get BERT embeddings | |
| with torch.no_grad(): | |
| outputs = self.bert(**inputs) | |
| text_features = outputs.last_hidden_state[:, 0, :] # CLS token | |
| # Combine tabular and text features | |
| combined_features = tabular_features + text_features | |
| else: | |
| # Use only tabular features if no text input | |
| combined_features = tabular_features | |
| # Classification | |
| output = self.classifier(combined_features) | |
| return output | |
| class FeedForwardNetwork(nn.Module): | |
| """Simple feed-forward network for comparison.""" | |
| def __init__(self, input_dim: int, num_classes: int, hidden_dims: list = [256, 128, 64], | |
| dropout: float = 0.3): | |
| super(FeedForwardNetwork, self).__init__() | |
| layers = [] | |
| prev_dim = input_dim | |
| for hidden_dim in hidden_dims: | |
| layers.extend([ | |
| nn.Linear(prev_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout) | |
| ]) | |
| prev_dim = hidden_dim | |
| layers.append(nn.Linear(prev_dim, num_classes)) | |
| self.network = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.network(x) | |
| class TransformerModels: | |
| """Class to handle training and evaluation of transformer models.""" | |
| def __init__(self, device: str = None): | |
| self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.training_history = {} | |
| def create_model(self, model_type: str, input_dim: int, num_classes: int, **kwargs): | |
| """Create a model of specified type.""" | |
| if model_type == 'transformer': | |
| # For backward compatibility with saved models | |
| if 'd_model' in kwargs: | |
| # This is the old TabularTransformer model structure | |
| d_model = kwargs.get('d_model', 256) | |
| nhead = kwargs.get('nhead', 8) | |
| num_layers = kwargs.get('num_layers', 4) | |
| dropout = kwargs.get('dropout', 0.1) | |
| # Create a compatible model structure | |
| model = FeedForwardNetwork(input_dim, num_classes, | |
| hidden_dims=[d_model, d_model//2, d_model//4], | |
| dropout=dropout) | |
| else: | |
| # New transformer models | |
| model_name = kwargs.get('model_name', 'distilbert') | |
| if model_name == 'distilbert': | |
| model = DistilBERTForTabular(input_dim, num_classes, dropout=kwargs.get('dropout', 0.1)) | |
| elif model_name == 'biobert': | |
| model = BioBERTForTabular(input_dim, num_classes, dropout=kwargs.get('dropout', 0.1)) | |
| elif model_name == 'pubmedbert': | |
| model = PubMedBERTForTabular(input_dim, num_classes, dropout=kwargs.get('dropout', 0.1)) | |
| else: | |
| # Default to DistilBERT if model name not specified | |
| model = DistilBERTForTabular(input_dim, num_classes, dropout=kwargs.get('dropout', 0.1)) | |
| elif model_type == 'feedforward': | |
| model = FeedForwardNetwork(input_dim, num_classes, **kwargs) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| return model.to(self.device) | |
| def train_model(self, model, train_loader, val_loader, epochs: int = 100, | |
| lr: float = 0.001, weight_decay: float = 1e-4, | |
| class_weights: torch.Tensor = None): | |
| """Train a model.""" | |
| # Loss function with class weights | |
| if class_weights is not None: | |
| criterion = nn.CrossEntropyLoss(weight=class_weights.to(self.device)) | |
| else: | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode='min', factor=0.5, patience=10 | |
| ) | |
| train_losses = [] | |
| val_losses = [] | |
| val_accuracies = [] | |
| best_val_loss = float('inf') | |
| patience_counter = 0 | |
| early_stopping_patience = 20 | |
| for epoch in range(epochs): | |
| # Training phase | |
| model.train() | |
| train_loss = 0.0 | |
| for batch_x, batch_y in train_loader: | |
| batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) | |
| optimizer.zero_grad() | |
| outputs = model(batch_x) | |
| loss = criterion(outputs, batch_y) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| # Validation phase | |
| model.eval() | |
| val_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for batch_x, batch_y in val_loader: | |
| batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) | |
| outputs = model(batch_x) | |
| loss = criterion(outputs, batch_y) | |
| val_loss += loss.item() | |
| _, predicted = torch.max(outputs.data, 1) | |
| total += batch_y.size(0) | |
| correct += (predicted == batch_y).sum().item() | |
| train_loss /= len(train_loader) | |
| val_loss /= len(val_loader) | |
| val_accuracy = 100 * correct / total | |
| train_losses.append(train_loss) | |
| val_losses.append(val_loss) | |
| val_accuracies.append(val_accuracy) | |
| scheduler.step(val_loss) | |
| if epoch % 10 == 0: | |
| print(f'Epoch [{epoch}/{epochs}], Train Loss: {train_loss:.4f}, ' | |
| f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%') | |
| # Early stopping | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= early_stopping_patience: | |
| print(f'Early stopping at epoch {epoch}') | |
| break | |
| return { | |
| 'train_losses': train_losses, | |
| 'val_losses': val_losses, | |
| 'val_accuracies': val_accuracies | |
| } | |
| def evaluate_model(self, model, test_loader): | |
| """Evaluate model on test set.""" | |
| model.eval() | |
| all_predictions = [] | |
| all_targets = [] | |
| with torch.no_grad(): | |
| for batch_x, batch_y in test_loader: | |
| batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) | |
| outputs = model(batch_x) | |
| _, predicted = torch.max(outputs, 1) | |
| all_predictions.extend(predicted.cpu().numpy()) | |
| all_targets.extend(batch_y.cpu().numpy()) | |
| accuracy = accuracy_score(all_targets, all_predictions) | |
| report = classification_report(all_targets, all_predictions) | |
| cm = confusion_matrix(all_targets, all_predictions) | |
| return { | |
| 'accuracy': accuracy, | |
| 'predictions': all_predictions, | |
| 'targets': all_targets, | |
| 'classification_report': report, | |
| 'confusion_matrix': cm | |
| } | |
| def cross_validate(self, model_type: str, X, y, cv_folds: int = 5, **model_kwargs): | |
| """Perform cross-validation.""" | |
| skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42) | |
| cv_scores = [] | |
| for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)): | |
| print(f"\nFold {fold + 1}/{cv_folds}") | |
| X_train_fold, X_val_fold = X[train_idx], X[val_idx] | |
| y_train_fold, y_val_fold = y[train_idx], y[val_idx] | |
| # Create datasets and loaders | |
| train_dataset = TabularDataset(X_train_fold, y_train_fold) | |
| val_dataset = TabularDataset(X_val_fold, y_val_fold) | |
| train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) | |
| # Create and train model | |
| model = self.create_model(model_type, X.shape[1], len(np.unique(y)), **model_kwargs) | |
| # Calculate class weights for this fold | |
| class_counts = np.bincount(y_train_fold) | |
| class_weights = torch.FloatTensor(len(class_counts) / (len(class_counts) * class_counts)) | |
| # Train model | |
| history = self.train_model(model, train_loader, val_loader, | |
| epochs=50, class_weights=class_weights) | |
| # Evaluate on validation set | |
| results = self.evaluate_model(model, val_loader) | |
| cv_scores.append(results['accuracy']) | |
| print(f"Fold {fold + 1} Accuracy: {results['accuracy']:.4f}") | |
| return cv_scores | |
| def save_model(self, model, model_name: str, save_dir: str = "models/saved"): | |
| """Save trained model.""" | |
| os.makedirs(save_dir, exist_ok=True) | |
| model_path = os.path.join(save_dir, f"{model_name}_transformer.pth") | |
| torch.save(model.state_dict(), model_path) | |
| print(f"Model saved to {model_path}") | |
| def load_model(self, model_type: str, model_name: str, input_dim: int, | |
| num_classes: int, save_dir: str = "models/saved", **model_kwargs): | |
| """Load trained model.""" | |
| model_path = os.path.join(save_dir, f"{model_name}_transformer.pth") | |
| model = self.create_model(model_type, input_dim, num_classes, **model_kwargs) | |
| model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| model.eval() | |
| return model | |
| def plot_training_history(self, history: Dict, model_name: str, save_dir: str = "notebooks"): | |
| """Plot training history.""" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) | |
| # Plot losses | |
| ax1.plot(history['train_losses'], label='Train Loss') | |
| ax1.plot(history['val_losses'], label='Validation Loss') | |
| ax1.set_title(f'{model_name} - Training and Validation Loss') | |
| ax1.set_xlabel('Epoch') | |
| ax1.set_ylabel('Loss') | |
| ax1.legend() | |
| ax1.grid(True) | |
| # Plot accuracy | |
| ax2.plot(history['val_accuracies'], label='Validation Accuracy', color='green') | |
| ax2.set_title(f'{model_name} - Validation Accuracy') | |
| ax2.set_xlabel('Epoch') | |
| ax2.set_ylabel('Accuracy (%)') | |
| ax2.legend() | |
| ax2.grid(True) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(save_dir, f'{model_name}_training_history.png'), dpi=300, bbox_inches='tight') | |
| plt.close() | |
| def plot_confusion_matrix(self, cm, model_name: str, class_names: list = None, | |
| save_dir: str = "notebooks"): | |
| """Plot confusion matrix.""" | |
| plt.figure(figsize=(8, 6)) | |
| if class_names is None: | |
| class_names = [f'Class {i}' for i in range(len(cm))] | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=class_names, yticklabels=class_names) | |
| plt.title(f'{model_name} - Confusion Matrix') | |
| plt.xlabel('Predicted') | |
| plt.ylabel('Actual') | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(save_dir, f'{model_name}_transformer_confusion_matrix.png'), | |
| dpi=300, bbox_inches='tight') | |
| plt.close() |