File size: 14,731 Bytes
85020ae
793d027
 
 
 
 
 
 
 
 
 
ad9e267
793d027
 
 
 
 
 
 
 
 
ad9e267
793d027
 
 
 
 
 
 
 
ad9e267
 
 
 
 
 
 
793d027
ad9e267
 
 
 
 
 
793d027
 
 
 
 
 
 
 
 
ad9e267
793d027
 
 
 
ad9e267
793d027
 
ad9e267
 
 
 
793d027
 
 
 
ad9e267
793d027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad9e267
 
 
 
793d027
ad9e267
793d027
 
ad9e267
793d027
 
 
 
 
 
 
 
 
 
 
 
 
 
ad9e267
793d027
 
 
 
 
 
ad9e267
 
 
 
 
 
 
793d027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad9e267
793d027
ad9e267
793d027
 
 
 
 
 
 
ad9e267
 
793d027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad9e267
 
793d027
 
 
ad9e267
 
793d027
 
ad9e267
 
793d027
ad9e267
793d027
ad9e267
793d027
 
 
ad9e267
 
 
 
 
 
 
 
793d027
 
 
 
 
 
ad9e267
793d027
 
 
 
 
 
 
 
ad9e267
793d027
 
 
 
 
 
 
 
 
ad9e267
793d027
ad9e267
 
 
 
 
 
 
793d027
ad9e267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793d027
 
 
 
 
ad9e267
 
 
 
 
 
 
793d027
 
 
 
 
82a7e73
 
 
 
 
 
 
ad9e267
82a7e73
ad9e267
 
 
837c265
 
82a7e73
 
 
 
ad9e267
 
 
 
82a7e73
837c265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82a7e73
ad9e267
 
82a7e73
 
 
ad9e267
82a7e73
 
 
 
837c265
82a7e73
 
837c265
82a7e73
 
 
837c265
82a7e73
 
 
 
793d027
ad9e267
793d027
 
82a7e73
 
 
 
 
793d027
 
 
ad9e267
793d027
 
 
 
 
 
 
 
 
ad9e267
793d027
 
ad9e267
 
 
 
793d027
 
 
 
 
 
 
 
 
 
ad9e267
793d027
85020ae
793d027
 
 
 
 
ad9e267
 
793d027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
"""Data import scripts for AMR-Guard structured documents."""

import pandas as pd
from pathlib import Path
from .database import (
    get_connection, init_database, execute_many,
    DOCS_DIR, DB_PATH
)


def safe_float(value):
    """Convert value to float; return None if the value is NaN or non-numeric."""
    if pd.isna(value):
        return None
    try:
        return float(value)
    except (ValueError, TypeError):
        return None


def safe_int(value):
    """Convert value to int via float; return None if the value is NaN or non-numeric."""
    if pd.isna(value):
        return None
    try:
        return int(float(value))
    except (ValueError, TypeError):
        return None


def safe_str(value) -> str:
    """Convert value to string; return empty string for None or NaN."""
    if value is None or pd.isna(value):
        return ''
    return str(value)


def classify_severity(description: str) -> str:
    """
    Classify drug interaction severity from the interaction description text.

    Returns 'major', 'moderate', or 'minor' based on keyword presence.
    Major keywords take precedence over moderate.
    """
    if not description:
        return "unknown"

    desc_lower = description.lower()

    major_keywords = [
        "cardiotoxic", "nephrotoxic", "hepatotoxic", "neurotoxic",
        "fatal", "death", "severe", "contraindicated", "arrhythmia",
        "qt prolongation", "seizure", "bleeding", "hemorrhage",
        "serotonin syndrome", "neuroleptic malignant",
    ]
    moderate_keywords = [
        "increase", "decrease", "reduce", "enhance", "inhibit",
        "metabolism", "concentration", "absorption", "excretion",
        "therapeutic effect", "adverse effect", "toxicity",
    ]

    if any(kw in desc_lower for kw in major_keywords):
        return "major"
    if any(kw in desc_lower for kw in moderate_keywords):
        return "moderate"
    return "minor"


def import_eml_antibiotics() -> int:
    """Import WHO EML antibiotic classification data from the three AWaRe Excel files."""
    print("Importing EML antibiotic data...")

    eml_files = {
        "ACCESS": DOCS_DIR / "antibiotic_guidelines" / "EML export-ACCESS group.xlsx",
        "RESERVE": DOCS_DIR / "antibiotic_guidelines" / "EML export-RESERVE group.xlsx",
        "WATCH": DOCS_DIR / "antibiotic_guidelines" / "EML export-WATCH group.xlsx",
    }

    records = []
    for category, filepath in eml_files.items():
        if not filepath.exists():
            print(f"  Warning: {filepath} not found, skipping...")
            continue

        try:
            import openpyxl
            wb = openpyxl.load_workbook(filepath, read_only=True)
            ws = wb.active

            headers = [
                str(cell.value).strip().lower().replace(' ', '_') if cell.value else f'col_{i}'
                for i, cell in enumerate(ws[1])
            ]

            for row in ws.iter_rows(min_row=2, values_only=True):
                row_dict = dict(zip(headers, row))
                medicine = str(row_dict.get('medicine_name', row_dict.get('medicine', '')))
                if not medicine or medicine in ('None', 'nan'):
                    continue

                records.append((
                    medicine,
                    category,
                    safe_str(row_dict.get('eml_section', '')),
                    safe_str(row_dict.get('formulations', '')),
                    safe_str(row_dict.get('indication', '')),
                    safe_str(row_dict.get('atc_codes', row_dict.get('atc_code', ''))),
                    safe_str(row_dict.get('combined_with', '')),
                    safe_str(row_dict.get('status', '')),
                ))

            wb.close()
            print(f"  Loaded {sum(1 for r in records if r[1] == category)} from {category}")

        except Exception as e:
            print(f"  Warning: Error reading {filepath}: {e}")
            continue

    if records:
        execute_many(
            """INSERT INTO eml_antibiotics
               (medicine_name, who_category, eml_section, formulations,
                indication, atc_codes, combined_with, status)
               VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
            records,
        )
        print(f"  Imported {len(records)} EML antibiotic records total")

    return len(records)


def import_atlas_susceptibility() -> int:
    """Import ATLAS antimicrobial susceptibility data."""
    print("Importing ATLAS susceptibility data...")

    filepath = DOCS_DIR / "pathogen_resistance" / "ATLAS Susceptibility Data Export.xlsx"

    if not filepath.exists():
        print(f"  Warning: {filepath} not found, skipping...")
        return 0

    df_raw = pd.read_excel(filepath, sheet_name="Percent", header=None)

    # Title row contains "Percentage Susceptibility from <Country>"
    region = "Unknown"
    for _, row in df_raw.head(5).iterrows():
        cell = str(row.iloc[0]) if pd.notna(row.iloc[0]) else ""
        if "from" in cell.lower():
            parts = cell.split("from")
            if len(parts) > 1:
                region = parts[1].strip()
            break

    # Locate the actual header row by finding "Antibacterial"
    header_row = 4
    for idx, row in df_raw.head(10).iterrows():
        if any('Antibacterial' in str(v) for v in row.values if pd.notna(v)):
            header_row = idx
            break

    df = pd.read_excel(filepath, sheet_name="Percent", header=header_row)
    df.columns = [str(col).strip().lower().replace(' ', '_').replace('.', '') for col in df.columns]

    records = []
    for _, row in df.iterrows():
        antibiotic = str(row.get('antibacterial', ''))
        if not antibiotic or antibiotic == 'nan' or 'omitted' in antibiotic.lower():
            continue
        if 'in vitro' in antibiotic.lower() or 'table cells' in antibiotic.lower():
            continue

        n_int = safe_int(row.get('n'))
        s_float = safe_float(row.get('susc', row.get('susceptible')))

        if n_int is not None and s_float is not None:
            records.append((
                "General",
                "",
                antibiotic,
                s_float,
                safe_float(row.get('int', row.get('intermediate'))),
                safe_float(row.get('res', row.get('resistant'))),
                n_int,
                2024,
                region,
                "ATLAS",
            ))

    if records:
        execute_many(
            """INSERT INTO atlas_susceptibility
               (species, family, antibiotic, percent_susceptible,
                percent_intermediate, percent_resistant, total_isolates,
                year, region, source)
               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
            records,
        )
        print(f"  Imported {len(records)} ATLAS susceptibility records from {region}")

    return len(records)


def import_mic_breakpoints() -> int:
    """Import EUCAST MIC breakpoint tables from the Excel file."""
    print("Importing MIC breakpoint data...")

    filepath = DOCS_DIR / "mic_breakpoints" / "v_16.0__BreakpointTables.xlsx"
    if not filepath.exists():
        print(f"  Warning: {filepath} not found, skipping...")
        return 0

    xl = pd.ExcelFile(filepath)
    # These sheets contain metadata/guidance, not pathogen-specific breakpoints
    skip_sheets = {'Content', 'Changes', 'Notes', 'Guidance', 'Dosages',
                   'Technical uncertainty', 'PK PD breakpoints', 'PK PD cutoffs'}

    records = []
    for sheet_name in xl.sheet_names:
        if sheet_name in skip_sheets:
            continue
        try:
            df = pd.read_excel(filepath, sheet_name=sheet_name, header=None)
            for _, row in df.iterrows():
                row_values = [str(v).strip() for v in row.values if pd.notna(v)]
                if len(row_values) < 2:
                    continue

                potential_antibiotic = row_values[0]
                if any(kw in potential_antibiotic.lower() for kw in
                       ['antibiotic', 'agent', 'note', 'disk', 'mic', 'breakpoint']):
                    continue

                # Extract numeric MIC values; strip inequality signs
                mic_values = []
                for v in row_values[1:]:
                    try:
                        mic_values.append(float(v.replace('≤', '').replace('>', '').replace('<', '').strip()))
                    except (ValueError, AttributeError):
                        pass

                if len(mic_values) >= 2 and len(potential_antibiotic) > 2:
                    records.append((
                        sheet_name,          # pathogen_group
                        potential_antibiotic,
                        None,                # route
                        mic_values[0],       # S breakpoint
                        mic_values[1],       # R breakpoint
                        None, None, None,    # disk S, disk R, notes
                        "16.0",
                    ))
        except Exception as e:
            print(f"  Warning: Could not parse sheet '{sheet_name}': {e}")
            continue

    if records:
        execute_many(
            """INSERT INTO mic_breakpoints
               (pathogen_group, antibiotic, route, mic_susceptible, mic_resistant,
                disk_susceptible, disk_resistant, notes, eucast_version)
               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
            records,
        )
        print(f"  Imported {len(records)} MIC breakpoint records")

    return len(records)


KAGGLE_DATASET = "mghobashy/drug-drug-interactions"
KAGGLE_INPUT_DIR = Path("/kaggle/input/drug-drug-interactions")
INTERACTIONS_CSV = DOCS_DIR / "drug_safety" / "db_drug_interactions.csv"


def _resolve_interactions_csv() -> Path | None:
    """
    Find the drug interactions CSV file.

    Checks in order:
    1. docs/drug_safety/db_drug_interactions.csv (local)
    2. /kaggle/input/drug-drug-interactions/ (Kaggle notebook with dataset attached)
    3. kagglehub.dataset_download() — works with KAGGLE_USERNAME/KAGGLE_KEY env vars
    4. Kaggle CLI download (legacy, requires ~/.kaggle/kaggle.json)
    """
    if INTERACTIONS_CSV.exists():
        return INTERACTIONS_CSV

    if KAGGLE_INPUT_DIR.exists():
        for candidate in KAGGLE_INPUT_DIR.glob("*.csv"):
            print(f"  Found CSV in Kaggle input: {candidate}")
            return candidate

    print(f"  CSV not found — downloading from Kaggle dataset '{KAGGLE_DATASET}' via kagglehub ...")
    try:
        import kagglehub
        dataset_path = Path(kagglehub.dataset_download(KAGGLE_DATASET))
        csvs = list(dataset_path.glob("*.csv"))
        if csvs:
            src = csvs[0]
            dest = INTERACTIONS_CSV.parent
            dest.mkdir(parents=True, exist_ok=True)
            import shutil
            shutil.copy2(src, INTERACTIONS_CSV)
            print(f"  Downloaded via kagglehub: {src.name}")
            return INTERACTIONS_CSV
        else:
            print(f"  kagglehub downloaded to {dataset_path} but found no CSV files.")
    except Exception as e:
        print(f"  kagglehub download failed: {e}")

    print("  Falling back to Kaggle CLI ...")
    try:
        import kaggle  # noqa: F401 — triggers credential check
        import subprocess
        dest = INTERACTIONS_CSV.parent
        dest.mkdir(parents=True, exist_ok=True)
        result = subprocess.run(
            ["kaggle", "datasets", "download", "-d", KAGGLE_DATASET, "--unzip", "-p", str(dest)],
            capture_output=True, text=True,
        )
        if result.returncode == 0:
            for f in dest.glob("*.csv"):
                print(f"  Downloaded via Kaggle CLI: {f.name}")
                return f
        else:
            print(f"  Kaggle CLI download failed: {result.stderr.strip()}")
    except ImportError:
        print("  kaggle package not installed — run: uv add kaggle")
    except Exception as e:
        print(f"  Could not download via CLI: {e}")

    return None


def import_drug_interactions(limit: int = None) -> int:
    """Import drug-drug interactions from the DDInter CSV (Kaggle dataset mghobashy/drug-drug-interactions)."""
    print("Importing drug interactions data...")

    filepath = _resolve_interactions_csv()
    if filepath is None:
        print("  Skipping drug interactions — CSV unavailable.")
        print(f"  To fix: attach the Kaggle dataset '{KAGGLE_DATASET}' to your notebook,")
        print("  or set up ~/.kaggle/kaggle.json for API access.")
        return 0

    total_records = 0
    for chunk in pd.read_csv(filepath, chunksize=10000):
        chunk.columns = [col.strip().lower().replace(' ', '_') for col in chunk.columns]

        records = []
        for _, row in chunk.iterrows():
            drug_1 = str(row.get('drug_1', row.get('drug1', row.iloc[0] if len(row) > 0 else '')))
            drug_2 = str(row.get('drug_2', row.get('drug2', row.iloc[1] if len(row) > 1 else '')))
            description = str(row.get('interaction_description', row.get('description',
                             row.get('interaction', row.iloc[2] if len(row) > 2 else ''))))
            if drug_1 and drug_2:
                records.append((drug_1, drug_2, description, classify_severity(description)))

        if records:
            execute_many(
                "INSERT INTO drug_interactions (drug_1, drug_2, interaction_description, severity) VALUES (?, ?, ?, ?)",
                records,
            )
            total_records += len(records)

        if limit and total_records >= limit:
            break

    print(f"  Imported {total_records} drug interaction records")
    return total_records


def import_all_data(interactions_limit: int = None) -> dict:
    """Initialize the database and import all structured data sources."""
    print(f"\n{'='*50}")
    print("AMR-Guard Data Import")
    print(f"{'='*50}\n")

    init_database()

    with get_connection() as conn:
        for table in ("eml_antibiotics", "atlas_susceptibility", "mic_breakpoints", "drug_interactions"):
            conn.execute(f"DELETE FROM {table}")
        conn.commit()
    print("Cleared existing data\n")

    results = {
        "eml_antibiotics": import_eml_antibiotics(),
        "atlas_susceptibility": import_atlas_susceptibility(),
        "mic_breakpoints": import_mic_breakpoints(),
        "drug_interactions": import_drug_interactions(limit=interactions_limit),
    }

    print(f"\n{'='*50}")
    print("Import Summary:")
    for table, count in results.items():
        print(f"  {table}: {count} records")
    print(f"{'='*50}\n")

    return results


if __name__ == "__main__":
    import_all_data(interactions_limit=50000)