Spaces:
Running
Running
| """Complete dataset audit for DDI sources. | |
| Audits: | |
| - Duplicate pairs and conflicting labels | |
| - Class imbalance | |
| - Low-quality/noisy records | |
| - Normalization consistency | |
| - Source reliability metrics | |
| Output: | |
| - dataset_audit_report.json | |
| - class_balance_report.json | |
| - conflict_analysis.csv | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| import pandas as pd | |
| from preprocessing.artifact_manager import manager | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
| ) | |
| logger = logging.getLogger('medcare_ddi.audit') | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| DATA_DIR = BASE_DIR / 'data' | |
| PROCESSED_DIR = DATA_DIR / 'processed' | |
| RAW_DIR = DATA_DIR / 'raw' | |
| REPORTS_DIR = BASE_DIR / 'models' / 'reports' | |
| REPORTS_DIR.mkdir(parents=True, exist_ok=True) | |
| SEVERITY_LEVELS = ['unknown', 'minor', 'moderate', 'major'] | |
| SEVERITY_RANK = {v: i for i, v in enumerate(SEVERITY_LEVELS)} | |
| SOURCE_RELIABILITY = { | |
| 'drugbank': 1.0, | |
| 'ddinter': 0.95, | |
| 'kegg': 0.9, | |
| 'chembl': 0.85, | |
| 'pubchem': 0.8, | |
| 'twosides': 0.75, | |
| 'sider': 0.7, | |
| 'faers': 0.65, | |
| } | |
| def normalize_name(v: str) -> str: | |
| return ' '.join(str(v).strip().lower().split()) | |
| def canonical_pair(a: str, b: str) -> tuple[str, str]: | |
| na = normalize_name(a) | |
| nb = normalize_name(b) | |
| return tuple(sorted((na, nb))) | |
| def load_ddinter_data(path: Path) -> pd.DataFrame | None: | |
| if not path.exists(): | |
| logger.warning(f'DDInter file not found: {path}') | |
| return None | |
| df = manager.load_artifact('ddinter_combined') | |
| df['source'] = 'ddinter' | |
| return df | |
| def audit_dataset(df: pd.DataFrame, source_name: str) -> Dict[str, Any]: | |
| """Audit a single dataset source.""" | |
| if df.empty: | |
| return {'rows': 0, 'source': source_name, 'error': 'empty_dataset'} | |
| logger.info(f'Auditing {source_name}: {len(df)} rows') | |
| required_cols = {'drug_a', 'drug_b', 'severity'} | |
| missing = required_cols - set(c.lower() for c in df.columns) | |
| if missing: | |
| logger.error(f'Missing columns: {missing}') | |
| return {'rows': len(df), 'source': source_name, 'error': f'missing_columns: {missing}'} | |
| # Normalize column names | |
| df = df.rename(columns={c: c.lower() for c in df.columns}) | |
| # Detect duplicates and conflicts | |
| pairs_dict: Dict[tuple[str, str], List[Dict[str, Any]]] = defaultdict(list) | |
| for _, row in df.iterrows(): | |
| drug_a = str(row.get('drug_a', '')).strip() | |
| drug_b = str(row.get('drug_b', '')).strip() | |
| severity = str(row.get('severity', '')).strip().lower() | |
| if not drug_a or not drug_b: | |
| continue | |
| key = canonical_pair(drug_a, drug_b) | |
| pairs_dict[key].append({'severity': severity, 'source': source_name}) | |
| # Analyze conflicts | |
| conflicts = [] | |
| for pair_key, records in pairs_dict.items(): | |
| severities = {r['severity'] for r in records} | |
| if len(severities) > 1: | |
| conflicts.append({ | |
| 'pair': pair_key, | |
| 'severities': sorted(severities), | |
| 'count': len(records), | |
| 'source': source_name, | |
| }) | |
| # Class distribution | |
| class_dist = df['severity'].value_counts().to_dict() | |
| total = len(df) | |
| class_dist_pct = {k: round(100 * v / total, 2) for k, v in class_dist.items()} | |
| # Imbalance ratio (major / minor) | |
| major_count = class_dist.get('major', 0) | |
| minor_count = class_dist.get('minor', 0) + class_dist.get('unknown', 1) | |
| imbalance_ratio = round(major_count / max(minor_count, 1), 3) | |
| return { | |
| 'source': source_name, | |
| 'rows': len(df), | |
| 'unique_drugs': len(set(df['drug_a']).union(set(df['drug_b']))), | |
| 'unique_pairs': len(pairs_dict), | |
| 'class_distribution': class_dist, | |
| 'class_distribution_pct': class_dist_pct, | |
| 'imbalance_ratio': imbalance_ratio, | |
| 'duplicate_pairs': len(df) - len(pairs_dict), | |
| 'conflicting_pairs': len(conflicts), | |
| 'sample_conflicts': conflicts[:10] if conflicts else [], | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description='Audit DDI datasets') | |
| parser.add_argument('--ddinter-path', type=str, default=None) | |
| parser.add_argument('--output-audit', type=str, default=str(REPORTS_DIR / 'dataset_audit_report.json')) | |
| parser.add_argument('--output-balance', type=str, default=str(REPORTS_DIR / 'class_balance_report.json')) | |
| parser.add_argument('--output-conflicts', type=str, default=str(REPORTS_DIR / 'conflict_analysis.csv')) | |
| args = parser.parse_args() | |
| # Load DDInter | |
| ddinter_path = Path(args.ddinter_path) if args.ddinter_path else (PROCESSED_DIR / 'ddinter_combined.parquet') | |
| if not ddinter_path.exists(): | |
| logger.error(f'DDInter not found: {ddinter_path}') | |
| return | |
| ddinter_df = manager.load_artifact('ddinter_combined') | |
| logger.info(f'Loaded DDInter: {len(ddinter_df)} rows') | |
| # Audit DDInter | |
| audit_result = audit_dataset(ddinter_df, 'ddinter_combined') | |
| # Build comprehensive report | |
| report = { | |
| 'timestamp': str(pd.Timestamp.now()), | |
| 'audits': [audit_result], | |
| 'overall': { | |
| 'total_rows': len(ddinter_df), | |
| 'total_unique_drugs': len(set(ddinter_df['Drug_A']).union(set(ddinter_df['Drug_B']))), | |
| 'severity_distribution': ddinter_df['Level'].value_counts().to_dict(), | |
| }, | |
| } | |
| # Save audit report | |
| audit_path = Path(args.output_audit) | |
| audit_path.parent.mkdir(parents=True, exist_ok=True) | |
| audit_path.write_text(json.dumps(report, indent=2), encoding='utf-8') | |
| logger.info(f'Saved audit report: {audit_path}') | |
| # Save class balance report | |
| balance_report = { | |
| 'source': 'ddinter_combined', | |
| 'class_distribution': report['audits'][0].get('class_distribution', {}), | |
| 'class_distribution_pct': report['audits'][0].get('class_distribution_pct', {}), | |
| 'imbalance_ratio': report['audits'][0].get('imbalance_ratio', 1.0), | |
| 'recommendation': 'Apply weighted class balancing and focal loss to handle class imbalance', | |
| } | |
| balance_path = Path(args.output_balance) | |
| balance_path.write_text(json.dumps(balance_report, indent=2), encoding='utf-8') | |
| logger.info(f'Saved balance report: {balance_path}') | |
| # Save conflict analysis | |
| conflicts = report['audits'][0].get('sample_conflicts', []) | |
| if conflicts: | |
| conflict_df = pd.DataFrame(conflicts) | |
| conflict_path = Path(args.output_conflicts) | |
| conflict_df.to_csv(conflict_path, index=False) | |
| logger.info(f'Saved {len(conflicts)} conflicts to: {conflict_path}') | |
| else: | |
| logger.info('No significant conflicts detected') | |
| logger.info('✓ Dataset audit complete') | |
| if __name__ == '__main__': | |
| main() | |