Spaces:
Build error
Build error
| import re | |
| import json | |
| import os | |
| import pandas as pd | |
| import argparse | |
| from datetime import datetime | |
| from src.utils.data_loader import DataLoader | |
| class WTTSBuilder: | |
| def __init__(self): | |
| # --- WEIGHTING RULES (W) --- | |
| self.critical_patterns = [ | |
| r'respiratory failure', r'seizure', r'cardiac arrest', r'intubat', | |
| r'abnormal', r'critical', r'hemorrhage', r'positive', | |
| r'emergency', r'acute', r'hypoxia', r'flagged', r'icu', | |
| r'dyspnea', r'shortness of breath', r'sob', r'mrc grade', r'nyha' | |
| ] | |
| self.chronic_patterns = [ | |
| r'history of', r'chronic', r'stable', r'continued', | |
| r'maintained', r'diagnosed with', r'previous', r'known' | |
| ] | |
| self.routine_patterns = [ | |
| r'routine', r'normal', r'negative', r'unremarkable', | |
| r'no acute', r'clear', r'regular diet', r'resting' | |
| ] | |
| def _get_normalized_time(self, event_time_str, admit_str, disch_str): | |
| """Calculates P_j (0.0 to 1.0)""" | |
| try: | |
| e_dt = pd.to_datetime(event_time_str) | |
| a_dt = pd.to_datetime(admit_str) | |
| d_dt = pd.to_datetime(disch_str) | |
| total_duration = (d_dt - a_dt).total_seconds() | |
| elapsed = (e_dt - a_dt).total_seconds() | |
| if total_duration <= 0: return 1.0 | |
| return round(max(0.0, min(1.0, elapsed / total_duration)), 2) | |
| except: | |
| return 0.5 | |
| def _get_weight(self, text): | |
| t = text.lower() | |
| if any(re.search(p, t) for p in self.critical_patterns): return 1.0 | |
| if any(re.search(p, t) for p in self.chronic_patterns): return 0.5 | |
| if any(re.search(p, t) for p in self.routine_patterns): return 0.1 | |
| return 0.5 | |
| def _extract_sentences_with_ids(self, text, start_index): | |
| """ | |
| Splits notes and assigns UNIQUE IDs. | |
| CRITICAL FIX: Sanitizes newlines to preserve WTTS structure. | |
| """ | |
| # 1. Replace newlines/tabs with spaces to keep tuple on one line | |
| text = text.replace('\n', ' ').replace('\r', '').replace('\t', ' ') | |
| # 2. Remove de-id brackets | |
| text = re.sub(r'\[\*\*.*?\*\*\]', '', text) | |
| # 3. Split by sentence boundaries | |
| sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text) | |
| results = [] | |
| current_idx = start_index | |
| for s in sentences: | |
| clean_s = s.strip() | |
| # Ignore very short/empty fragments | |
| if len(clean_s) > 5: | |
| sid = f"S_{current_idx}" | |
| results.append((sid, clean_s)) | |
| current_idx += 1 | |
| return results, current_idx | |
| def build_wtts_string(self, patient_data): | |
| tuples = [] | |
| admit = patient_data.get('admission_time') | |
| disch = patient_data.get('discharge_time') | |
| sorted_notes = sorted(patient_data.get('notes', []), key=lambda x: x['timestamp']) | |
| global_sent_idx = 0 | |
| for note in sorted_notes: | |
| raw_ts = note['timestamp'] | |
| p_j = self._get_normalized_time(raw_ts, admit, disch) | |
| events, global_sent_idx = self._extract_sentences_with_ids(note['text'], global_sent_idx) | |
| for (sid, event) in events: | |
| w = self._get_weight(event) | |
| # Format: [ID] ("Date", "Event", P_j, W) | |
| tuples.append(f'[{sid}] ("{raw_ts}", "{event}", {p_j}, {w})') | |
| return " | ".join(tuples) | |
| # --- EXECUTION LOGIC --- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Process raw clinical notes into WTTS format.") | |
| # Set defaults for easier running | |
| parser.add_argument("--input_dirs", nargs="+", | |
| default=[ | |
| "data/raw/dyspnea-clinical-notes", | |
| "data/raw/dyspnea-crf-development", | |
| ], | |
| help="Directories containing .parquet shards (searched recursively)") | |
| parser.add_argument("--gt_file", type=str, | |
| default="data/raw/dev_gt.jsonl", | |
| help="Path to the ground truth JSONL file.") | |
| parser.add_argument("--output_dir", type=str, | |
| default="data/processed/materialized_ehr", | |
| help="Path to store processed JSONL files.") | |
| args = parser.parse_args() | |
| loader = DataLoader(data_folders=args.input_dirs, gt_path=args.gt_file) | |
| builder = WTTSBuilder() | |
| patients = loader.load_and_merge() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| if not patients: | |
| print("No patients found! Check paths.") | |
| else: | |
| print(f"Materializing timelines for {len(patients)} patients...") | |
| for p in patients: | |
| wtts_output = builder.build_wtts_string(p) | |
| # FIX: Prioritize document_id to match DataLoader logic | |
| pid = str(p.get('document_id') or p.get('patient_id') or p.get('hadm_id')) | |
| output_path = os.path.join(args.output_dir, f"{pid}.jsonl") | |
| with open(output_path, 'w') as f: | |
| json.dump({"person_id": pid, "text": wtts_output}, f) | |
| print(f"Successfully stored outputs in: {args.output_dir}") |