File size: 13,647 Bytes
908ea05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""MEDS v0.4.1 exporter for DATASUS β€” audit-proof, interop-ready.

Verified against:
  - meds 0.4.1 schemas (DataSchema, CodeMetadataSchema)
  - https://github.com/Medical-Event-Data-Standard/meds
  - CLMBR/MOTOR/EHRSHOT/CoMET tokenization conventions

Code conventions (interop-compatible):
  - Static (time=None): GENDER//, RACE//, UF//, MUN//, ORPHA//
  - Birth/Death: MEDS_BIRTH, MEDS_DEATH (reserved)
  - Diagnoses: ICD10//<cid> (NOT CID10// β€” interop with OHDSI/Athena)
  - Hospitalization: SIH//ADM, SIH//DIS (numeric_value=LOS_days on DIS)
  - Procedures: SIGTAP//<10-digit> (Brazil-local namespace)
  - Drugs (APAC): APAC//<sigtap> (numeric_value=monthly_cost_brl)
  - Outpatient (BPA-I): BPAI//<sigtap>
  - Visits: Visit//{IP, OP, ER} (matches CLMBR convention)

Outputs canonical MEDS dataset:
  /out/
  β”œβ”€β”€ data/                       # parquet shards by subject
  β”‚   β”œβ”€β”€ shard_0.parquet
  β”‚   └── ...
  β”œβ”€β”€ metadata/
  β”‚   └── codes.parquet           # REQUIRED: every unique code with description + parent_codes
  └── dataset_metadata.json       # MEDS dataset metadata
"""
from __future__ import annotations
import os
import json
import logging
from collections import defaultdict, Counter
from datetime import datetime
from typing import Iterator

import pyarrow as pa
import pyarrow.parquet as pq
import meds

log = logging.getLogger("gemeo.cdf.meds_export")


def _parse_date(s) -> datetime | None:
    """Parse date string from various DATASUS formats."""
    if s is None: return None
    s = str(s).strip()
    if not s or s in ("0", "None", "nan"): return None
    try:
        if "-" in s:
            return datetime.strptime(s[:10], "%Y-%m-%d")
        if len(s) == 8:
            return datetime.strptime(s, "%Y%m%d")
    except ValueError:
        return None
    return None


def _ym(year, month) -> datetime | None:
    if year is None: return None
    try:
        return datetime(int(year), int(month) if month else 1, 1)
    except (ValueError, TypeError):
        return None


def datasus_patient_to_meds_rows(p: dict, subject_id: int) -> list[tuple]:
    """Convert one DATASUS patient trajectory to a list of MEDS rows.

    Each row is (subject_id, time, code, numeric_value, text_value).
    Returns rows ready to write to a parquet shard.
    """
    rows = []

    # ---- Static (time=None) ----
    if p.get("sex"):
        rows.append((subject_id, None, f"GENDER//{p['sex']}", None, None))
    # ORPHA is rare-disease specific (parallel to ICD10)
    for orpha in p.get("orphas", []):
        rows.append((subject_id, None, f"ORPHA//{orpha}", None, None))

    # ---- Birth (use birth_year as Jan 1) ----
    birth_year = p.get("birth_year")
    birth_dt = datetime(int(birth_year), 1, 1) if birth_year else None
    if birth_dt:
        rows.append((subject_id, birth_dt, "MEDS_BIRTH", None, None))

    # ---- Events ----
    for e in p.get("events", []):
        et = e.get("type")

        if et == "admission":  # SIH-RD
            t = _ym(e.get("year"), e.get("month")) or _parse_date(e.get("admission_date"))
            if not t: continue
            rows.append((subject_id, t, "SIH//ADM", None, None))
            rows.append((subject_id, t, "Visit//IP", None, None))
            cid = e.get("cid_princ", "")
            if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
            proc = e.get("primary_procedure")
            if proc: rows.append((subject_id, t, f"SIGTAP//{proc[:10]}", None, None))
            los = e.get("los_days")
            disch_dt = _parse_date(e.get("discharge_date")) or t
            if e.get("death_during_stay"):
                rows.append((subject_id, disch_dt, "MEDS_DEATH", None, None))
            else:
                rows.append((subject_id, disch_dt, "SIH//DIS",
                            float(los) if los is not None else None, None))

        elif et == "treatment":  # APAC-SIA orphan drug
            t = _ym(e.get("year"), e.get("month"))
            if not t: continue
            cid = e.get("cid", "")
            if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
            proc = e.get("procedure_code", "")[:10]
            if proc:
                cost = e.get("monthly_cost_brl")
                rows.append((subject_id, t, f"APAC//{proc}",
                            float(cost) if cost is not None else None, None))

        elif et == "outpatient_proc":  # BPA-I
            t = _parse_date(e.get("auth_date")) or _ym(e.get("year"), e.get("month"))
            if not t: continue
            cid = e.get("cid", "")
            if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
            proc = e.get("procedure_code", "")[:10]
            if proc:
                rows.append((subject_id, t, f"BPAI//{proc}", None, None))

        elif et == "death":  # SIM
            t = _parse_date(e.get("date_of_death")) or _ym(e.get("year"), e.get("month"))
            if not t: continue
            rows.append((subject_id, t, "MEDS_DEATH", None, None))
            cid = (e.get("cause_cid") or e.get("cid_princ") or e.get("cid", ""))
            if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))

    # Sort: nulls first (static), then by time
    rows.sort(key=lambda r: (r[1] is not None, r[1] or datetime(1900, 1, 1)))
    return rows


def export_to_meds(patients: list[dict], out_dir: str,
                   shard_size: int = 5000,
                   dataset_name: str = "GEMEO-DATASUS",
                   version: str = "v13"):
    """Export a list of DATASUS patient trajectories to MEDS v0.4.1 format.

    Parameters
    ----------
    patients : list of dict
        Each dict must have: patient_id, sex, birth_year, orphas (list),
        events (list of dicts with 'type', 'year', 'month', etc.)
    out_dir : str
        Output directory (will create data/ and metadata/ subdirs)
    shard_size : int
        Number of subjects per parquet shard
    """
    os.makedirs(f"{out_dir}/data", exist_ok=True)
    os.makedirs(f"{out_dir}/metadata", exist_ok=True)

    log.info(f"Exporting {len(patients)} patients to MEDS at {out_dir}")

    # Map patient_id (string hash) β†’ int64 subject_id (MEDS requires int64)
    pid_to_sid = {p["patient_id"]: i for i, p in enumerate(patients)}

    # ---- Stream rows ----
    all_codes = Counter()
    shard_idx = 0
    shard_rows = []
    n_events = 0
    n_subjects = 0

    for p in patients:
        sid = pid_to_sid[p["patient_id"]]
        rows = datasus_patient_to_meds_rows(p, sid)
        shard_rows.extend(rows)
        n_events += len(rows)
        n_subjects += 1
        for r in rows:
            all_codes[r[2]] += 1
        # Write shard when full
        if n_subjects % shard_size == 0 and shard_rows:
            _write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
            shard_idx += 1
            shard_rows = []

    # Write remaining
    if shard_rows:
        _write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")

    log.info(f"  wrote {shard_idx + 1} data shards, {n_events} rows, {n_subjects} subjects")

    # ---- codes.parquet (REQUIRED in MEDS v0.4) ----
    code_rows = []
    for code, count in all_codes.most_common():
        # parent_codes: empty for Brazil-local namespaces; populated for ICD10 -> SNOMED if mapped
        parent_codes = _get_parent_codes(code)
        code_rows.append({
            "code": code,
            "description": _get_description(code, count),
            "parent_codes": parent_codes,
        })
    code_table = pa.Table.from_pylist(code_rows, schema=meds.CodeMetadataSchema.schema())
    pq.write_table(code_table, f"{out_dir}/metadata/codes.parquet")
    log.info(f"  wrote metadata/codes.parquet ({len(code_rows)} unique codes)")

    # ---- dataset_metadata.json ----
    md = {
        "dataset_name": dataset_name,
        "dataset_version": version,
        "etl_name": "gemeo.cdf.meds_export",
        "etl_version": "1.0.0",
        "meds_version": meds.__version__,
        "n_subjects": n_subjects,
        "n_events": n_events,
        "n_unique_codes": len(all_codes),
        "top_codes": dict(all_codes.most_common(30)),
    }
    with open(f"{out_dir}/dataset_metadata.json", "w") as f:
        json.dump(md, f, indent=2, default=str)
    log.info(f"  wrote dataset_metadata.json")

    return md


def _write_shard(rows: list[tuple], path: str):
    """Write a list of (subject_id, time, code, numeric_value, text_value) to parquet."""
    if not rows: return
    # Build columnar arrays
    subject_id = pa.array([r[0] for r in rows], type=pa.int64())
    time = pa.array([r[1] for r in rows], type=pa.timestamp("us"))
    code = pa.array([r[2] for r in rows], type=pa.string())
    numeric_value = pa.array([r[3] for r in rows], type=pa.float32())
    text_value = pa.array([r[4] for r in rows], type=pa.large_string())
    table = pa.Table.from_arrays(
        [subject_id, time, code, numeric_value, text_value],
        names=["subject_id", "time", "code", "numeric_value", "text_value"],
    )
    # Validate against MEDS schema
    expected_schema = meds.DataSchema.schema()
    # Cast if needed
    table = table.cast(expected_schema, safe=False)
    pq.write_table(table, path, compression="zstd")


# Brazilian-specific mapping tables (extend as needed)
ICD10_CHAPTERS = {
    "A": "Certain infectious and parasitic diseases",
    "B": "Certain infectious and parasitic diseases",
    "C": "Neoplasms",
    "D": "Neoplasms / Diseases of the blood and immune",
    "E": "Endocrine, nutritional and metabolic diseases",
    "F": "Mental, Behavioral and Neurodevelopmental disorders",
    "G": "Diseases of the nervous system",
    "H": "Diseases of the eye / ear",
    "I": "Diseases of the circulatory system",
    "J": "Diseases of the respiratory system",
    "K": "Diseases of the digestive system",
    "L": "Diseases of the skin and subcutaneous tissue",
    "M": "Diseases of the musculoskeletal system",
    "N": "Diseases of the genitourinary system",
    "O": "Pregnancy, childbirth and the puerperium",
    "P": "Certain conditions originating in the perinatal period",
    "Q": "Congenital malformations, deformations and chromosomal abnormalities",
    "R": "Symptoms, signs and abnormal clinical and laboratory findings",
    "S": "Injury, poisoning and certain other consequences of external causes",
    "T": "Injury, poisoning and certain other consequences of external causes",
    "V": "External causes of morbidity",
    "W": "External causes of morbidity",
    "X": "External causes of morbidity",
    "Y": "External causes of morbidity",
    "Z": "Factors influencing health status and contact with health services",
}


def _get_description(code: str, count: int) -> str:
    """Generate a brief description for a code (used in codes.parquet)."""
    if code in ("MEDS_BIRTH",): return "Birth event (reserved)"
    if code in ("MEDS_DEATH",): return "Death event (reserved)"
    parts = code.split("//")
    if len(parts) < 2: return f"Unknown code (n={count})"
    domain, val = parts[0], "//".join(parts[1:])
    if domain == "GENDER": return f"Patient sex = {val}"
    if domain == "ORPHA": return f"Orphanet rare disease {val}"
    if domain == "ICD10":
        ch = ICD10_CHAPTERS.get(val[0], "Unknown chapter")
        return f"ICD-10 {val} ({ch})"
    if domain == "SIH": return f"SIH hospitalization {val}"
    if domain == "Visit": return f"Visit type {val}"
    if domain == "SIGTAP": return f"SIGTAP procedure {val}"
    if domain == "APAC": return f"APAC orphan-drug authorization {val}"
    if domain == "BPAI": return f"BPA-I outpatient procedure {val}"
    if domain == "UF": return f"Residence UF {val}"
    return f"{domain} code {val}"


def _get_parent_codes(code: str) -> list[str]:
    """Return parent codes for ontology hierarchy (currently minimal)."""
    parts = code.split("//")
    if len(parts) < 2: return []
    domain, val = parts[0], "//".join(parts[1:])
    parents = []
    if domain == "ICD10" and len(val) >= 3:
        # ICD-10 chapter as parent
        chapter = val[0]
        if chapter in ICD10_CHAPTERS:
            parents.append(f"ICD10//chapter_{chapter}")
        # 3-char prefix as parent (e.g., E84.0 β†’ E84)
        if "." in val:
            parents.append(f"ICD10//{val.split('.')[0]}")
        elif len(val) > 3:
            parents.append(f"ICD10//{val[:3]}")
    if domain == "SIGTAP" and len(val) >= 4:
        # 4-digit group as parent (SIGTAP 10-digit β†’ 4-digit group)
        parents.append(f"SIGTAP//group_{val[:4]}")
    return parents


def load_meds_dataset(meds_dir: str) -> dict:
    """Load a MEDS dataset back from parquet for inspection or downstream processing."""
    import glob
    shards = sorted(glob.glob(f"{meds_dir}/data/*.parquet"))
    tables = [pq.read_table(p) for p in shards]
    data = pa.concat_tables(tables) if tables else None
    codes = pq.read_table(f"{meds_dir}/metadata/codes.parquet")
    md = json.load(open(f"{meds_dir}/dataset_metadata.json"))
    return {"data": data, "codes": codes, "metadata": md}


if __name__ == "__main__":
    # Quick test on real patient data
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(levelname)s %(message)s")
    PATIENTS = "/tmp/datasus_patient_trajectories_v2.json"
    if os.path.exists(PATIENTS):
        patients = json.load(open(PATIENTS))[:50]   # 50 patients smoke test
        md = export_to_meds(patients, "/tmp/meds_smoke_test")
        print("\n=== smoke test result ===")
        print(json.dumps(md, indent=2, default=str))