Spaces:
Sleeping
Sleeping
| import xml.etree.ElementTree as ET | |
| import json | |
| import sys | |
| import os | |
| from ..services.indexing import create_symptom_index | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Update path constants | |
| BASE_DIR = os.path.dirname(os.path.dirname(__file__)) | |
| DATA_DIR = os.path.join(BASE_DIR, "data") | |
| ICD_DIR = os.path.join(DATA_DIR, "icd10cm_tabular_2025") | |
| DEFAULT_XML_PATH = os.path.join(ICD_DIR, "icd10cm_tabular_2025.xml") | |
| PROCESSED_DIR = os.path.join(DATA_DIR, "processed") | |
| def main(xml_path=DEFAULT_XML_PATH): | |
| # Create processed directory if it doesn't exist | |
| os.makedirs(PROCESSED_DIR, exist_ok=True) | |
| if not os.path.isfile(xml_path): | |
| print(f"ERROR: cannot find tabular XML at '{xml_path}'") | |
| sys.exit(1) | |
| tree = ET.parse(xml_path) | |
| root = tree.getroot() | |
| icd_to_description = {} | |
| # Iterate over every <diag> in the entire file, recursively. | |
| # Each <diag> has: | |
| # • <name> (the ICD-10 code) | |
| # • <desc> (the human-readable description) | |
| # • zero or more nested <diag> children (sub-codes). | |
| for diag in root.iter("diag"): | |
| name_elem = diag.find("name") | |
| desc_elem = diag.find("desc") | |
| if name_elem is None or desc_elem is None: | |
| continue | |
| # Some <diag> nodes might have <name/> or <desc/> with no text; skip those. | |
| if name_elem.text is None or desc_elem.text is None: | |
| continue | |
| code = name_elem.text.strip() | |
| description = desc_elem.text.strip() | |
| # Only store non-empty strings: | |
| if code and description: | |
| icd_to_description[code] = description | |
| # Write out a flat JSON mapping code → description | |
| out_path = os.path.join(PROCESSED_DIR, "icd_to_description.json") | |
| with open(out_path, "w", encoding="utf-8") as fp: | |
| json.dump(icd_to_description, fp, indent=2, ensure_ascii=False) | |
| print(f"Wrote {len(icd_to_description)} code entries to {out_path}") | |
| # Move this outside the main() function | |
| symptom_index = None | |
| if __name__ == "__main__": | |
| if len(sys.argv) > 1: | |
| main(sys.argv[1]) | |
| else: | |
| main() # Use default path | |
| symptom_index = create_symptom_index() | |
| # Test multiple queries | |
| test_queries = [ | |
| "persistent cough with fever", | |
| "severe headache with nausea", | |
| "lower back pain", | |
| "difficulty breathing" | |
| ] | |
| print("\nTesting symptom matching:") | |
| print("-" * 50) | |
| for query in test_queries: | |
| response = symptom_index.as_query_engine().query(query) | |
| print(f"\nQuery: {query}") | |
| print(f"Relevant ICD-10 codes:") | |
| print(str(response)) | |
| print("-" * 50) | |