ddi / src /preprocessing /unified_dataset_builder.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Unified DDI dataset builder for multi-source integration.
Sources supported via adapters:
- DDInter
- DrugBank
- TWOSIDES
- SIDER
- FAERS
- ChEMBL
- PubChem
- KEGG
The output schema is immutable and reproducible.
"""
from __future__ import annotations
import argparse
import hashlib
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List
import pandas as pd
from preprocessing.artifact_manager import manager
SEVERITY_LEVELS = ['unknown', 'minor', 'moderate', 'major']
SEVERITY_RANK = {v: i for i, v in enumerate(SEVERITY_LEVELS)}
SOURCE_RELIABILITY = {
'drugbank': 1.0,
'ddinter': 0.95,
'kegg': 0.9,
'chembl': 0.85,
'pubchem': 0.8,
'twosides': 0.75,
'sider': 0.7,
'faers': 0.65,
}
@dataclass(frozen=True)
class UnifiedSchema:
version: str = 'ddi_unified_v1'
columns: tuple[str, ...] = (
'drug_a',
'drug_b',
'severity',
'source',
'support',
'evidence',
)
def normalize_drug_name(v: str) -> str:
return ' '.join(str(v).strip().lower().split())
def canonical_pair(a: str, b: str) -> tuple[str, str]:
na = normalize_drug_name(a)
nb = normalize_drug_name(b)
return tuple(sorted((na, nb)))
def normalize_severity(v: str) -> str:
s = str(v).strip().lower()
if s in SEVERITY_RANK:
return s
if s in {'severe', 'contraindicated', 'high'}:
return 'major'
if s in {'medium', 'moderate risk'}:
return 'moderate'
if s in {'low', 'mild'}:
return 'minor'
return 'unknown'
def ingest_ddinter(path: Path) -> pd.DataFrame:
df = manager.load_artifact('ddinter_combined')
out = pd.DataFrame(
{
'drug_a': df['Drug_A'].astype(str),
'drug_b': df['Drug_B'].astype(str),
'severity': df['Level'].astype(str).map(normalize_severity),
'source': 'ddinter',
'support': 1,
'evidence': df.get('Description', '').astype(str) if 'Description' in df.columns else '',
}
)
return out
def ingest_generic(path: Path, source: str, mapping: Dict[str, str]) -> pd.DataFrame:
df = manager.load_artifact('ddinter_combined')
def col(name: str) -> str:
if name not in mapping:
raise ValueError(f'Missing mapping for {name} in source {source}')
return mapping[name]
out = pd.DataFrame(
{
'drug_a': df[col('drug_a')].astype(str),
'drug_b': df[col('drug_b')].astype(str),
'severity': df[col('severity')].astype(str).map(normalize_severity),
'source': source,
'support': 1,
'evidence': df[col('evidence')].astype(str) if 'evidence' in mapping else '',
}
)
return out
def dedupe_and_resolve(df: pd.DataFrame) -> pd.DataFrame:
buckets: Dict[tuple[str, str], List[dict]] = {}
for _, row in df.iterrows():
key = canonical_pair(row['drug_a'], row['drug_b'])
buckets.setdefault(key, []).append(
{
'severity': normalize_severity(row['severity']),
'source': str(row['source']),
'support': int(row.get('support', 1)),
'evidence': str(row.get('evidence', '')),
}
)
merged = []
for (a, b), rows in buckets.items():
# Reliability-aware conservative merge.
severity_support = {level: 0.0 for level in SEVERITY_LEVELS}
for r in rows:
src = str(r['source']).strip().lower()
reliability = SOURCE_RELIABILITY.get(src, 0.6)
sev = normalize_severity(r['severity'])
severity_support[sev] += reliability * max(1, int(r.get('support', 1)))
ranked = sorted(
severity_support.items(),
key=lambda item: (item[1], SEVERITY_RANK.get(item[0], 0)),
reverse=True,
)
chosen_severity = ranked[0][0]
max_seen = max(rows, key=lambda r: SEVERITY_RANK.get(normalize_severity(r['severity']), 0))['severity']
disagreement = len({normalize_severity(r['severity']) for r in rows}) > 1
# Safety-first tie break: if signals conflict and strong major evidence exists, keep major.
if disagreement and severity_support.get('major', 0.0) >= 0.9:
chosen_severity = 'major'
merged.append(
{
'drug_a': a,
'drug_b': b,
'severity': chosen_severity,
'source': '|'.join(sorted({r['source'] for r in rows})),
'support': int(sum(r['support'] for r in rows)),
'evidence': ' || '.join([r['evidence'] for r in rows if r['evidence']][:5]),
'conflict': int(disagreement),
'max_observed_severity': normalize_severity(max_seen),
}
)
out = pd.DataFrame(merged)
return out.sort_values(['drug_a', 'drug_b']).reset_index(drop=True)
def dataset_stats(df: pd.DataFrame) -> dict:
return {
'rows': int(len(df)),
'unique_drugs': int(len(set(df['drug_a']).union(set(df['drug_b'])))),
'severity_distribution': df['severity'].value_counts().to_dict(),
'conflict_rows': int(df['conflict'].sum()) if 'conflict' in df.columns else 0,
'sources': sorted(set('|'.join(df['source'].tolist()).split('|'))),
'checksum': hashlib.sha256(df.to_csv(index=False).encode('utf-8')).hexdigest(),
}
def main() -> None:
parser = argparse.ArgumentParser(description='Build unified DDI dataset from multi-source inputs')
parser.add_argument('--ddinter', type=str, required=True)
parser.add_argument('--extra-config', type=str, default=None, help='JSON config listing extra CSV sources and column mappings')
parser.add_argument('--out-csv', type=str, required=True)
parser.add_argument('--out-stats', type=str, required=True)
args = parser.parse_args()
frames: List[pd.DataFrame] = [ingest_ddinter(Path(args.ddinter))]
if args.extra_config:
cfg = json.loads(Path(args.extra_config).read_text(encoding='utf-8'))
for source in cfg.get('sources', []):
frames.append(
ingest_generic(
path=Path(source['path']),
source=str(source['name']).lower(),
mapping=source['mapping'],
)
)
all_df = pd.concat(frames, ignore_index=True)
unified = dedupe_and_resolve(all_df)
schema = UnifiedSchema()
missing = [c for c in schema.columns if c not in unified.columns]
if missing:
raise ValueError(f'Unified schema mismatch, missing: {missing}')
out_csv = Path(args.out_csv)
out_csv.parent.mkdir(parents=True, exist_ok=True)
unified.to_csv(out_csv, index=False)
stats = dataset_stats(unified)
stats['schema_version'] = schema.version
Path(args.out_stats).write_text(json.dumps(stats, indent=2), encoding='utf-8')
if __name__ == '__main__':
main()