File size: 5,495 Bytes
a5b6ba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import re
import json
import os
import pandas as pd
import argparse
from datetime import datetime
from src.utils.data_loader import DataLoader

class WTTSBuilder:
    def __init__(self):
        # --- WEIGHTING RULES (W) ---
        self.critical_patterns = [
            r'respiratory failure', r'seizure', r'cardiac arrest', r'intubat', 
            r'abnormal', r'critical', r'hemorrhage', r'positive', 
            r'emergency', r'acute', r'hypoxia', r'flagged', r'icu',
            r'dyspnea', r'shortness of breath', r'sob', r'mrc grade', r'nyha'
        ]
        self.chronic_patterns = [
            r'history of', r'chronic', r'stable', r'continued', 
            r'maintained', r'diagnosed with', r'previous', r'known'
        ]
        self.routine_patterns = [
            r'routine', r'normal', r'negative', r'unremarkable', 
            r'no acute', r'clear', r'regular diet', r'resting'
        ]

    def _get_normalized_time(self, event_time_str, admit_str, disch_str):
        """Calculates P_j (0.0 to 1.0)"""
        try:
            e_dt = pd.to_datetime(event_time_str)
            a_dt = pd.to_datetime(admit_str)
            d_dt = pd.to_datetime(disch_str)
            
            total_duration = (d_dt - a_dt).total_seconds()
            elapsed = (e_dt - a_dt).total_seconds()
            
            if total_duration <= 0: return 1.0
            return round(max(0.0, min(1.0, elapsed / total_duration)), 2)
        except:
            return 0.5

    def _get_weight(self, text):
        t = text.lower()
        if any(re.search(p, t) for p in self.critical_patterns): return 1.0
        if any(re.search(p, t) for p in self.chronic_patterns): return 0.5
        if any(re.search(p, t) for p in self.routine_patterns): return 0.1
        return 0.5 

    def _extract_sentences_with_ids(self, text, start_index):
        """

        Splits notes and assigns UNIQUE IDs.

        CRITICAL FIX: Sanitizes newlines to preserve WTTS structure.

        """
        # 1. Replace newlines/tabs with spaces to keep tuple on one line
        text = text.replace('\n', ' ').replace('\r', '').replace('\t', ' ')
        
        # 2. Remove de-id brackets
        text = re.sub(r'\[\*\*.*?\*\*\]', '', text) 
        
        # 3. Split by sentence boundaries
        sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
        
        results = []
        current_idx = start_index
        for s in sentences:
            clean_s = s.strip()
            # Ignore very short/empty fragments
            if len(clean_s) > 5:
                sid = f"S_{current_idx}"
                results.append((sid, clean_s))
                current_idx += 1
        return results, current_idx

    def build_wtts_string(self, patient_data):
        tuples = []
        admit = patient_data.get('admission_time')
        disch = patient_data.get('discharge_time')
        
        sorted_notes = sorted(patient_data.get('notes', []), key=lambda x: x['timestamp'])
        
        global_sent_idx = 0 

        for note in sorted_notes:
            raw_ts = note['timestamp']
            p_j = self._get_normalized_time(raw_ts, admit, disch)
            
            events, global_sent_idx = self._extract_sentences_with_ids(note['text'], global_sent_idx)
            
            for (sid, event) in events:
                w = self._get_weight(event)
                # Format: [ID] ("Date", "Event", P_j, W)
                tuples.append(f'[{sid}] ("{raw_ts}", "{event}", {p_j}, {w})')
        
        return " | ".join(tuples)

# --- EXECUTION LOGIC ---
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process raw clinical notes into WTTS format.")
    # Set defaults for easier running
    parser.add_argument("--input_dirs", nargs="+",
                        default=[
                            "data/raw/dyspnea-clinical-notes",
                            "data/raw/dyspnea-crf-development",
                        ],
                        help="Directories containing .parquet shards (searched recursively)")
    parser.add_argument("--gt_file", type=str, 
                        default="data/raw/dev_gt.jsonl",
                        help="Path to the ground truth JSONL file.")
    parser.add_argument("--output_dir", type=str, 
                        default="data/processed/materialized_ehr",
                        help="Path to store processed JSONL files.")
    
    args = parser.parse_args()

    loader = DataLoader(data_folders=args.input_dirs, gt_path=args.gt_file)
    builder = WTTSBuilder()

    patients = loader.load_and_merge()
    os.makedirs(args.output_dir, exist_ok=True)

    if not patients:
        print("No patients found! Check paths.")
    else:
        print(f"Materializing timelines for {len(patients)} patients...")
        for p in patients:
            wtts_output = builder.build_wtts_string(p)
            
            # FIX: Prioritize document_id to match DataLoader logic
            pid = str(p.get('document_id') or p.get('patient_id') or p.get('hadm_id'))
            
            output_path = os.path.join(args.output_dir, f"{pid}.jsonl")
            with open(output_path, 'w') as f:
                json.dump({"person_id": pid, "text": wtts_output}, f)
        
        print(f"Successfully stored outputs in: {args.output_dir}")