p14-space / src /triage_llm /data /build_datasets.py
perachon's picture
Deploy CPU FastAPI stub
c9dcc3b
from __future__ import annotations
import hashlib
import json
import random
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from datasets import Dataset, DatasetDict
from presidio_anonymizer.entities import OperatorConfig
from triage_llm.data.anonymize import PresidioAnonymizer
from triage_llm.schemas import DPORecord, MetadataSchema, SFTRecord
from triage_llm.utils import ensure_dir, read_jsonl, write_jsonl
@dataclass
class BuildDatasetsConfig:
input_dir: str
out_dir: str
seed: int = 42
split_ratios: tuple[float, float, float] = (0.9, 0.05, 0.05)
anonymize: bool = False
anonymize_lang_default: str = "fr"
anonymize_operator: str = "replace"
anonymize_new_value: str = "<REDACTED>"
export_hf: bool = True
clinical_eval_dir: str | None = None
def default_metadata_schema() -> MetadataSchema:
return MetadataSchema(
fields={
"id": "Identifiant unique",
"instruction/prompt": "Consigne ou prompt",
"input": "Contexte (optionnel)",
"output/chosen/rejected": "Réponse(s)",
"symptoms": "Liste de symptômes normalisés (optionnel)",
"history": "Antécédents (optionnel)",
"vitals": "Constantes (optionnel)",
"source": "Origine du dataset",
"lang": "Langue fr/en",
"confidence": "Niveau de confiance (optionnel)",
"pii_redacted": "PII supprimées (bool)",
}
)
def load_records_from_dir(input_dir: Path) -> tuple[list[SFTRecord], list[DPORecord]]:
sft: list[SFTRecord] = []
dpo: list[DPORecord] = []
for p in sorted(input_dir.glob("*.jsonl")):
rows = read_jsonl(p)
for row in rows:
if {"instruction", "output"}.issubset(row.keys()):
sft.append(SFTRecord.model_validate(row))
elif {"prompt", "chosen", "rejected"}.issubset(row.keys()):
dpo.append(DPORecord.model_validate(row))
return sft, dpo
def _sha256_file(path: Path) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def _now_utc_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _audit_write(path: Path, event: dict[str, Any]) -> None:
ensure_dir(path.parent)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(event, ensure_ascii=False) + "\n")
def _anonymize_sft_rows(
rows: list[dict[str, Any]],
lang_default: str,
operator: str,
new_value: str,
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
op = OperatorConfig(operator, {"new_value": new_value})
engine_fr = PresidioAnonymizer(language="fr", operators={"DEFAULT": op})
engine_en = PresidioAnonymizer(language="en", operators={"DEFAULT": op})
n_entities_total = 0
out: list[dict[str, Any]] = []
for row in rows:
lang = (row.get("lang") or lang_default).lower()
engine = engine_fr if lang == "fr" else engine_en
for key in ["instruction", "input", "output"]:
if not row.get(key):
continue
try:
res = engine.anonymize(str(row[key]))
n_entities_total += len(res.entities)
row[key] = res.text
except Exception:
# Fallback: leave text unchanged but keep pipeline running.
row[key] = str(row[key])
row["pii_redacted"] = True
out.append(row)
stats = {"records": len(rows), "entities_detected": n_entities_total}
return out, stats
def _anonymize_dpo_rows(
rows: list[dict[str, Any]],
lang_default: str,
operator: str,
new_value: str,
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
op = OperatorConfig(operator, {"new_value": new_value})
engine_fr = PresidioAnonymizer(language="fr", operators={"DEFAULT": op})
engine_en = PresidioAnonymizer(language="en", operators={"DEFAULT": op})
n_entities_total = 0
out: list[dict[str, Any]] = []
for row in rows:
lang = (row.get("lang") or lang_default).lower()
engine = engine_fr if lang == "fr" else engine_en
for key in ["prompt", "chosen", "rejected"]:
if not row.get(key):
continue
try:
res = engine.anonymize(str(row[key]))
n_entities_total += len(res.entities)
row[key] = res.text
except Exception:
row[key] = str(row[key])
row["pii_redacted"] = True
out.append(row)
stats = {"records": len(rows), "entities_detected": n_entities_total}
return out, stats
def split_rows(
rows: list[dict[str, Any]],
seed: int,
ratios: tuple[float, float, float] = (0.9, 0.05, 0.05),
):
assert abs(sum(ratios) - 1.0) < 1e-9
rng = random.Random(seed)
idx = list(range(len(rows)))
rng.shuffle(idx)
n = len(rows)
n_train = int(n * ratios[0])
n_val = int(n * ratios[1])
train = [rows[i] for i in idx[:n_train]]
val = [rows[i] for i in idx[n_train : n_train + n_val]]
test = [rows[i] for i in idx[n_train + n_val :]]
return train, val, test
def build_datasets(cfg: BuildDatasetsConfig) -> dict[str, Path]:
input_path = Path(cfg.input_dir)
out_path = ensure_dir(cfg.out_dir)
audit_path = out_path / "audit_log.jsonl"
_audit_write(
audit_path,
{
"ts": _now_utc_iso(),
"event": "build_start",
"input_dir": str(input_path),
"out_dir": str(out_path),
"seed": cfg.seed,
"split_ratios": cfg.split_ratios,
"anonymize": cfg.anonymize,
"export_hf": cfg.export_hf,
},
)
sft_records, dpo_records = load_records_from_dir(input_path)
sft_rows = [r.model_dump(mode="json") for r in sft_records]
dpo_rows = [r.model_dump(mode="json") for r in dpo_records]
_audit_write(
audit_path,
{
"ts": _now_utc_iso(),
"event": "loaded",
"sft_records": len(sft_rows),
"dpo_records": len(dpo_rows),
},
)
if cfg.anonymize:
sft_rows, sft_stats = _anonymize_sft_rows(
sft_rows,
lang_default=cfg.anonymize_lang_default,
operator=cfg.anonymize_operator,
new_value=cfg.anonymize_new_value,
)
dpo_rows, dpo_stats = _anonymize_dpo_rows(
dpo_rows,
lang_default=cfg.anonymize_lang_default,
operator=cfg.anonymize_operator,
new_value=cfg.anonymize_new_value,
)
_audit_write(
audit_path,
{
"ts": _now_utc_iso(),
"event": "anonymized",
"sft": sft_stats,
"dpo": dpo_stats,
"operator": cfg.anonymize_operator,
"new_value": cfg.anonymize_new_value,
},
)
sft_path = out_path / "sft.jsonl"
dpo_path = out_path / "dpo.jsonl"
write_jsonl(sft_path, sft_rows)
write_jsonl(dpo_path, dpo_rows)
schema = default_metadata_schema()
schema_path = out_path / "metadata_schema.json"
with open(schema_path, "w", encoding="utf-8") as f:
json.dump(schema.model_dump(mode="json"), f, ensure_ascii=False, indent=2)
splits_path = ensure_dir(out_path / "splits")
sft_train, sft_val, sft_test = split_rows(sft_rows, seed=cfg.seed, ratios=cfg.split_ratios)
dpo_train, dpo_val, dpo_test = split_rows(dpo_rows, seed=cfg.seed, ratios=cfg.split_ratios)
sft_train_path = splits_path / "sft_train.jsonl"
sft_val_path = splits_path / "sft_val.jsonl"
sft_test_path = splits_path / "sft_test.jsonl"
dpo_train_path = splits_path / "dpo_train.jsonl"
dpo_val_path = splits_path / "dpo_val.jsonl"
dpo_test_path = splits_path / "dpo_test.jsonl"
write_jsonl(sft_train_path, sft_train)
write_jsonl(sft_val_path, sft_val)
write_jsonl(sft_test_path, sft_test)
write_jsonl(dpo_train_path, dpo_train)
write_jsonl(dpo_val_path, dpo_val)
write_jsonl(dpo_test_path, dpo_test)
if cfg.export_hf:
sft_dd = DatasetDict(
{
"train": Dataset.from_list(sft_train),
"validation": Dataset.from_list(sft_val),
"test": Dataset.from_list(sft_test),
}
)
dpo_dd = DatasetDict(
{
"train": Dataset.from_list(dpo_train),
"validation": Dataset.from_list(dpo_val),
"test": Dataset.from_list(dpo_test),
}
)
hf_path = ensure_dir(out_path / "hf")
sft_hf_path = hf_path / "sft"
dpo_hf_path = hf_path / "dpo"
sft_dd.save_to_disk(str(sft_hf_path))
dpo_dd.save_to_disk(str(dpo_hf_path))
_audit_write(
audit_path,
{
"ts": _now_utc_iso(),
"event": "export_hf",
"sft_path": str(sft_hf_path),
"dpo_path": str(dpo_hf_path),
},
)
if cfg.clinical_eval_dir:
eval_in = Path(cfg.clinical_eval_dir)
eval_out = ensure_dir(out_path / "eval")
copied: list[str] = []
for p in sorted(eval_in.glob("*.jsonl")):
target = eval_out / p.name
target.write_text(p.read_text(encoding="utf-8"), encoding="utf-8")
copied.append(str(target))
_audit_write(
audit_path,
{"ts": _now_utc_iso(), "event": "eval_sets_copied", "files": copied},
)
_audit_write(
audit_path,
{
"ts": _now_utc_iso(),
"event": "build_end",
"outputs": {
"sft_jsonl": str(sft_path),
"dpo_jsonl": str(dpo_path),
"schema": str(schema_path),
"splits": str(splits_path),
"audit": str(audit_path),
},
"hashes": {
"sft_jsonl_sha256": _sha256_file(sft_path),
"dpo_jsonl_sha256": _sha256_file(dpo_path),
"schema_sha256": _sha256_file(schema_path),
},
},
)
return {
"sft": sft_path,
"dpo": dpo_path,
"schema": schema_path,
"splits": splits_path,
"audit": audit_path,
"hf": out_path / "hf" if cfg.export_hf else out_path,
}