| | """Data import scripts for AMR-Guard structured documents.""" |
| |
|
| | import pandas as pd |
| | from pathlib import Path |
| | from .database import ( |
| | get_connection, init_database, execute_many, |
| | DOCS_DIR, DB_PATH |
| | ) |
| |
|
| |
|
| | def safe_float(value): |
| | """Convert value to float; return None if the value is NaN or non-numeric.""" |
| | if pd.isna(value): |
| | return None |
| | try: |
| | return float(value) |
| | except (ValueError, TypeError): |
| | return None |
| |
|
| |
|
| | def safe_int(value): |
| | """Convert value to int via float; return None if the value is NaN or non-numeric.""" |
| | if pd.isna(value): |
| | return None |
| | try: |
| | return int(float(value)) |
| | except (ValueError, TypeError): |
| | return None |
| |
|
| |
|
| | def safe_str(value) -> str: |
| | """Convert value to string; return empty string for None or NaN.""" |
| | if value is None or pd.isna(value): |
| | return '' |
| | return str(value) |
| |
|
| |
|
| | def classify_severity(description: str) -> str: |
| | """ |
| | Classify drug interaction severity from the interaction description text. |
| | |
| | Returns 'major', 'moderate', or 'minor' based on keyword presence. |
| | Major keywords take precedence over moderate. |
| | """ |
| | if not description: |
| | return "unknown" |
| |
|
| | desc_lower = description.lower() |
| |
|
| | major_keywords = [ |
| | "cardiotoxic", "nephrotoxic", "hepatotoxic", "neurotoxic", |
| | "fatal", "death", "severe", "contraindicated", "arrhythmia", |
| | "qt prolongation", "seizure", "bleeding", "hemorrhage", |
| | "serotonin syndrome", "neuroleptic malignant", |
| | ] |
| | moderate_keywords = [ |
| | "increase", "decrease", "reduce", "enhance", "inhibit", |
| | "metabolism", "concentration", "absorption", "excretion", |
| | "therapeutic effect", "adverse effect", "toxicity", |
| | ] |
| |
|
| | if any(kw in desc_lower for kw in major_keywords): |
| | return "major" |
| | if any(kw in desc_lower for kw in moderate_keywords): |
| | return "moderate" |
| | return "minor" |
| |
|
| |
|
| | def import_eml_antibiotics() -> int: |
| | """Import WHO EML antibiotic classification data from the three AWaRe Excel files.""" |
| | print("Importing EML antibiotic data...") |
| |
|
| | eml_files = { |
| | "ACCESS": DOCS_DIR / "antibiotic_guidelines" / "EML export-ACCESS group.xlsx", |
| | "RESERVE": DOCS_DIR / "antibiotic_guidelines" / "EML export-RESERVE group.xlsx", |
| | "WATCH": DOCS_DIR / "antibiotic_guidelines" / "EML export-WATCH group.xlsx", |
| | } |
| |
|
| | records = [] |
| | for category, filepath in eml_files.items(): |
| | if not filepath.exists(): |
| | print(f" Warning: {filepath} not found, skipping...") |
| | continue |
| |
|
| | try: |
| | import openpyxl |
| | wb = openpyxl.load_workbook(filepath, read_only=True) |
| | ws = wb.active |
| |
|
| | headers = [ |
| | str(cell.value).strip().lower().replace(' ', '_') if cell.value else f'col_{i}' |
| | for i, cell in enumerate(ws[1]) |
| | ] |
| |
|
| | for row in ws.iter_rows(min_row=2, values_only=True): |
| | row_dict = dict(zip(headers, row)) |
| | medicine = str(row_dict.get('medicine_name', row_dict.get('medicine', ''))) |
| | if not medicine or medicine in ('None', 'nan'): |
| | continue |
| |
|
| | records.append(( |
| | medicine, |
| | category, |
| | safe_str(row_dict.get('eml_section', '')), |
| | safe_str(row_dict.get('formulations', '')), |
| | safe_str(row_dict.get('indication', '')), |
| | safe_str(row_dict.get('atc_codes', row_dict.get('atc_code', ''))), |
| | safe_str(row_dict.get('combined_with', '')), |
| | safe_str(row_dict.get('status', '')), |
| | )) |
| |
|
| | wb.close() |
| | print(f" Loaded {sum(1 for r in records if r[1] == category)} from {category}") |
| |
|
| | except Exception as e: |
| | print(f" Warning: Error reading {filepath}: {e}") |
| | continue |
| |
|
| | if records: |
| | execute_many( |
| | """INSERT INTO eml_antibiotics |
| | (medicine_name, who_category, eml_section, formulations, |
| | indication, atc_codes, combined_with, status) |
| | VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", |
| | records, |
| | ) |
| | print(f" Imported {len(records)} EML antibiotic records total") |
| |
|
| | return len(records) |
| |
|
| |
|
| | def import_atlas_susceptibility() -> int: |
| | """Import ATLAS antimicrobial susceptibility data.""" |
| | print("Importing ATLAS susceptibility data...") |
| |
|
| | filepath = DOCS_DIR / "pathogen_resistance" / "ATLAS Susceptibility Data Export.xlsx" |
| |
|
| | if not filepath.exists(): |
| | print(f" Warning: {filepath} not found, skipping...") |
| | return 0 |
| |
|
| | df_raw = pd.read_excel(filepath, sheet_name="Percent", header=None) |
| |
|
| | |
| | region = "Unknown" |
| | for _, row in df_raw.head(5).iterrows(): |
| | cell = str(row.iloc[0]) if pd.notna(row.iloc[0]) else "" |
| | if "from" in cell.lower(): |
| | parts = cell.split("from") |
| | if len(parts) > 1: |
| | region = parts[1].strip() |
| | break |
| |
|
| | |
| | header_row = 4 |
| | for idx, row in df_raw.head(10).iterrows(): |
| | if any('Antibacterial' in str(v) for v in row.values if pd.notna(v)): |
| | header_row = idx |
| | break |
| |
|
| | df = pd.read_excel(filepath, sheet_name="Percent", header=header_row) |
| | df.columns = [str(col).strip().lower().replace(' ', '_').replace('.', '') for col in df.columns] |
| |
|
| | records = [] |
| | for _, row in df.iterrows(): |
| | antibiotic = str(row.get('antibacterial', '')) |
| | if not antibiotic or antibiotic == 'nan' or 'omitted' in antibiotic.lower(): |
| | continue |
| | if 'in vitro' in antibiotic.lower() or 'table cells' in antibiotic.lower(): |
| | continue |
| |
|
| | n_int = safe_int(row.get('n')) |
| | s_float = safe_float(row.get('susc', row.get('susceptible'))) |
| |
|
| | if n_int is not None and s_float is not None: |
| | records.append(( |
| | "General", |
| | "", |
| | antibiotic, |
| | s_float, |
| | safe_float(row.get('int', row.get('intermediate'))), |
| | safe_float(row.get('res', row.get('resistant'))), |
| | n_int, |
| | 2024, |
| | region, |
| | "ATLAS", |
| | )) |
| |
|
| | if records: |
| | execute_many( |
| | """INSERT INTO atlas_susceptibility |
| | (species, family, antibiotic, percent_susceptible, |
| | percent_intermediate, percent_resistant, total_isolates, |
| | year, region, source) |
| | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", |
| | records, |
| | ) |
| | print(f" Imported {len(records)} ATLAS susceptibility records from {region}") |
| |
|
| | return len(records) |
| |
|
| |
|
| | def import_mic_breakpoints() -> int: |
| | """Import EUCAST MIC breakpoint tables from the Excel file.""" |
| | print("Importing MIC breakpoint data...") |
| |
|
| | filepath = DOCS_DIR / "mic_breakpoints" / "v_16.0__BreakpointTables.xlsx" |
| | if not filepath.exists(): |
| | print(f" Warning: {filepath} not found, skipping...") |
| | return 0 |
| |
|
| | xl = pd.ExcelFile(filepath) |
| | |
| | skip_sheets = {'Content', 'Changes', 'Notes', 'Guidance', 'Dosages', |
| | 'Technical uncertainty', 'PK PD breakpoints', 'PK PD cutoffs'} |
| |
|
| | records = [] |
| | for sheet_name in xl.sheet_names: |
| | if sheet_name in skip_sheets: |
| | continue |
| | try: |
| | df = pd.read_excel(filepath, sheet_name=sheet_name, header=None) |
| | for _, row in df.iterrows(): |
| | row_values = [str(v).strip() for v in row.values if pd.notna(v)] |
| | if len(row_values) < 2: |
| | continue |
| |
|
| | potential_antibiotic = row_values[0] |
| | if any(kw in potential_antibiotic.lower() for kw in |
| | ['antibiotic', 'agent', 'note', 'disk', 'mic', 'breakpoint']): |
| | continue |
| |
|
| | |
| | mic_values = [] |
| | for v in row_values[1:]: |
| | try: |
| | mic_values.append(float(v.replace('≤', '').replace('>', '').replace('<', '').strip())) |
| | except (ValueError, AttributeError): |
| | pass |
| |
|
| | if len(mic_values) >= 2 and len(potential_antibiotic) > 2: |
| | records.append(( |
| | sheet_name, |
| | potential_antibiotic, |
| | None, |
| | mic_values[0], |
| | mic_values[1], |
| | None, None, None, |
| | "16.0", |
| | )) |
| | except Exception as e: |
| | print(f" Warning: Could not parse sheet '{sheet_name}': {e}") |
| | continue |
| |
|
| | if records: |
| | execute_many( |
| | """INSERT INTO mic_breakpoints |
| | (pathogen_group, antibiotic, route, mic_susceptible, mic_resistant, |
| | disk_susceptible, disk_resistant, notes, eucast_version) |
| | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", |
| | records, |
| | ) |
| | print(f" Imported {len(records)} MIC breakpoint records") |
| |
|
| | return len(records) |
| |
|
| |
|
| | KAGGLE_DATASET = "mghobashy/drug-drug-interactions" |
| | KAGGLE_INPUT_DIR = Path("/kaggle/input/drug-drug-interactions") |
| | INTERACTIONS_CSV = DOCS_DIR / "drug_safety" / "db_drug_interactions.csv" |
| |
|
| |
|
| | def _resolve_interactions_csv() -> Path | None: |
| | """ |
| | Find the drug interactions CSV file. |
| | |
| | Checks in order: |
| | 1. docs/drug_safety/db_drug_interactions.csv (local) |
| | 2. /kaggle/input/drug-drug-interactions/ (Kaggle notebook with dataset attached) |
| | 3. kagglehub.dataset_download() — works with KAGGLE_USERNAME/KAGGLE_KEY env vars |
| | 4. Kaggle CLI download (legacy, requires ~/.kaggle/kaggle.json) |
| | """ |
| | if INTERACTIONS_CSV.exists(): |
| | return INTERACTIONS_CSV |
| |
|
| | if KAGGLE_INPUT_DIR.exists(): |
| | for candidate in KAGGLE_INPUT_DIR.glob("*.csv"): |
| | print(f" Found CSV in Kaggle input: {candidate}") |
| | return candidate |
| |
|
| | print(f" CSV not found — downloading from Kaggle dataset '{KAGGLE_DATASET}' via kagglehub ...") |
| | try: |
| | import kagglehub |
| | dataset_path = Path(kagglehub.dataset_download(KAGGLE_DATASET)) |
| | csvs = list(dataset_path.glob("*.csv")) |
| | if csvs: |
| | src = csvs[0] |
| | dest = INTERACTIONS_CSV.parent |
| | dest.mkdir(parents=True, exist_ok=True) |
| | import shutil |
| | shutil.copy2(src, INTERACTIONS_CSV) |
| | print(f" Downloaded via kagglehub: {src.name}") |
| | return INTERACTIONS_CSV |
| | else: |
| | print(f" kagglehub downloaded to {dataset_path} but found no CSV files.") |
| | except Exception as e: |
| | print(f" kagglehub download failed: {e}") |
| |
|
| | print(" Falling back to Kaggle CLI ...") |
| | try: |
| | import kaggle |
| | import subprocess |
| | dest = INTERACTIONS_CSV.parent |
| | dest.mkdir(parents=True, exist_ok=True) |
| | result = subprocess.run( |
| | ["kaggle", "datasets", "download", "-d", KAGGLE_DATASET, "--unzip", "-p", str(dest)], |
| | capture_output=True, text=True, |
| | ) |
| | if result.returncode == 0: |
| | for f in dest.glob("*.csv"): |
| | print(f" Downloaded via Kaggle CLI: {f.name}") |
| | return f |
| | else: |
| | print(f" Kaggle CLI download failed: {result.stderr.strip()}") |
| | except ImportError: |
| | print(" kaggle package not installed — run: uv add kaggle") |
| | except Exception as e: |
| | print(f" Could not download via CLI: {e}") |
| |
|
| | return None |
| |
|
| |
|
| | def import_drug_interactions(limit: int = None) -> int: |
| | """Import drug-drug interactions from the DDInter CSV (Kaggle dataset mghobashy/drug-drug-interactions).""" |
| | print("Importing drug interactions data...") |
| |
|
| | filepath = _resolve_interactions_csv() |
| | if filepath is None: |
| | print(" Skipping drug interactions — CSV unavailable.") |
| | print(f" To fix: attach the Kaggle dataset '{KAGGLE_DATASET}' to your notebook,") |
| | print(" or set up ~/.kaggle/kaggle.json for API access.") |
| | return 0 |
| |
|
| | total_records = 0 |
| | for chunk in pd.read_csv(filepath, chunksize=10000): |
| | chunk.columns = [col.strip().lower().replace(' ', '_') for col in chunk.columns] |
| |
|
| | records = [] |
| | for _, row in chunk.iterrows(): |
| | drug_1 = str(row.get('drug_1', row.get('drug1', row.iloc[0] if len(row) > 0 else ''))) |
| | drug_2 = str(row.get('drug_2', row.get('drug2', row.iloc[1] if len(row) > 1 else ''))) |
| | description = str(row.get('interaction_description', row.get('description', |
| | row.get('interaction', row.iloc[2] if len(row) > 2 else '')))) |
| | if drug_1 and drug_2: |
| | records.append((drug_1, drug_2, description, classify_severity(description))) |
| |
|
| | if records: |
| | execute_many( |
| | "INSERT INTO drug_interactions (drug_1, drug_2, interaction_description, severity) VALUES (?, ?, ?, ?)", |
| | records, |
| | ) |
| | total_records += len(records) |
| |
|
| | if limit and total_records >= limit: |
| | break |
| |
|
| | print(f" Imported {total_records} drug interaction records") |
| | return total_records |
| |
|
| |
|
| | def import_all_data(interactions_limit: int = None) -> dict: |
| | """Initialize the database and import all structured data sources.""" |
| | print(f"\n{'='*50}") |
| | print("AMR-Guard Data Import") |
| | print(f"{'='*50}\n") |
| |
|
| | init_database() |
| |
|
| | with get_connection() as conn: |
| | for table in ("eml_antibiotics", "atlas_susceptibility", "mic_breakpoints", "drug_interactions"): |
| | conn.execute(f"DELETE FROM {table}") |
| | conn.commit() |
| | print("Cleared existing data\n") |
| |
|
| | results = { |
| | "eml_antibiotics": import_eml_antibiotics(), |
| | "atlas_susceptibility": import_atlas_susceptibility(), |
| | "mic_breakpoints": import_mic_breakpoints(), |
| | "drug_interactions": import_drug_interactions(limit=interactions_limit), |
| | } |
| |
|
| | print(f"\n{'='*50}") |
| | print("Import Summary:") |
| | for table, count in results.items(): |
| | print(f" {table}: {count} records") |
| | print(f"{'='*50}\n") |
| |
|
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import_all_data(interactions_limit=50000) |
| |
|