ddi / src /training /feature_audit.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Comprehensive feature pipeline integrity audit.
This script performs a thorough analysis of the feature pipeline to detect:
1. Dead or constant feature blocks
2. Improper normalization
3. Data leakage between feature groups
4. Train/inference consistency issues
5. Feature importance or relevance
6. Label distribution problems
"""
import sys
import json
from pathlib import Path
from typing import Dict, List, Tuple, Any
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.feature_selection import mutual_info_classif
ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT / 'src'))
from training.feature_pipeline import (
build_feature_pipeline, transform_pair_features, _normalize
)
MODELS_DIR = ROOT / 'models'
REPORT_DIR = MODELS_DIR / 'reports'
REPORT_DIR.mkdir(parents=True, exist_ok=True)
def audit_feature_pipeline() -> Dict[str, Any]:
"""Run comprehensive feature pipeline audit."""
print("[AUDIT] Building feature pipeline from scratch...", flush=True)
pairs_df, artifacts = build_feature_pipeline(save_artifacts=False, sample_size=5000, seed=2026)
X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32)
y = pd.Categorical(pairs_df['label'], categories=['unknown', 'minor', 'moderate', 'major']).codes
audit_results = {
'timestamp': str(pd.Timestamp.now()),
'dataset_size': len(pairs_df),
'feature_dimension': X.shape[1],
'num_classes': len(np.unique(y)),
'class_distribution': _audit_label_distribution(pairs_df['label'].values, y),
'feature_group_analysis': _audit_feature_groups(X, artifacts),
'feature_statistics': _audit_feature_statistics(X, artifacts),
'dead_features': _detect_dead_features(X, artifacts),
'train_inference_consistency': _check_train_inference_consistency(pairs_df, artifacts),
'feature_importance': _compute_feature_importance(X, y, artifacts),
'sparsity_analysis': _compute_sparsity_by_group(X, artifacts),
}
return audit_results, pairs_df, X, y, artifacts
def _audit_label_distribution(labels_raw, y_encoded) -> Dict[str, Any]:
"""Check label distribution and class balance."""
label_names = ['unknown', 'minor', 'moderate', 'major']
unique, counts = np.unique(y_encoded, return_counts=True)
distribution = {}
for label_idx, count in zip(unique, counts):
name = label_names[label_idx] if label_idx < len(label_names) else f'class_{label_idx}'
distribution[name] = {
'count': int(count),
'percentage': float(100.0 * count / len(y_encoded)),
}
return {
'label_counts': distribution,
'is_balanced': float(counts.max()) / float(counts.min()) < 2.0,
'max_min_ratio': float(counts.max()) / float(counts.min()),
}
def _audit_feature_groups(X: np.ndarray, artifacts) -> Dict[str, Dict[str, Any]]:
"""Analyze each feature group independently."""
group_slices = artifacts.group_slices
results = {}
for group_name, (start, end) in group_slices.items():
block = X[:, start:end]
sparsity = float(np.mean(block == 0.0))
results[group_name] = {
'start_idx': int(start),
'end_idx': int(end),
'num_features': int(end - start),
'min': float(np.min(block)),
'max': float(np.max(block)),
'mean': float(np.mean(block)),
'std': float(np.std(block)),
'sparsity': sparsity,
'num_zeros': int(np.sum(block == 0.0)),
'num_non_zeros': int(np.sum(block != 0.0)),
'is_all_zeros': bool(np.allclose(block, 0.0)),
'is_constant': bool(np.allclose(block, block[0, :])),
}
return results
def _audit_feature_statistics(X: np.ndarray, artifacts) -> Dict[str, Any]:
"""Global feature statistics."""
return {
'total_features': int(X.shape[0]),
'total_dimensions': int(X.shape[1]),
'global_mean': float(np.mean(X)),
'global_std': float(np.std(X)),
'global_min': float(np.min(X)),
'global_max': float(np.max(X)),
'global_sparsity': float(np.mean(X == 0.0)),
'num_all_zero_rows': int(np.sum(np.all(X == 0.0, axis=1))),
'num_all_zero_cols': int(np.sum(np.all(X == 0.0, axis=0))),
}
def _detect_dead_features(X: np.ndarray, artifacts) -> Dict[str, Any]:
"""Find features that are constant, all-zero, or near-zero variance."""
group_slices = artifacts.group_slices
dead_features = []
for col_idx in range(X.shape[1]):
col = X[:, col_idx]
variance = np.var(col)
is_dead = (
np.allclose(col, col[0]) or # constant
np.allclose(col, 0.0) or # all zeros
variance < 1e-9 # near-zero variance
)
if is_dead:
# Find which group this belongs to:
group_name = None
for name, (start, end) in group_slices.items():
if start <= col_idx < end:
group_name = name
break
dead_features.append({
'column_index': int(col_idx),
'group': group_name,
'variance': float(variance),
'reason': 'constant' if np.allclose(col, col[0]) else ('all_zeros' if np.allclose(col, 0.0) else 'near_zero_variance'),
})
return {
'num_dead_features': len(dead_features),
'dead_features': dead_features[:50], # Top 50
}
def _compute_sparsity_by_group(X: np.ndarray, artifacts) -> Dict[str, float]:
"""Sparsity per feature group."""
group_slices = artifacts.group_slices
results = {}
for group_name, (start, end) in group_slices.items():
block = X[:, start:end]
sparsity = float(np.mean(block == 0.0))
results[group_name] = sparsity
return results
def _check_train_inference_consistency(pairs_df: pd.DataFrame, artifacts) -> Dict[str, Any]:
"""Compare training features with inference features for same pairs."""
print("[AUDIT] Checking train/inference consistency...", flush=True)
# Get training features from pairs_df
train_X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32)
# Sample 50 random pairs and recompute via inference path
sample_indices = np.random.choice(len(pairs_df), size=min(50, len(pairs_df)), replace=False)
inconsistencies = []
for idx in sample_indices:
row = pairs_df.iloc[idx]
drug_a = row['drug_a']
drug_b = row['drug_b']
# Training feature
train_feat = train_X[idx]
# Inference feature
try:
inference_feat = transform_pair_features(drug_a, drug_b, {
'ngram_vocabulary': artifacts.ngram_vocabulary,
'scaler': artifacts.scaler,
'drugbank_map': artifacts.drugbank_map,
'twosides_map': artifacts.twosides_map,
'partner_graph': artifacts.partner_graph,
'metadata': artifacts.metadata,
})
except Exception as e:
inconsistencies.append({
'pair_index': int(idx),
'drug_a': drug_a,
'drug_b': drug_b,
'error': str(e),
})
continue
# Compare dimensions
if train_feat.shape[0] != inference_feat.shape[0]:
inconsistencies.append({
'pair_index': int(idx),
'drug_a': drug_a,
'drug_b': drug_b,
'train_dim': int(train_feat.shape[0]),
'inference_dim': int(inference_feat.shape[0]),
'issue': 'dimension_mismatch',
})
continue
# Compare values
max_diff = np.max(np.abs(train_feat - inference_feat))
mean_diff = np.mean(np.abs(train_feat - inference_feat))
if max_diff > 1e-4: # Allow small numerical differences
inconsistencies.append({
'pair_index': int(idx),
'drug_a': drug_a,
'drug_b': drug_b,
'max_diff': float(max_diff),
'mean_diff': float(mean_diff),
'issue': 'value_mismatch',
})
return {
'num_samples_checked': len(sample_indices),
'num_inconsistencies': len(inconsistencies),
'inconsistencies': inconsistencies[:20], # Top 20
}
def _compute_feature_importance(X: np.ndarray, y: np.ndarray, artifacts) -> Dict[str, Any]:
"""Compute mutual information and variance-based importance."""
print("[AUDIT] Computing feature importance...", flush=True)
# Mutual information with label
mi_scores = mutual_info_classif(X, y, random_state=2026)
# Variance
variances = np.var(X, axis=0)
# Feature-wise statistics
group_slices = artifacts.group_slices
group_importance = {}
for group_name, (start, end) in group_slices.items():
group_mi = mi_scores[start:end]
group_var = variances[start:end]
group_importance[group_name] = {
'mean_mi': float(np.mean(group_mi)),
'max_mi': float(np.max(group_mi)),
'min_mi': float(np.min(group_mi)),
'std_mi': float(np.std(group_mi)),
'mean_variance': float(np.mean(group_var)),
'max_variance': float(np.max(group_var)),
}
return {
'mean_mi_all_features': float(np.mean(mi_scores)),
'max_mi_feature': float(np.max(mi_scores)),
'min_mi_feature': float(np.min(mi_scores)),
'group_importance': group_importance,
'top_5_important_features': [
{
'feature_index': int(idx),
'mi_score': float(mi_scores[idx]),
'variance': float(variances[idx]),
}
for idx in np.argsort(mi_scores)[-5:][::-1]
],
}
def _create_visualizations(audit_results: Dict[str, Any], X: np.ndarray, y: np.ndarray, artifacts):
"""Generate diagnostic plots."""
print("[AUDIT] Creating visualizations...", flush=True)
# 1. Feature group statistics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
groups = list(audit_results['feature_group_analysis'].keys())
means = [audit_results['feature_group_analysis'][g]['mean'] for g in groups]
stds = [audit_results['feature_group_analysis'][g]['std'] for g in groups]
sparsities = [audit_results['feature_group_analysis'][g]['sparsity'] for g in groups]
ax = axes[0, 0]
ax.bar(range(len(groups)), means)
ax.set_xticks(range(len(groups)))
ax.set_xticklabels(groups, rotation=45, ha='right')
ax.set_title('Mean Value by Feature Group')
ax.set_ylabel('Mean')
ax = axes[0, 1]
ax.bar(range(len(groups)), stds)
ax.set_xticks(range(len(groups)))
ax.set_xticklabels(groups, rotation=45, ha='right')
ax.set_title('Std Dev by Feature Group')
ax.set_ylabel('Std Dev')
ax = axes[1, 0]
ax.bar(range(len(groups)), sparsities)
ax.set_xticks(range(len(groups)))
ax.set_xticklabels(groups, rotation=45, ha='right')
ax.set_title('Sparsity by Feature Group')
ax.set_ylabel('Proportion of Zeros')
# 2. Label distribution
ax = axes[1, 1]
label_names = ['unknown', 'minor', 'moderate', 'major']
label_counts = audit_results['class_distribution']['label_counts']
counts = [label_counts[name]['count'] for name in label_names if name in label_counts]
ax.bar(label_names[:len(counts)], counts, color=['red', 'orange', 'yellow', 'green'][:len(counts)])
ax.set_title('Label Distribution')
ax.set_ylabel('Count')
fig.tight_layout()
fig.savefig(REPORT_DIR / 'feature_group_statistics.png', dpi=150)
plt.close(fig)
print(f" Saved: feature_group_statistics.png")
# 2. Feature importance by group
if 'feature_importance' in audit_results:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
group_importance = audit_results['feature_importance']['group_importance']
groups = list(group_importance.keys())
mean_mis = [group_importance[g]['mean_mi'] for g in groups]
mean_vars = [group_importance[g]['mean_variance'] for g in groups]
ax = axes[0]
ax.bar(range(len(groups)), mean_mis)
ax.set_xticks(range(len(groups)))
ax.set_xticklabels(groups, rotation=45, ha='right')
ax.set_title('Mean Mutual Information by Group')
ax.set_ylabel('MI Score')
ax = axes[1]
ax.bar(range(len(groups)), mean_vars)
ax.set_xticks(range(len(groups)))
ax.set_xticklabels(groups, rotation=45, ha='right')
ax.set_title('Mean Variance by Group')
ax.set_ylabel('Variance')
fig.tight_layout()
fig.savefig(REPORT_DIR / 'feature_importance.png', dpi=150)
plt.close(fig)
print(f" Saved: feature_importance.png")
def _generate_report(audit_results: Dict[str, Any]) -> str:
"""Generate markdown report."""
lines = [
"# Feature Pipeline Integrity Audit Report",
"",
f"**Generated:** {audit_results['timestamp']}",
"",
"## Executive Summary",
"",
]
# Critical issues
critical_issues = []
if audit_results['dead_features']['num_dead_features'] > 0:
critical_issues.append(
f"⚠️ **{audit_results['dead_features']['num_dead_features']} dead features detected** β€” "
f"These contribute no signal and should be removed."
)
class_dist = audit_results['class_distribution']
if not class_dist['is_balanced']:
critical_issues.append(
f"⚠️ **Severe class imbalance** β€” max/min ratio = {class_dist['max_min_ratio']:.1f}. "
f"This creates learning difficulty."
)
train_inf_issues = audit_results['train_inference_consistency']
if train_inf_issues['num_inconsistencies'] > 0:
critical_issues.append(
f"πŸ”΄ **Train/inference consistency failures:** {train_inf_issues['num_inconsistencies']} "
f"mismatches found. Inference predictions may differ from training."
)
if critical_issues:
lines.append("### β›” Critical Issues")
lines.extend([""] + critical_issues + [""])
else:
lines.append("βœ… **No critical issues detected.**\n")
# Dataset overview
lines.extend([
"## Dataset Overview",
"",
f"- **Number of samples:** {audit_results['dataset_size']:,}",
f"- **Feature dimension:** {audit_results['feature_dimension']}",
f"- **Number of classes:** {audit_results['num_classes']}",
"",
])
# Label distribution
lines.append("### Label Distribution")
lines.append("")
label_dist = audit_results['class_distribution']['label_counts']
for label_name in ['unknown', 'minor', 'moderate', 'major']:
if label_name in label_dist:
info = label_dist[label_name]
lines.append(f"- **{label_name}:** {info['count']} samples ({info['percentage']:.1f}%)")
lines.append("")
# Feature group analysis
lines.append("## Feature Group Analysis")
lines.append("")
groups = audit_results['feature_group_analysis']
group_data = []
for group_name in sorted(groups.keys()):
g = groups[group_name]
group_data.append({
'Group': group_name,
'Dims': g['num_features'],
'Mean': f"{g['mean']:.3f}",
'Std': f"{g['std']:.3f}",
'Min': f"{g['min']:.3f}",
'Max': f"{g['max']:.3f}",
'Sparse': f"{g['sparsity']:.1%}",
'Dead?': '⚠️ YES' if g['is_all_zeros'] or g['is_constant'] else 'No',
})
group_df = pd.DataFrame(group_data)
lines.append(group_df.to_string(index=False))
lines.append("")
# Dead features
if audit_results['dead_features']['num_dead_features'] > 0:
lines.extend([
"## Dead Features Detection",
"",
f"Found **{audit_results['dead_features']['num_dead_features']} dead features** "
"(constant, all-zero, or near-zero variance):",
"",
])
for dead in audit_results['dead_features']['dead_features'][:10]:
lines.append(f"- Col {dead['column_index']} ({dead['group']}): {dead['reason']}")
if len(audit_results['dead_features']['dead_features']) > 10:
lines.append(f"- ... and {len(audit_results['dead_features']['dead_features']) - 10} more")
lines.append("")
# Feature importance
if 'feature_importance' in audit_results:
lines.extend([
"## Feature Importance Analysis",
"",
])
importance = audit_results['feature_importance']
lines.append(f"**Mean MI across all features:** {importance['mean_mi_all_features']:.4f}")
lines.append(f"**Max MI (single feature):** {importance['max_mi_feature']:.4f}")
lines.append(f"**Min MI (single feature):** {importance['min_mi_feature']:.4f}")
lines.append("")
lines.append("### By Group:")
importance_data = []
for group_name, group_info in importance['group_importance'].items():
importance_data.append({
'Group': group_name,
'Mean MI': f"{group_info['mean_mi']:.4f}",
'Max MI': f"{group_info['max_mi']:.4f}",
'Mean Var': f"{group_info['mean_variance']:.4f}",
})
importance_df = pd.DataFrame(importance_data)
lines.append("")
lines.append(importance_df.to_string(index=False))
lines.append("")
lines.append("**Top 5 Most Important Features:**")
lines.append("")
for feat in importance['top_5_important_features']:
lines.append(f"- Feature {feat['feature_index']}: MI={feat['mi_score']:.4f}, Var={feat['variance']:.4f}")
lines.append("")
# Train/inference consistency
lines.extend([
"## Train/Inference Consistency Check",
"",
])
consistency = audit_results['train_inference_consistency']
lines.append(f"**Samples checked:** {consistency['num_samples_checked']}")
lines.append(f"**Inconsistencies found:** {consistency['num_inconsistencies']}")
lines.append("")
if consistency['num_inconsistencies'] > 0:
lines.append("**Issues detected:**")
lines.append("")
for issue in consistency['inconsistencies'][:5]:
if 'error' in issue:
lines.append(f"- {issue['drug_a']} + {issue['drug_b']}: ERROR: {issue['error']}")
elif 'dimension_mismatch' in issue.get('issue', ''):
lines.append(f"- {issue['drug_a']} + {issue['drug_b']}: "
f"Train dim={issue['train_dim']}, Inference dim={issue['inference_dim']}")
elif 'value_mismatch' in issue.get('issue', ''):
lines.append(f"- {issue['drug_a']} + {issue['drug_b']}: "
f"Max diff={issue['max_diff']:.2e}, Mean diff={issue['mean_diff']:.2e}")
lines.append("")
else:
lines.append("βœ… **No consistency issues detected.**")
lines.append("")
# Recommendations
lines.extend([
"## Recommendations",
"",
])
recommendations = []
if audit_results['dead_features']['num_dead_features'] > 0:
recommendations.append(
"1. **Remove dead features** β€” Features with zero variance waste model capacity. "
"Recompute feature pipeline excluding these columns."
)
if not audit_results['class_distribution']['is_balanced']:
recommendations.append(
"2. **Address class imbalance** β€” The severe imbalance may cause the model to ignore "
"rare classes. Consider reweighting losses, oversampling, or SMOTE."
)
if consistency['num_inconsistencies'] > 0:
recommendations.append(
"3. **Fix train/inference mismatch** β€” The feature construction differs between training "
"and inference. Ensure the same preprocessing is applied in both paths."
)
if not recommendations:
recommendations.append(
"βœ… The feature pipeline appears to be in good shape. Proceed with retraining on larger sample."
)
lines.extend([""] + recommendations + [""])
lines.extend([
"## Artifacts",
"",
"- feature_group_statistics.png",
"- feature_importance.png",
"- feature_integrity_audit.md (this file)",
])
return "\n".join(lines)
def main():
"""Run audit and generate report."""
print("[AUDIT] Starting feature pipeline integrity audit...\n", flush=True)
audit_results, pairs_df, X, y, artifacts = audit_feature_pipeline()
# Visualizations
_create_visualizations(audit_results, X, y, artifacts)
# Report
report_text = _generate_report(audit_results)
report_path = REPORT_DIR / 'feature_integrity_audit.md'
report_path.write_text(report_text, encoding='utf-8')
print(f"\n[AUDIT] Report saved: {report_path}", flush=True)
# JSON summary
json_path = REPORT_DIR / 'feature_integrity_audit.json'
json_path.write_text(json.dumps(audit_results, indent=2, default=str), encoding='utf-8')
print(f"[AUDIT] JSON summary saved: {json_path}", flush=True)
print("\n" + "="*60)
print(report_text)
print("="*60)
if __name__ == '__main__':
main()