ddi / src /preprocessing /artifact_store.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
from __future__ import annotations
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Iterable
import re
import xml.etree.ElementTree as ET
import pandas as pd
try: # pragma: no cover - optional during lightweight installs
from chemistry.smiles_recovery import validate_smiles
except Exception: # pragma: no cover
validate_smiles = None # type: ignore
BASE_DIR = Path(__file__).resolve().parents[2]
DATA_DIR = BASE_DIR / 'data'
RAW_DIR = DATA_DIR / 'raw'
PROCESSED_DIR = DATA_DIR / 'processed'
CACHE_DIR = BASE_DIR / 'cache'
EMBEDDINGS_CACHE_DIR = CACHE_DIR / 'embeddings'
FEATURE_CACHE_DIR = CACHE_DIR / 'feature_cache'
DRUGBANK_RAW_DIR = RAW_DIR / 'drugbank'
DRUGBANK_XML = DRUGBANK_RAW_DIR / 'full database.xml'
DDINTER_RAW_DIR = RAW_DIR / 'ddinter'
DRUGS_PATH = PROCESSED_DIR / 'drugs.parquet'
INTERACTIONS_PATH = PROCESSED_DIR / 'interactions.parquet'
SYNONYMS_PATH = PROCESSED_DIR / 'synonyms.parquet'
SMILES_PATH = PROCESSED_DIR / 'smiles.parquet'
DDI_DATASET_PATH = PROCESSED_DIR / 'ddi_dataset.parquet'
DDINTER_COMBINED_PATH = PROCESSED_DIR / 'ddinter_combined.parquet'
LEGACY_DDINTER_CSV = PROCESSED_DIR / 'ddinter_combined.csv'
DRUGBANK_CACHE_PATH = CACHE_DIR / 'drugbank_name_to_smiles.json'
def ensure_artifact_dirs() -> None:
for path in [PROCESSED_DIR, CACHE_DIR, EMBEDDINGS_CACHE_DIR, FEATURE_CACHE_DIR]:
path.mkdir(parents=True, exist_ok=True)
def _local_name(tag: str) -> str:
return tag.split('}')[-1].split(':')[-1] if tag else tag
def _element_text(element: ET.Element, local_name: str) -> str:
for child in element.iter():
if _local_name(child.tag) == local_name and child.text and child.text.strip():
return child.text.strip()
return ''
def _element_texts(element: ET.Element, local_name: str) -> list[str]:
values: list[str] = []
for child in element.iter():
if _local_name(child.tag) == local_name and child.text and child.text.strip():
values.append(child.text.strip())
return values
def _iter_drugbank_drugs(xml_path: Path) -> Iterable[ET.Element]:
try:
context = ET.iterparse(xml_path, events=('end',))
for _, element in context:
if _local_name(element.tag) == 'drug':
yield element
element.clear()
except ET.ParseError as e:
print(f"XML parse error gracefully handled: {e}")
def _extract_drugbank_smiles(drug: ET.Element) -> tuple[str, str]:
for prop in drug.iter():
if _local_name(prop.tag) != 'property':
continue
kind = ''
value = ''
for child in list(prop):
local = _local_name(child.tag)
if local == 'kind' and child.text:
kind = child.text.strip().lower()
elif local in {'value', 'text'}:
text = ''.join(child.itertext()).strip()
if text:
value = text
if kind and 'smiles' in kind and value:
return re.sub(r'\s+', '', value), 'property'
direct = _element_text(drug, 'smiles')
if direct:
return re.sub(r'\s+', '', direct), 'smiles_tag'
for element in drug.iter():
if 'smiles' in _local_name(element.tag).lower():
text = ''.join(element.itertext()).strip()
if text:
return re.sub(r'\s+', '', text), 'tag'
return '', ''
def _extract_drugbank_identifiers(drug: ET.Element) -> dict[str, str]:
identifiers = {'cas': '', 'unii': '', 'inchi': ''}
cas = _element_text(drug, 'cas-number')
if cas:
identifiers['cas'] = cas
unii = _element_text(drug, 'unii')
if unii:
identifiers['unii'] = unii
for ext in drug.iter():
if _local_name(ext.tag) != 'external-identifier':
continue
resource = _element_text(ext, 'resource').lower()
identifier = _element_text(ext, 'identifier')
if not resource or not identifier:
continue
if 'inchi' in resource:
identifiers['inchi'] = identifier
if 'cas' in resource or 'cas' in identifier.lower():
identifiers['cas'] = identifiers['cas'] or identifier
for prop in drug.iter():
if _local_name(prop.tag) != 'property':
continue
kind = ''
value = ''
for child in list(prop):
local = _local_name(child.tag)
if local == 'kind' and child.text:
kind = child.text.strip().lower()
elif local in {'value', 'text'}:
text = ''.join(child.itertext()).strip()
if text:
value = text
if kind and 'inchi' in kind and value:
identifiers['inchi'] = value
return identifiers
def _canonicalize_smiles(smiles: str) -> tuple[str, bool]:
smiles = re.sub(r'\s+', '', smiles or '').strip()
if not smiles:
return '', False
if validate_smiles is None:
return smiles, True
try:
result = validate_smiles(smiles)
except Exception:
return smiles, False
canonical = str(result.get('canonical_smiles') or '').strip()
if canonical and result.get('valid'):
return canonical, True
return smiles, bool(smiles)
def _extract_interactions(drug: ET.Element, drugbank_id: str, name: str) -> list[dict[str, Any]]:
interactions: list[dict[str, Any]] = []
for interaction in drug.iter():
if _local_name(interaction.tag) != 'drug-interaction':
continue
target_id = _element_text(interaction, 'drugbank-id')
target_name = _element_text(interaction, 'name')
description = _element_text(interaction, 'description')
if target_id or target_name or description:
interactions.append(
{
'drugbank_id': drugbank_id,
'drug_name': name,
'interacting_drugbank_id': target_id,
'interacting_drug_name': target_name,
'description': description,
}
)
return interactions
def build_drugbank_artifacts(xml_path: Path | None = None, *, force: bool = False) -> dict[str, Path]:
ensure_artifact_dirs()
xml_path = Path(xml_path) if xml_path is not None else DRUGBANK_XML
if not xml_path.exists():
raise FileNotFoundError(f'DrugBank XML not found at {xml_path}')
if not force and all(path.exists() for path in [DRUGS_PATH, INTERACTIONS_PATH, SYNONYMS_PATH, SMILES_PATH]):
return {
'drugs': DRUGS_PATH,
'interactions': INTERACTIONS_PATH,
'synonyms': SYNONYMS_PATH,
'smiles': SMILES_PATH,
}
drugs: list[dict[str, Any]] = []
interactions: list[dict[str, Any]] = []
synonyms: list[dict[str, Any]] = []
smiles_rows: list[dict[str, Any]] = []
for drug in _iter_drugbank_drugs(xml_path):
drugbank_id = _element_text(drug, 'drugbank-id') or ''
name = _element_text(drug, 'name') or drugbank_id
drug_type = str(drug.attrib.get('type', '')).strip().lower()
synonyms_list = [value for value in _element_texts(drug, 'synonym') if value]
brands = [value for value in _element_texts(drug, 'international-brand') if value]
products = [value for value in _element_texts(drug, 'product') if value]
identifiers = _extract_drugbank_identifiers(drug)
raw_smiles, smiles_source = _extract_drugbank_smiles(drug)
canonical_smiles, smiles_valid = _canonicalize_smiles(raw_smiles)
drugs.append(
{
'drugbank_id': drugbank_id,
'name': name,
'type': drug_type,
'is_biologic': drug_type in {'biotech', 'protein'},
'is_small_molecule': drug_type == 'small molecule',
'raw_smiles': raw_smiles,
'canonical_smiles': canonical_smiles,
'smiles_source': smiles_source,
'smiles_valid': smiles_valid,
'cas': identifiers.get('cas', ''),
'unii': identifiers.get('unii', ''),
'inchi': identifiers.get('inchi', ''),
'synonym_count': len(set(synonyms_list)),
'brand_count': len(set(brands)),
'product_count': len(set(products)),
}
)
alias_rows = []
for alias_type, values in (
('synonym', synonyms_list),
('brand', brands),
('product', products),
):
for alias in sorted(set(value for value in values if value)):
alias_rows.append(
{
'drugbank_id': drugbank_id,
'canonical_name': name,
'alias': alias,
'alias_type': alias_type,
}
)
alias_rows.extend(
[
{'drugbank_id': drugbank_id, 'canonical_name': name, 'alias': name, 'alias_type': 'canonical'},
{'drugbank_id': drugbank_id, 'canonical_name': name, 'alias': drugbank_id, 'alias_type': 'identifier'},
{'drugbank_id': drugbank_id, 'canonical_name': name, 'alias': identifiers.get('cas', ''), 'alias_type': 'cas'},
{'drugbank_id': drugbank_id, 'canonical_name': name, 'alias': identifiers.get('unii', ''), 'alias_type': 'unii'},
{'drugbank_id': drugbank_id, 'canonical_name': name, 'alias': identifiers.get('inchi', ''), 'alias_type': 'inchi'},
]
)
synonyms.extend([row for row in alias_rows if row['alias']])
smiles_rows.append(
{
'drugbank_id': drugbank_id,
'canonical_name': name,
'raw_smiles': raw_smiles,
'canonical_smiles': canonical_smiles,
'smiles_source': smiles_source,
'smiles_valid': smiles_valid,
}
)
interactions.extend(_extract_interactions(drug, drugbank_id, name))
drugs_df = pd.DataFrame(drugs).drop_duplicates(subset=['drugbank_id'], keep='first')
interactions_df = pd.DataFrame(interactions)
synonyms_df = pd.DataFrame(synonyms).drop_duplicates()
smiles_df = pd.DataFrame(smiles_rows).drop_duplicates(subset=['drugbank_id'], keep='first')
drugs_df.to_parquet(DRUGS_PATH, index=False)
interactions_df.to_parquet(INTERACTIONS_PATH, index=False)
synonyms_df.to_parquet(SYNONYMS_PATH, index=False)
smiles_df.to_parquet(SMILES_PATH, index=False)
try:
from preprocessing.artifact_manager import manager
manager.register_artifact('drugs', drugs_df, DRUGS_PATH)
manager.register_artifact('interactions', interactions_df, INTERACTIONS_PATH)
manager.register_artifact('synonyms', synonyms_df, SYNONYMS_PATH)
manager.register_artifact('smiles', smiles_df, SMILES_PATH)
except Exception as e:
pass
return {
'drugs': DRUGS_PATH,
'interactions': INTERACTIONS_PATH,
'synonyms': SYNONYMS_PATH,
'smiles': SMILES_PATH,
}
def _load_first_existing(paths: Iterable[Path]) -> pd.DataFrame:
for path in paths:
if not path.exists():
continue
if path.suffix.lower() in {'.parquet', '.feather'}:
return pd.read_parquet(path)
if path.suffix.lower() == '.csv':
return pd.read_csv(path, low_memory=False)
raise FileNotFoundError(f'None of the candidate data files exist: {", ".join(str(path) for path in paths)}')
def load_ddinter_processed_frame() -> pd.DataFrame:
candidates = [DDINTER_COMBINED_PATH, LEGACY_DDINTER_CSV, PROCESSED_DIR / 'ddinter_combined.feather']
return _load_first_existing(candidates)
def build_ddinter_structured_artifacts(*, force: bool = False) -> dict[str, Path]:
ensure_artifact_dirs()
if not force and DDINTER_COMBINED_PATH.exists() and DDI_DATASET_PATH.exists():
return {'ddinter_combined': DDINTER_COMBINED_PATH, 'ddi_dataset': DDI_DATASET_PATH}
df = load_ddinter_processed_frame().copy()
df = df.rename(columns={column: column.strip() for column in df.columns})
column_map = {column.lower(): column for column in df.columns}
drug_a = column_map.get('drug_a') or column_map.get('drug_a_name') or column_map.get('a')
drug_b = column_map.get('drug_b') or column_map.get('drug_b_name') or column_map.get('b')
severity = column_map.get('level') or column_map.get('severity') or column_map.get('label')
if not all([drug_a, drug_b, severity]):
raise KeyError(f'Could not identify required DDInter columns in {list(df.columns)}')
combined = df.copy()
combined['drug_a_name'] = combined[drug_a].astype(str).str.strip()
combined['drug_b_name'] = combined[drug_b].astype(str).str.strip()
combined['severity'] = combined[severity].astype(str).str.strip().str.lower()
if 'source' not in combined.columns:
combined['source'] = 'ddinter'
combined['drug_a'] = combined['drug_a_name']
combined['drug_b'] = combined['drug_b_name']
combined['level'] = combined['severity']
pair_levels: dict[tuple[str, str], Counter[str]] = defaultdict(Counter)
pair_support: dict[tuple[str, str], int] = defaultdict(int)
representative: dict[tuple[str, str], tuple[str, str]] = {}
for _, row in combined.iterrows():
a = str(row['drug_a_name']).strip()
b = str(row['drug_b_name']).strip()
label = str(row['severity']).strip().lower()
key = tuple(sorted((a.lower(), b.lower())))
pair_levels[key][label] += 1
pair_support[key] += 1
representative.setdefault(key, (a, b))
ddi_rows = []
for key, counter in pair_levels.items():
label, _ = counter.most_common(1)[0]
a, b = representative[key]
ddi_rows.append(
{
'drug_a_name': a,
'drug_b_name': b,
'severity': label,
'support': int(pair_support[key]),
'pair_key': '||'.join(key),
'source': 'ddinter',
}
)
ddi_dataset = pd.DataFrame(ddi_rows)
combined.to_parquet(DDINTER_COMBINED_PATH, index=False)
ddi_dataset.to_parquet(DDI_DATASET_PATH, index=False)
combined.to_csv(LEGACY_DDINTER_CSV, index=False)
try:
from preprocessing.artifact_manager import manager
manager.register_artifact('ddinter_combined', combined, DDINTER_COMBINED_PATH)
manager.register_artifact('ddi_dataset', ddi_dataset, DDI_DATASET_PATH)
except Exception as e:
pass
return {'ddinter_combined': DDINTER_COMBINED_PATH, 'ddi_dataset': DDI_DATASET_PATH}
def ensure_structured_data(*, force_rebuild: bool = False) -> dict[str, Path]:
ensure_artifact_dirs()
outputs: dict[str, Path] = {}
outputs.update(build_ddinter_structured_artifacts(force=force_rebuild))
if DRUGBANK_XML.exists():
outputs.update(build_drugbank_artifacts(force=force_rebuild))
try:
from preprocessing.twosides_builder import build_twosides_artifacts
build_twosides_artifacts(force=force_rebuild)
except Exception as e:
pass
return outputs
def load_structured_dataframe(name: str) -> pd.DataFrame:
from preprocessing.artifact_manager import manager
return manager.load_artifact(name)