ai-response-validator / scripts /kaggle_to_yaml.py
below-threshold's picture
Add 15 drug profiles to pharma KB from Kaggle dataset
2a3badd
"""
Convert a Kaggle drug CSV into features.yaml entries and append to the pharma KB.
Supported CSV formats:
- Drug Labels & Side Effects (drug_name, medical_condition, side_effects)
- Drug Dataset: Uses, Side Effects & User Reviews (same columns)
- Medicine Side-Effects Analysis (drug, use, side_effects / adverse_effects)
Usage:
python scripts/kaggle_to_yaml.py path/to/drugs.csv [--max 15] [--dry-run]
Options:
--max N Max number of drugs to import (default: 15)
--dry-run Print YAML to stdout instead of appending to features.yaml
--domain D Target domain (default: pharma)
"""
import argparse
import csv
import re
import sys
from pathlib import Path
KNOWLEDGE_ROOT = Path(__file__).parent.parent / "knowledge"
_DRUG_COLS = ["drug_name", "Drug Name", "name", "drug", "drugName"]
_USE_COLS = ["indications", "medical_condition", "condition", "use", "uses",
"indication", "Medical Condition"]
_EFFECT_COLS = ["side_effects", "Side_Effects", "Side Effects", "sideEffects",
"adverse_effects", "adverse_events"]
_CONTRA_COLS = ["contraindications", "contraindication", "Contraindications"]
_WARN_COLS = ["warnings", "warning", "Warnings", "precautions"]
def _find_col(headers: list[str], candidates: list[str]) -> str | None:
h_lower = {h.lower().strip(): h for h in headers}
for c in candidates:
if c in headers:
return c
if c.lower().strip() in h_lower:
return h_lower[c.lower().strip()]
return None
def _slugify(s: str) -> str:
return re.sub(r"[^a-z0-9]+", "-", s.lower()).strip("-")
def _truncate(text: str, max_items: int = 8) -> str:
sep = "," if text.count(",") >= text.count(";") else ";"
items = [i.strip() for i in text.split(sep) if i.strip()]
if len(items) > max_items:
return ", ".join(items[:max_items]) + ", and others"
return ", ".join(items)
def _to_yaml_block(doc_id: str, title: str, content: str, tags: list[str]) -> str:
safe_content = content.replace('"', '\\"')
tags_yaml = ", ".join(f'"{t}"' for t in tags)
return (
f"\n - id: {doc_id}\n"
f" title: \"{title}\"\n"
f" content: >\n"
+ "".join(f" {line}\n" for line in content.splitlines())
+ f" tags: [{tags_yaml}]\n"
)
def convert(csv_path: Path, max_drugs: int, dry_run: bool, domain: str) -> None:
with csv_path.open(newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
headers = list(reader.fieldnames or [])
drug_col = _find_col(headers, _DRUG_COLS)
use_col = _find_col(headers, _USE_COLS)
effect_col = _find_col(headers, _EFFECT_COLS)
contra_col = _find_col(headers, _CONTRA_COLS)
warn_col = _find_col(headers, _WARN_COLS)
if not drug_col:
print(f"ERROR: no drug name column found. Headers: {headers}", file=sys.stderr)
sys.exit(1)
print(f"Detected columns — drug: {drug_col!r}, use: {use_col!r}, "
f"effects: {effect_col!r}, contra: {contra_col!r}, warnings: {warn_col!r}")
blocks: list[str] = []
seen: set[str] = set()
for i, row in enumerate(reader):
if len(blocks) >= max_drugs:
break
drug = row.get(drug_col, "").strip().title()
if not drug or drug.lower() in seen:
continue
seen.add(drug.lower())
condition = row.get(use_col, "").strip() if use_col else ""
effects = row.get(effect_col, "").strip() if effect_col else ""
contra = row.get(contra_col, "").strip() if contra_col else ""
warnings = row.get(warn_col, "").strip() if warn_col else ""
condition_str = condition or "the indicated condition"
effects_str = _truncate(effects) if effects else "not listed"
parts = [
f"{drug} is indicated for the treatment or management of {condition_str}.",
f"Known adverse events include: {effects_str}.",
]
if contra:
parts.append(f"Contraindicated in patients with: {contra}.")
if warnings:
parts.append(f"Prescriber warning: {warnings}.")
parts.append(
"Serious unexpected adverse events must be reported to the regulatory "
"authority within 15 days."
)
content = " ".join(parts)
tags = list(filter(None, [
_slugify(drug),
_slugify(condition) if condition else None,
"adverse-event",
"drug-profile",
]))
doc_id = f"pharma_drug_{i + 1:03d}"
blocks.append(_to_yaml_block(doc_id, f"{drug} — Drug Profile", content, tags))
if not blocks:
print("No documents generated. Check the CSV format.", file=sys.stderr)
sys.exit(1)
yaml_str = "".join(blocks)
if dry_run:
print(yaml_str)
print(f"\n[dry-run] {len(blocks)} drug entries would be appended.", file=sys.stderr)
return
features_path = KNOWLEDGE_ROOT / domain / "features.yaml"
original = features_path.read_text()
# Remove trailing newline before appending
updated = original.rstrip() + "\n" + yaml_str
features_path.write_text(updated)
print(f"Appended {len(blocks)} drug entries to {features_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="Convert Kaggle drug CSV to features.yaml entries")
parser.add_argument("csv", type=Path, help="Path to Kaggle drug CSV")
parser.add_argument("--max", type=int, default=15, dest="max_drugs")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--domain", default="pharma")
args = parser.parse_args()
if not args.csv.exists():
print(f"ERROR: file not found: {args.csv}", file=sys.stderr)
sys.exit(1)
convert(args.csv, args.max_drugs, args.dry_run, args.domain)
if __name__ == "__main__":
main()