import re import json import os import pandas as pd import argparse from datetime import datetime from src.utils.data_loader import DataLoader class WTTSBuilder: def __init__(self): # --- WEIGHTING RULES (W) --- self.critical_patterns = [ r'respiratory failure', r'seizure', r'cardiac arrest', r'intubat', r'abnormal', r'critical', r'hemorrhage', r'positive', r'emergency', r'acute', r'hypoxia', r'flagged', r'icu', r'dyspnea', r'shortness of breath', r'sob', r'mrc grade', r'nyha' ] self.chronic_patterns = [ r'history of', r'chronic', r'stable', r'continued', r'maintained', r'diagnosed with', r'previous', r'known' ] self.routine_patterns = [ r'routine', r'normal', r'negative', r'unremarkable', r'no acute', r'clear', r'regular diet', r'resting' ] def _get_normalized_time(self, event_time_str, admit_str, disch_str): """Calculates P_j (0.0 to 1.0)""" try: e_dt = pd.to_datetime(event_time_str) a_dt = pd.to_datetime(admit_str) d_dt = pd.to_datetime(disch_str) total_duration = (d_dt - a_dt).total_seconds() elapsed = (e_dt - a_dt).total_seconds() if total_duration <= 0: return 1.0 return round(max(0.0, min(1.0, elapsed / total_duration)), 2) except: return 0.5 def _get_weight(self, text): t = text.lower() if any(re.search(p, t) for p in self.critical_patterns): return 1.0 if any(re.search(p, t) for p in self.chronic_patterns): return 0.5 if any(re.search(p, t) for p in self.routine_patterns): return 0.1 return 0.5 def _extract_sentences_with_ids(self, text, start_index): """ Splits notes and assigns UNIQUE IDs. CRITICAL FIX: Sanitizes newlines to preserve WTTS structure. """ # 1. Replace newlines/tabs with spaces to keep tuple on one line text = text.replace('\n', ' ').replace('\r', '').replace('\t', ' ') # 2. Remove de-id brackets text = re.sub(r'\[\*\*.*?\*\*\]', '', text) # 3. Split by sentence boundaries sentences = re.split(r'(? 5: sid = f"S_{current_idx}" results.append((sid, clean_s)) current_idx += 1 return results, current_idx def build_wtts_string(self, patient_data): tuples = [] admit = patient_data.get('admission_time') disch = patient_data.get('discharge_time') sorted_notes = sorted(patient_data.get('notes', []), key=lambda x: x['timestamp']) global_sent_idx = 0 for note in sorted_notes: raw_ts = note['timestamp'] p_j = self._get_normalized_time(raw_ts, admit, disch) events, global_sent_idx = self._extract_sentences_with_ids(note['text'], global_sent_idx) for (sid, event) in events: w = self._get_weight(event) # Format: [ID] ("Date", "Event", P_j, W) tuples.append(f'[{sid}] ("{raw_ts}", "{event}", {p_j}, {w})') return " | ".join(tuples) # --- EXECUTION LOGIC --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process raw clinical notes into WTTS format.") # Set defaults for easier running parser.add_argument("--input_dirs", nargs="+", default=[ "data/raw/dyspnea-clinical-notes", "data/raw/dyspnea-crf-development", ], help="Directories containing .parquet shards (searched recursively)") parser.add_argument("--gt_file", type=str, default="data/raw/dev_gt.jsonl", help="Path to the ground truth JSONL file.") parser.add_argument("--output_dir", type=str, default="data/processed/materialized_ehr", help="Path to store processed JSONL files.") args = parser.parse_args() loader = DataLoader(data_folders=args.input_dirs, gt_path=args.gt_file) builder = WTTSBuilder() patients = loader.load_and_merge() os.makedirs(args.output_dir, exist_ok=True) if not patients: print("No patients found! Check paths.") else: print(f"Materializing timelines for {len(patients)} patients...") for p in patients: wtts_output = builder.build_wtts_string(p) # FIX: Prioritize document_id to match DataLoader logic pid = str(p.get('document_id') or p.get('patient_id') or p.get('hadm_id')) output_path = os.path.join(args.output_dir, f"{pid}.jsonl") with open(output_path, 'w') as f: json.dump({"person_id": pid, "text": wtts_output}, f) print(f"Successfully stored outputs in: {args.output_dir}")