Spaces:
Running
Running
| """Integrity validation for the multisource feature pipeline. | |
| Produces JSON and Markdown reports covering: | |
| - feature-group statistics | |
| - sparsity / dead-dimension diagnostics | |
| - mapping coverage | |
| - train/inference parity | |
| - label distribution checks | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from typing import Any | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| ROOT = Path(__file__).resolve().parents[2] | |
| sys.path.insert(0, str(ROOT)) | |
| from .feature_pipeline_multisource import build_feature_pipeline, transform_pair_features | |
| MODELS_DIR = ROOT / 'models' | |
| REPORT_DIR = MODELS_DIR / 'reports' | |
| REPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| def _markdown_table(df: pd.DataFrame) -> str: | |
| headers = list(df.columns) | |
| rows = [headers] | |
| for _, row in df.iterrows(): | |
| rows.append([str(row[col]) for col in headers]) | |
| widths = [max(len(cell) for cell in column) for column in zip(*rows)] | |
| lines = [] | |
| lines.append('| ' + ' | '.join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + ' |') | |
| lines.append('| ' + ' | '.join('-' * widths[idx] for idx in range(len(headers))) + ' |') | |
| for row in rows[1:]: | |
| lines.append('| ' + ' | '.join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)) + ' |') | |
| return '\n'.join(lines) | |
| def _group_statistics(X: np.ndarray, group_slices: dict[str, tuple[int, int]]) -> tuple[pd.DataFrame, dict[str, Any]]: | |
| rows = [] | |
| diagnostics: dict[str, Any] = {} | |
| for group_name, (start, end) in group_slices.items(): | |
| block = X[:, start:end] | |
| variances = np.var(block, axis=0) | |
| zero_var = int(np.sum(variances <= 1e-10)) | |
| rows.append( | |
| { | |
| 'group': group_name, | |
| 'dims': int(end - start), | |
| 'mean': float(np.mean(block)), | |
| 'std': float(np.std(block)), | |
| 'min': float(np.min(block)), | |
| 'max': float(np.max(block)), | |
| 'sparsity': float(np.mean(block == 0.0)), | |
| 'non_zero_rate': float(np.mean(block != 0.0)), | |
| 'zero_var_dims': zero_var, | |
| } | |
| ) | |
| diagnostics[group_name] = { | |
| 'dims': int(end - start), | |
| 'zero_var_dims': zero_var, | |
| 'sparsity': float(np.mean(block == 0.0)), | |
| 'non_zero_rate': float(np.mean(block != 0.0)), | |
| 'all_zero': bool(np.allclose(block, 0.0)), | |
| } | |
| return pd.DataFrame(rows), diagnostics | |
| def _consistency_check(pairs_df: pd.DataFrame, artifacts: dict[str, Any], sample_size: int = 50) -> dict[str, Any]: | |
| rng = np.random.default_rng(2026) | |
| sample_size = min(sample_size, len(pairs_df)) | |
| indices = rng.choice(len(pairs_df), size=sample_size, replace=False) | |
| mismatches = [] | |
| max_diff = 0.0 | |
| for idx in indices: | |
| row = pairs_df.iloc[int(idx)] | |
| train_vector = np.asarray(row['_X'], dtype=np.float32) | |
| inference_vector = transform_pair_features(row['drug_a'], row['drug_b'], artifacts) | |
| if train_vector.shape != inference_vector.shape: | |
| mismatches.append( | |
| { | |
| 'index': int(idx), | |
| 'drug_a': row['drug_a'], | |
| 'drug_b': row['drug_b'], | |
| 'issue': 'dimension_mismatch', | |
| 'train_dim': int(train_vector.shape[0]), | |
| 'inference_dim': int(inference_vector.shape[0]), | |
| } | |
| ) | |
| continue | |
| diff = np.abs(train_vector - inference_vector) | |
| sample_max = float(np.max(diff)) | |
| max_diff = max(max_diff, sample_max) | |
| if sample_max > 1e-6: | |
| mismatches.append( | |
| { | |
| 'index': int(idx), | |
| 'drug_a': row['drug_a'], | |
| 'drug_b': row['drug_b'], | |
| 'issue': 'value_mismatch', | |
| 'max_diff': sample_max, | |
| 'mean_diff': float(np.mean(diff)), | |
| } | |
| ) | |
| return { | |
| 'samples_checked': int(sample_size), | |
| 'mismatches': mismatches, | |
| 'mismatch_count': int(len(mismatches)), | |
| 'max_diff': float(max_diff), | |
| } | |
| def _plot_statistics(group_df: pd.DataFrame, label_counts: pd.Series) -> tuple[Path, Path]: | |
| stats_path = REPORT_DIR / 'feature_group_statistics_multisource.png' | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) | |
| axes = axes.ravel() | |
| axes[0].bar(group_df['group'], group_df['dims'], color='#4C78A8') | |
| axes[0].set_title('Group Dimensions') | |
| axes[0].tick_params(axis='x', rotation=20) | |
| axes[1].bar(group_df['group'], group_df['sparsity'], color='#F58518') | |
| axes[1].set_title('Group Sparsity') | |
| axes[1].tick_params(axis='x', rotation=20) | |
| axes[2].bar(group_df['group'], group_df['zero_var_dims'], color='#E45756') | |
| axes[2].set_title('Zero-Variance Dims') | |
| axes[2].tick_params(axis='x', rotation=20) | |
| axes[3].bar(label_counts.index.tolist(), label_counts.values.tolist(), color='#72B7B2') | |
| axes[3].set_title('Label Distribution') | |
| axes[3].tick_params(axis='x', rotation=20) | |
| fig.tight_layout() | |
| fig.savefig(stats_path, dpi=160) | |
| plt.close(fig) | |
| consistency_path = REPORT_DIR / 'feature_consistency_multisource.png' | |
| return stats_path, consistency_path | |
| def run_validation(sample_size: int = 1000, consistency_sample: int = 50, seed: int = 2026) -> dict[str, Any]: | |
| pairs_df, artifacts_obj = build_feature_pipeline(save_artifacts=True, sample_size=sample_size, seed=seed) | |
| artifacts = { | |
| 'mapper_artifact': artifacts_obj.mapper_artifact, | |
| 'group_slices': artifacts_obj.group_slices, | |
| 'feature_names': artifacts_obj.feature_names, | |
| 'active_feature_mask': artifacts_obj.active_feature_mask, | |
| 'semantic_dim': artifacts_obj.semantic_dim, | |
| 'drugbank_dim': artifacts_obj.drugbank_dim, | |
| 'twosides_hash_dim': artifacts_obj.twosides_hash_dim, | |
| 'metadata': artifacts_obj.metadata, | |
| 'ddinter_name_to_canonical': artifacts_obj.ddinter_name_to_canonical, | |
| 'twosides_cid_to_canonical': artifacts_obj.twosides_cid_to_canonical, | |
| 'canonical_entities': artifacts_obj.canonical_entities, | |
| 'ddinter_adjacency': artifacts_obj.ddinter_adjacency, | |
| 'twosides_pair_stats': artifacts_obj.twosides_pair_stats, | |
| 'graph_scaler': artifacts_obj.graph_scaler, | |
| 'twosides_scaler': artifacts_obj.twosides_scaler, | |
| 'coverage_stats': artifacts_obj.coverage_stats, | |
| } | |
| X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32) | |
| group_df, diagnostics = _group_statistics(X, artifacts_obj.group_slices) | |
| consistency = _consistency_check(pairs_df, artifacts, sample_size=consistency_sample) | |
| label_counts = pairs_df['label'].value_counts() | |
| report = { | |
| 'metadata': artifacts_obj.metadata, | |
| 'coverage_stats': artifacts_obj.coverage_stats, | |
| 'feature_dimension': int(X.shape[1]), | |
| 'sample_size': int(len(pairs_df)), | |
| 'group_statistics': group_df.to_dict(orient='records'), | |
| 'group_diagnostics': diagnostics, | |
| 'label_distribution': label_counts.to_dict(), | |
| 'train_inference_consistency': consistency, | |
| 'dead_groups': [name for name, diag in diagnostics.items() if diag['all_zero'] or diag['zero_var_dims'] >= diag['dims']], | |
| } | |
| md_lines = [ | |
| '# Multisource Feature Integrity Report', | |
| '', | |
| f"- **Feature dimension:** {report['feature_dimension']}", | |
| f"- **Samples:** {report['sample_size']}", | |
| f"- **Train/inference mismatches:** {consistency['mismatch_count']} / {consistency['samples_checked']}", | |
| f"- **Max consistency diff:** {consistency['max_diff']:.6f}", | |
| '', | |
| '## Coverage', | |
| '', | |
| _markdown_table(pd.DataFrame([report['coverage_stats']['ddinter'] | {'source': 'ddinter'}, report['coverage_stats']['twosides'] | {'source': 'twosides'}])), | |
| '', | |
| '## Group Statistics', | |
| '', | |
| _markdown_table(group_df), | |
| '', | |
| '## Label Distribution', | |
| '', | |
| _markdown_table(pd.DataFrame([{'label': label, 'count': count} for label, count in label_counts.items()])), | |
| '', | |
| '## Consistency Diagnostics', | |
| '', | |
| ] | |
| if consistency['mismatch_count'] == 0: | |
| md_lines.append('No train/inference mismatches detected.') | |
| else: | |
| md_lines.append(f"Detected {consistency['mismatch_count']} mismatches. See JSON for pair-level details.") | |
| md_lines.append('') | |
| for mismatch in consistency['mismatches'][:10]: | |
| md_lines.append(f"- {mismatch['drug_a']} + {mismatch['drug_b']}: {mismatch['issue']}") | |
| md_lines.extend([ | |
| '', | |
| '## Dead Groups', | |
| '', | |
| ', '.join(report['dead_groups']) if report['dead_groups'] else 'None detected.', | |
| '', | |
| ]) | |
| json_path = REPORT_DIR / 'feature_integrity_multisource.json' | |
| md_path = REPORT_DIR / 'feature_integrity_multisource.md' | |
| json_path.write_text(json.dumps(report, indent=2, default=str), encoding='utf-8') | |
| md_path.write_text('\n'.join(md_lines), encoding='utf-8') | |
| _plot_statistics(group_df, label_counts) | |
| report['artifacts'] = { | |
| 'json': str(json_path), | |
| 'markdown': str(md_path), | |
| } | |
| return report | |
| if __name__ == '__main__': | |
| validation = run_validation() | |
| print(json.dumps(validation['metadata'], indent=2)) | |