Spaces:
Running
Running
| """Comprehensive feature pipeline integrity audit. | |
| This script performs a thorough analysis of the feature pipeline to detect: | |
| 1. Dead or constant feature blocks | |
| 2. Improper normalization | |
| 3. Data leakage between feature groups | |
| 4. Train/inference consistency issues | |
| 5. Feature importance or relevance | |
| 6. Label distribution problems | |
| """ | |
| import sys | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Any | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.decomposition import PCA | |
| from sklearn.feature_selection import mutual_info_classif | |
| ROOT = Path(__file__).resolve().parents[2] | |
| sys.path.insert(0, str(ROOT / 'src')) | |
| from training.feature_pipeline import ( | |
| build_feature_pipeline, transform_pair_features, _normalize | |
| ) | |
| MODELS_DIR = ROOT / 'models' | |
| REPORT_DIR = MODELS_DIR / 'reports' | |
| REPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| def audit_feature_pipeline() -> Dict[str, Any]: | |
| """Run comprehensive feature pipeline audit.""" | |
| print("[AUDIT] Building feature pipeline from scratch...", flush=True) | |
| pairs_df, artifacts = build_feature_pipeline(save_artifacts=False, sample_size=5000, seed=2026) | |
| X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32) | |
| y = pd.Categorical(pairs_df['label'], categories=['unknown', 'minor', 'moderate', 'major']).codes | |
| audit_results = { | |
| 'timestamp': str(pd.Timestamp.now()), | |
| 'dataset_size': len(pairs_df), | |
| 'feature_dimension': X.shape[1], | |
| 'num_classes': len(np.unique(y)), | |
| 'class_distribution': _audit_label_distribution(pairs_df['label'].values, y), | |
| 'feature_group_analysis': _audit_feature_groups(X, artifacts), | |
| 'feature_statistics': _audit_feature_statistics(X, artifacts), | |
| 'dead_features': _detect_dead_features(X, artifacts), | |
| 'train_inference_consistency': _check_train_inference_consistency(pairs_df, artifacts), | |
| 'feature_importance': _compute_feature_importance(X, y, artifacts), | |
| 'sparsity_analysis': _compute_sparsity_by_group(X, artifacts), | |
| } | |
| return audit_results, pairs_df, X, y, artifacts | |
| def _audit_label_distribution(labels_raw, y_encoded) -> Dict[str, Any]: | |
| """Check label distribution and class balance.""" | |
| label_names = ['unknown', 'minor', 'moderate', 'major'] | |
| unique, counts = np.unique(y_encoded, return_counts=True) | |
| distribution = {} | |
| for label_idx, count in zip(unique, counts): | |
| name = label_names[label_idx] if label_idx < len(label_names) else f'class_{label_idx}' | |
| distribution[name] = { | |
| 'count': int(count), | |
| 'percentage': float(100.0 * count / len(y_encoded)), | |
| } | |
| return { | |
| 'label_counts': distribution, | |
| 'is_balanced': float(counts.max()) / float(counts.min()) < 2.0, | |
| 'max_min_ratio': float(counts.max()) / float(counts.min()), | |
| } | |
| def _audit_feature_groups(X: np.ndarray, artifacts) -> Dict[str, Dict[str, Any]]: | |
| """Analyze each feature group independently.""" | |
| group_slices = artifacts.group_slices | |
| results = {} | |
| for group_name, (start, end) in group_slices.items(): | |
| block = X[:, start:end] | |
| sparsity = float(np.mean(block == 0.0)) | |
| results[group_name] = { | |
| 'start_idx': int(start), | |
| 'end_idx': int(end), | |
| 'num_features': int(end - start), | |
| 'min': float(np.min(block)), | |
| 'max': float(np.max(block)), | |
| 'mean': float(np.mean(block)), | |
| 'std': float(np.std(block)), | |
| 'sparsity': sparsity, | |
| 'num_zeros': int(np.sum(block == 0.0)), | |
| 'num_non_zeros': int(np.sum(block != 0.0)), | |
| 'is_all_zeros': bool(np.allclose(block, 0.0)), | |
| 'is_constant': bool(np.allclose(block, block[0, :])), | |
| } | |
| return results | |
| def _audit_feature_statistics(X: np.ndarray, artifacts) -> Dict[str, Any]: | |
| """Global feature statistics.""" | |
| return { | |
| 'total_features': int(X.shape[0]), | |
| 'total_dimensions': int(X.shape[1]), | |
| 'global_mean': float(np.mean(X)), | |
| 'global_std': float(np.std(X)), | |
| 'global_min': float(np.min(X)), | |
| 'global_max': float(np.max(X)), | |
| 'global_sparsity': float(np.mean(X == 0.0)), | |
| 'num_all_zero_rows': int(np.sum(np.all(X == 0.0, axis=1))), | |
| 'num_all_zero_cols': int(np.sum(np.all(X == 0.0, axis=0))), | |
| } | |
| def _detect_dead_features(X: np.ndarray, artifacts) -> Dict[str, Any]: | |
| """Find features that are constant, all-zero, or near-zero variance.""" | |
| group_slices = artifacts.group_slices | |
| dead_features = [] | |
| for col_idx in range(X.shape[1]): | |
| col = X[:, col_idx] | |
| variance = np.var(col) | |
| is_dead = ( | |
| np.allclose(col, col[0]) or # constant | |
| np.allclose(col, 0.0) or # all zeros | |
| variance < 1e-9 # near-zero variance | |
| ) | |
| if is_dead: | |
| # Find which group this belongs to: | |
| group_name = None | |
| for name, (start, end) in group_slices.items(): | |
| if start <= col_idx < end: | |
| group_name = name | |
| break | |
| dead_features.append({ | |
| 'column_index': int(col_idx), | |
| 'group': group_name, | |
| 'variance': float(variance), | |
| 'reason': 'constant' if np.allclose(col, col[0]) else ('all_zeros' if np.allclose(col, 0.0) else 'near_zero_variance'), | |
| }) | |
| return { | |
| 'num_dead_features': len(dead_features), | |
| 'dead_features': dead_features[:50], # Top 50 | |
| } | |
| def _compute_sparsity_by_group(X: np.ndarray, artifacts) -> Dict[str, float]: | |
| """Sparsity per feature group.""" | |
| group_slices = artifacts.group_slices | |
| results = {} | |
| for group_name, (start, end) in group_slices.items(): | |
| block = X[:, start:end] | |
| sparsity = float(np.mean(block == 0.0)) | |
| results[group_name] = sparsity | |
| return results | |
| def _check_train_inference_consistency(pairs_df: pd.DataFrame, artifacts) -> Dict[str, Any]: | |
| """Compare training features with inference features for same pairs.""" | |
| print("[AUDIT] Checking train/inference consistency...", flush=True) | |
| # Get training features from pairs_df | |
| train_X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32) | |
| # Sample 50 random pairs and recompute via inference path | |
| sample_indices = np.random.choice(len(pairs_df), size=min(50, len(pairs_df)), replace=False) | |
| inconsistencies = [] | |
| for idx in sample_indices: | |
| row = pairs_df.iloc[idx] | |
| drug_a = row['drug_a'] | |
| drug_b = row['drug_b'] | |
| # Training feature | |
| train_feat = train_X[idx] | |
| # Inference feature | |
| try: | |
| inference_feat = transform_pair_features(drug_a, drug_b, { | |
| 'ngram_vocabulary': artifacts.ngram_vocabulary, | |
| 'scaler': artifacts.scaler, | |
| 'drugbank_map': artifacts.drugbank_map, | |
| 'twosides_map': artifacts.twosides_map, | |
| 'partner_graph': artifacts.partner_graph, | |
| 'metadata': artifacts.metadata, | |
| }) | |
| except Exception as e: | |
| inconsistencies.append({ | |
| 'pair_index': int(idx), | |
| 'drug_a': drug_a, | |
| 'drug_b': drug_b, | |
| 'error': str(e), | |
| }) | |
| continue | |
| # Compare dimensions | |
| if train_feat.shape[0] != inference_feat.shape[0]: | |
| inconsistencies.append({ | |
| 'pair_index': int(idx), | |
| 'drug_a': drug_a, | |
| 'drug_b': drug_b, | |
| 'train_dim': int(train_feat.shape[0]), | |
| 'inference_dim': int(inference_feat.shape[0]), | |
| 'issue': 'dimension_mismatch', | |
| }) | |
| continue | |
| # Compare values | |
| max_diff = np.max(np.abs(train_feat - inference_feat)) | |
| mean_diff = np.mean(np.abs(train_feat - inference_feat)) | |
| if max_diff > 1e-4: # Allow small numerical differences | |
| inconsistencies.append({ | |
| 'pair_index': int(idx), | |
| 'drug_a': drug_a, | |
| 'drug_b': drug_b, | |
| 'max_diff': float(max_diff), | |
| 'mean_diff': float(mean_diff), | |
| 'issue': 'value_mismatch', | |
| }) | |
| return { | |
| 'num_samples_checked': len(sample_indices), | |
| 'num_inconsistencies': len(inconsistencies), | |
| 'inconsistencies': inconsistencies[:20], # Top 20 | |
| } | |
| def _compute_feature_importance(X: np.ndarray, y: np.ndarray, artifacts) -> Dict[str, Any]: | |
| """Compute mutual information and variance-based importance.""" | |
| print("[AUDIT] Computing feature importance...", flush=True) | |
| # Mutual information with label | |
| mi_scores = mutual_info_classif(X, y, random_state=2026) | |
| # Variance | |
| variances = np.var(X, axis=0) | |
| # Feature-wise statistics | |
| group_slices = artifacts.group_slices | |
| group_importance = {} | |
| for group_name, (start, end) in group_slices.items(): | |
| group_mi = mi_scores[start:end] | |
| group_var = variances[start:end] | |
| group_importance[group_name] = { | |
| 'mean_mi': float(np.mean(group_mi)), | |
| 'max_mi': float(np.max(group_mi)), | |
| 'min_mi': float(np.min(group_mi)), | |
| 'std_mi': float(np.std(group_mi)), | |
| 'mean_variance': float(np.mean(group_var)), | |
| 'max_variance': float(np.max(group_var)), | |
| } | |
| return { | |
| 'mean_mi_all_features': float(np.mean(mi_scores)), | |
| 'max_mi_feature': float(np.max(mi_scores)), | |
| 'min_mi_feature': float(np.min(mi_scores)), | |
| 'group_importance': group_importance, | |
| 'top_5_important_features': [ | |
| { | |
| 'feature_index': int(idx), | |
| 'mi_score': float(mi_scores[idx]), | |
| 'variance': float(variances[idx]), | |
| } | |
| for idx in np.argsort(mi_scores)[-5:][::-1] | |
| ], | |
| } | |
| def _create_visualizations(audit_results: Dict[str, Any], X: np.ndarray, y: np.ndarray, artifacts): | |
| """Generate diagnostic plots.""" | |
| print("[AUDIT] Creating visualizations...", flush=True) | |
| # 1. Feature group statistics | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 10)) | |
| groups = list(audit_results['feature_group_analysis'].keys()) | |
| means = [audit_results['feature_group_analysis'][g]['mean'] for g in groups] | |
| stds = [audit_results['feature_group_analysis'][g]['std'] for g in groups] | |
| sparsities = [audit_results['feature_group_analysis'][g]['sparsity'] for g in groups] | |
| ax = axes[0, 0] | |
| ax.bar(range(len(groups)), means) | |
| ax.set_xticks(range(len(groups))) | |
| ax.set_xticklabels(groups, rotation=45, ha='right') | |
| ax.set_title('Mean Value by Feature Group') | |
| ax.set_ylabel('Mean') | |
| ax = axes[0, 1] | |
| ax.bar(range(len(groups)), stds) | |
| ax.set_xticks(range(len(groups))) | |
| ax.set_xticklabels(groups, rotation=45, ha='right') | |
| ax.set_title('Std Dev by Feature Group') | |
| ax.set_ylabel('Std Dev') | |
| ax = axes[1, 0] | |
| ax.bar(range(len(groups)), sparsities) | |
| ax.set_xticks(range(len(groups))) | |
| ax.set_xticklabels(groups, rotation=45, ha='right') | |
| ax.set_title('Sparsity by Feature Group') | |
| ax.set_ylabel('Proportion of Zeros') | |
| # 2. Label distribution | |
| ax = axes[1, 1] | |
| label_names = ['unknown', 'minor', 'moderate', 'major'] | |
| label_counts = audit_results['class_distribution']['label_counts'] | |
| counts = [label_counts[name]['count'] for name in label_names if name in label_counts] | |
| ax.bar(label_names[:len(counts)], counts, color=['red', 'orange', 'yellow', 'green'][:len(counts)]) | |
| ax.set_title('Label Distribution') | |
| ax.set_ylabel('Count') | |
| fig.tight_layout() | |
| fig.savefig(REPORT_DIR / 'feature_group_statistics.png', dpi=150) | |
| plt.close(fig) | |
| print(f" Saved: feature_group_statistics.png") | |
| # 2. Feature importance by group | |
| if 'feature_importance' in audit_results: | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
| group_importance = audit_results['feature_importance']['group_importance'] | |
| groups = list(group_importance.keys()) | |
| mean_mis = [group_importance[g]['mean_mi'] for g in groups] | |
| mean_vars = [group_importance[g]['mean_variance'] for g in groups] | |
| ax = axes[0] | |
| ax.bar(range(len(groups)), mean_mis) | |
| ax.set_xticks(range(len(groups))) | |
| ax.set_xticklabels(groups, rotation=45, ha='right') | |
| ax.set_title('Mean Mutual Information by Group') | |
| ax.set_ylabel('MI Score') | |
| ax = axes[1] | |
| ax.bar(range(len(groups)), mean_vars) | |
| ax.set_xticks(range(len(groups))) | |
| ax.set_xticklabels(groups, rotation=45, ha='right') | |
| ax.set_title('Mean Variance by Group') | |
| ax.set_ylabel('Variance') | |
| fig.tight_layout() | |
| fig.savefig(REPORT_DIR / 'feature_importance.png', dpi=150) | |
| plt.close(fig) | |
| print(f" Saved: feature_importance.png") | |
| def _generate_report(audit_results: Dict[str, Any]) -> str: | |
| """Generate markdown report.""" | |
| lines = [ | |
| "# Feature Pipeline Integrity Audit Report", | |
| "", | |
| f"**Generated:** {audit_results['timestamp']}", | |
| "", | |
| "## Executive Summary", | |
| "", | |
| ] | |
| # Critical issues | |
| critical_issues = [] | |
| if audit_results['dead_features']['num_dead_features'] > 0: | |
| critical_issues.append( | |
| f"β οΈ **{audit_results['dead_features']['num_dead_features']} dead features detected** β " | |
| f"These contribute no signal and should be removed." | |
| ) | |
| class_dist = audit_results['class_distribution'] | |
| if not class_dist['is_balanced']: | |
| critical_issues.append( | |
| f"β οΈ **Severe class imbalance** β max/min ratio = {class_dist['max_min_ratio']:.1f}. " | |
| f"This creates learning difficulty." | |
| ) | |
| train_inf_issues = audit_results['train_inference_consistency'] | |
| if train_inf_issues['num_inconsistencies'] > 0: | |
| critical_issues.append( | |
| f"π΄ **Train/inference consistency failures:** {train_inf_issues['num_inconsistencies']} " | |
| f"mismatches found. Inference predictions may differ from training." | |
| ) | |
| if critical_issues: | |
| lines.append("### β Critical Issues") | |
| lines.extend([""] + critical_issues + [""]) | |
| else: | |
| lines.append("β **No critical issues detected.**\n") | |
| # Dataset overview | |
| lines.extend([ | |
| "## Dataset Overview", | |
| "", | |
| f"- **Number of samples:** {audit_results['dataset_size']:,}", | |
| f"- **Feature dimension:** {audit_results['feature_dimension']}", | |
| f"- **Number of classes:** {audit_results['num_classes']}", | |
| "", | |
| ]) | |
| # Label distribution | |
| lines.append("### Label Distribution") | |
| lines.append("") | |
| label_dist = audit_results['class_distribution']['label_counts'] | |
| for label_name in ['unknown', 'minor', 'moderate', 'major']: | |
| if label_name in label_dist: | |
| info = label_dist[label_name] | |
| lines.append(f"- **{label_name}:** {info['count']} samples ({info['percentage']:.1f}%)") | |
| lines.append("") | |
| # Feature group analysis | |
| lines.append("## Feature Group Analysis") | |
| lines.append("") | |
| groups = audit_results['feature_group_analysis'] | |
| group_data = [] | |
| for group_name in sorted(groups.keys()): | |
| g = groups[group_name] | |
| group_data.append({ | |
| 'Group': group_name, | |
| 'Dims': g['num_features'], | |
| 'Mean': f"{g['mean']:.3f}", | |
| 'Std': f"{g['std']:.3f}", | |
| 'Min': f"{g['min']:.3f}", | |
| 'Max': f"{g['max']:.3f}", | |
| 'Sparse': f"{g['sparsity']:.1%}", | |
| 'Dead?': 'β οΈ YES' if g['is_all_zeros'] or g['is_constant'] else 'No', | |
| }) | |
| group_df = pd.DataFrame(group_data) | |
| lines.append(group_df.to_string(index=False)) | |
| lines.append("") | |
| # Dead features | |
| if audit_results['dead_features']['num_dead_features'] > 0: | |
| lines.extend([ | |
| "## Dead Features Detection", | |
| "", | |
| f"Found **{audit_results['dead_features']['num_dead_features']} dead features** " | |
| "(constant, all-zero, or near-zero variance):", | |
| "", | |
| ]) | |
| for dead in audit_results['dead_features']['dead_features'][:10]: | |
| lines.append(f"- Col {dead['column_index']} ({dead['group']}): {dead['reason']}") | |
| if len(audit_results['dead_features']['dead_features']) > 10: | |
| lines.append(f"- ... and {len(audit_results['dead_features']['dead_features']) - 10} more") | |
| lines.append("") | |
| # Feature importance | |
| if 'feature_importance' in audit_results: | |
| lines.extend([ | |
| "## Feature Importance Analysis", | |
| "", | |
| ]) | |
| importance = audit_results['feature_importance'] | |
| lines.append(f"**Mean MI across all features:** {importance['mean_mi_all_features']:.4f}") | |
| lines.append(f"**Max MI (single feature):** {importance['max_mi_feature']:.4f}") | |
| lines.append(f"**Min MI (single feature):** {importance['min_mi_feature']:.4f}") | |
| lines.append("") | |
| lines.append("### By Group:") | |
| importance_data = [] | |
| for group_name, group_info in importance['group_importance'].items(): | |
| importance_data.append({ | |
| 'Group': group_name, | |
| 'Mean MI': f"{group_info['mean_mi']:.4f}", | |
| 'Max MI': f"{group_info['max_mi']:.4f}", | |
| 'Mean Var': f"{group_info['mean_variance']:.4f}", | |
| }) | |
| importance_df = pd.DataFrame(importance_data) | |
| lines.append("") | |
| lines.append(importance_df.to_string(index=False)) | |
| lines.append("") | |
| lines.append("**Top 5 Most Important Features:**") | |
| lines.append("") | |
| for feat in importance['top_5_important_features']: | |
| lines.append(f"- Feature {feat['feature_index']}: MI={feat['mi_score']:.4f}, Var={feat['variance']:.4f}") | |
| lines.append("") | |
| # Train/inference consistency | |
| lines.extend([ | |
| "## Train/Inference Consistency Check", | |
| "", | |
| ]) | |
| consistency = audit_results['train_inference_consistency'] | |
| lines.append(f"**Samples checked:** {consistency['num_samples_checked']}") | |
| lines.append(f"**Inconsistencies found:** {consistency['num_inconsistencies']}") | |
| lines.append("") | |
| if consistency['num_inconsistencies'] > 0: | |
| lines.append("**Issues detected:**") | |
| lines.append("") | |
| for issue in consistency['inconsistencies'][:5]: | |
| if 'error' in issue: | |
| lines.append(f"- {issue['drug_a']} + {issue['drug_b']}: ERROR: {issue['error']}") | |
| elif 'dimension_mismatch' in issue.get('issue', ''): | |
| lines.append(f"- {issue['drug_a']} + {issue['drug_b']}: " | |
| f"Train dim={issue['train_dim']}, Inference dim={issue['inference_dim']}") | |
| elif 'value_mismatch' in issue.get('issue', ''): | |
| lines.append(f"- {issue['drug_a']} + {issue['drug_b']}: " | |
| f"Max diff={issue['max_diff']:.2e}, Mean diff={issue['mean_diff']:.2e}") | |
| lines.append("") | |
| else: | |
| lines.append("β **No consistency issues detected.**") | |
| lines.append("") | |
| # Recommendations | |
| lines.extend([ | |
| "## Recommendations", | |
| "", | |
| ]) | |
| recommendations = [] | |
| if audit_results['dead_features']['num_dead_features'] > 0: | |
| recommendations.append( | |
| "1. **Remove dead features** β Features with zero variance waste model capacity. " | |
| "Recompute feature pipeline excluding these columns." | |
| ) | |
| if not audit_results['class_distribution']['is_balanced']: | |
| recommendations.append( | |
| "2. **Address class imbalance** β The severe imbalance may cause the model to ignore " | |
| "rare classes. Consider reweighting losses, oversampling, or SMOTE." | |
| ) | |
| if consistency['num_inconsistencies'] > 0: | |
| recommendations.append( | |
| "3. **Fix train/inference mismatch** β The feature construction differs between training " | |
| "and inference. Ensure the same preprocessing is applied in both paths." | |
| ) | |
| if not recommendations: | |
| recommendations.append( | |
| "β The feature pipeline appears to be in good shape. Proceed with retraining on larger sample." | |
| ) | |
| lines.extend([""] + recommendations + [""]) | |
| lines.extend([ | |
| "## Artifacts", | |
| "", | |
| "- feature_group_statistics.png", | |
| "- feature_importance.png", | |
| "- feature_integrity_audit.md (this file)", | |
| ]) | |
| return "\n".join(lines) | |
| def main(): | |
| """Run audit and generate report.""" | |
| print("[AUDIT] Starting feature pipeline integrity audit...\n", flush=True) | |
| audit_results, pairs_df, X, y, artifacts = audit_feature_pipeline() | |
| # Visualizations | |
| _create_visualizations(audit_results, X, y, artifacts) | |
| # Report | |
| report_text = _generate_report(audit_results) | |
| report_path = REPORT_DIR / 'feature_integrity_audit.md' | |
| report_path.write_text(report_text, encoding='utf-8') | |
| print(f"\n[AUDIT] Report saved: {report_path}", flush=True) | |
| # JSON summary | |
| json_path = REPORT_DIR / 'feature_integrity_audit.json' | |
| json_path.write_text(json.dumps(audit_results, indent=2, default=str), encoding='utf-8') | |
| print(f"[AUDIT] JSON summary saved: {json_path}", flush=True) | |
| print("\n" + "="*60) | |
| print(report_text) | |
| print("="*60) | |
| if __name__ == '__main__': | |
| main() | |