diabetesLLM / core /data_loader.py
KS00Max's picture
first commit
f27bb68
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import yaml
from .config import DATA_DIR
from .models import ClarifyingChoice, ClarifyingQuestion, IntentMeta, PageMeta, Passage, Section
logger = logging.getLogger(__name__)
def _load_yaml(path: Path):
if not path.exists():
raise FileNotFoundError(f"Required data file not found: {path}")
with path.open("r", encoding="utf-8") as f:
return yaml.safe_load(f) or []
def load_sections(data_dir: Path = DATA_DIR) -> Dict[str, Section]:
sections: Dict[str, Section] = {}
for row in _load_yaml(data_dir / "sections.yaml"):
sec = Section(
section_id=str(row["section_id"]),
title=str(row["title"]),
parent_id=row.get("parent_id"),
)
if sec.section_id in sections:
logger.warning("Duplicate section id found: %s", sec.section_id)
sections[sec.section_id] = sec
return sections
def load_page_index(data_dir: Path = DATA_DIR) -> Dict[str, PageMeta]:
pages: Dict[str, PageMeta] = {}
for row in _load_yaml(data_dir / "page_index.yaml"):
page = PageMeta(
page_id=str(row["page_id"]),
guideline_id=str(row["guideline_id"]),
section_id=str(row["section_id"]),
title=str(row["title"]),
summary=str(row["summary"]),
intent_ids=list(row.get("intent_ids", [])),
)
if page.page_id in pages:
logger.warning("Duplicate page id found: %s", page.page_id)
pages[page.page_id] = page
return pages
def load_passages(data_dir: Path = DATA_DIR) -> Dict[str, Passage]:
passages: Dict[str, Passage] = {}
for row in _load_yaml(data_dir / "passages.yaml"):
passage = Passage(
passage_id=str(row["passage_id"]),
guideline_id=str(row["guideline_id"]),
page_id=str(row["page_id"]),
section_id=str(row["section_id"]),
order_in_section=int(row["order_in_section"]),
text=str(row["text"]),
source_page=int(row["source_page"]),
source_lines=list(row.get("source_lines", [])),
tags=list(row.get("tags", [])),
)
if passage.passage_id in passages:
logger.warning("Duplicate passage id found: %s", passage.passage_id)
passages[passage.passage_id] = passage
return passages
def load_intents(data_dir: Path = DATA_DIR) -> Dict[str, IntentMeta]:
intents: Dict[str, IntentMeta] = {}
for row in _load_yaml(data_dir / "intents.yaml"):
intent = IntentMeta(
intent_id=str(row["intent_id"]),
name=str(row["name"]),
description=str(row["description"]),
topic_group=str(row["topic_group"]),
guideline_ids=list(row.get("guideline_ids", [])),
primary_section_ids=list(row.get("primary_section_ids", [])),
example_questions=list(row.get("example_questions", [])),
)
if intent.intent_id in intents:
logger.warning("Duplicate intent id found: %s", intent.intent_id)
intents[intent.intent_id] = intent
return intents
def load_clarifying(path: Path) -> List[ClarifyingQuestion]:
questions: List[ClarifyingQuestion] = []
for row in _load_yaml(path):
choices = [
ClarifyingChoice(
id=str(choice["id"]),
text=str(choice["text"]),
update_state=dict(choice.get("update_state", {})),
)
for choice in row.get("choices", [])
]
questions.append(
ClarifyingQuestion(
question_id=str(row["question_id"]),
applies_to_topic_groups=list(row.get("applies_to_topic_groups", [])),
text=str(row["text"]),
choices=choices,
)
)
return questions
def load_state_schema(data_dir: Path = DATA_DIR) -> Tuple[dict, dict]:
"""Return (schema, defaults) where defaults fill enums with first entry and bool with False."""
schema_path = data_dir / "state_schema.yaml"
raw = _load_yaml(schema_path)
data = raw[0] if isinstance(raw, list) and raw else raw
fields = data.get("fields", {}) if isinstance(data, dict) else {}
defaults = {}
for key, spec in fields.items():
if spec.get("type") == "bool":
defaults[key] = None
elif spec.get("type") == "str":
enum = spec.get("enum")
defaults[key] = enum[0] if enum else ""
else:
defaults[key] = None
return fields, defaults