File size: 4,997 Bytes
371efe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Any


ID_PATTERN = re.compile(r"^q\d{3,}$")
DEFAULT_SYSTEM_PROMPT = (
    "You are a helpful assistant. Answer in the same language as the user's question unless explicitly asked otherwise."
)


class DatasetValidationError(ValueError):
    pass


def _load_raw_dataset(path: Path) -> Any:
    if not path.exists():
        raise FileNotFoundError(f"Dataset not found: {path}")
    with path.open("r", encoding="utf-8") as file:
        return json.load(file)


def _extract_records(raw_payload: Any) -> list[dict[str, Any]]:
    if isinstance(raw_payload, list):
        return raw_payload
    if isinstance(raw_payload, dict) and isinstance(raw_payload.get("questions"), list):
        return raw_payload["questions"]
    raise DatasetValidationError(
        "benchmark.json must be either a list of questions or an object with a 'questions' list."
    )


def _require_text_field(record: dict[str, Any], field_name: str, index: int) -> str:
    if field_name not in record:
        raise DatasetValidationError(f"Record #{index} is missing required field '{field_name}'.")
    value = str(record.get(field_name, "")).strip()
    if not value:
        raise DatasetValidationError(f"Record #{index} has empty '{field_name}'.")
    return value


def validate_question_records(records: list[dict[str, Any]]) -> None:
    seen_ids: set[str] = set()
    for index, record in enumerate(records, start=1):
        if not isinstance(record, dict):
            raise DatasetValidationError(f"Record #{index} must be an object.")

        question_id = _require_text_field(record, "id", index)
        if not ID_PATTERN.match(question_id):
            raise DatasetValidationError(
                f"Record #{index} has invalid id '{question_id}'. Expected format like q001."
            )
        if question_id in seen_ids:
            raise DatasetValidationError(f"Duplicate question id found: {question_id}")
        seen_ids.add(question_id)

        _require_text_field(record, "question", index)
        _require_text_field(record, "expected_answer", index)


def load_benchmark_payload(dataset_path: Path) -> dict[str, Any]:
    raw_payload = _load_raw_dataset(dataset_path)
    records = _extract_records(raw_payload)
    validate_question_records(records)

    questions: list[dict[str, Any]] = []
    for record in records:
        questions.append(
            {
                "id": str(record["id"]).strip(),
                "prompt": str(record["question"]).strip(),
                "expected_answer": str(record["expected_answer"]).strip(),
                "category": str(record.get("topic", "GENEL")).strip() or "GENEL",
                "expected_source": "benchmark_json",
                "confidence": 1.0,
                "hardness_level": str(record.get("hardness_level", "")).strip(),
                "why_prepared": str(record.get("why_prepared", "")).strip(),
            }
        )

    return {"instruction": DEFAULT_SYSTEM_PROMPT, "questions": questions}


def save_expected_answer(dataset_path: Path, question_id: str, expected_answer: str) -> None:
    normalized_answer = expected_answer.strip()
    if not normalized_answer:
        raise DatasetValidationError("expected_answer cannot be empty.")

    raw_payload = _load_raw_dataset(dataset_path)
    records = _extract_records(raw_payload)
    validate_question_records(records)

    found = False
    for record in records:
        if str(record.get("id", "")).strip() == question_id:
            record["expected_answer"] = normalized_answer
            found = True
            break

    if not found:
        raise KeyError(f"Question id not found: {question_id}")

    with dataset_path.open("w", encoding="utf-8") as file:
        json.dump(raw_payload, file, ensure_ascii=False, indent=2)


def backfill_missing_ids(dataset_path: Path) -> None:
    raw_payload = _load_raw_dataset(dataset_path)
    records = _extract_records(raw_payload)

    existing_numbers: set[int] = set()
    for record in records:
        raw_id = str(record.get("id", "")).strip()
        if ID_PATTERN.match(raw_id):
            existing_numbers.add(int(raw_id[1:]))

    next_number = 1 if not existing_numbers else (max(existing_numbers) + 1)
    changed = False

    for index, record in enumerate(records, start=1):
        raw_id = str(record.get("id", "")).strip()
        if raw_id:
            continue

        candidate_number = index if not existing_numbers else next_number
        while candidate_number in existing_numbers:
            candidate_number += 1

        record["id"] = f"q{candidate_number:03d}"
        existing_numbers.add(candidate_number)
        next_number = candidate_number + 1
        changed = True

    validate_question_records(records)

    if changed:
        with dataset_path.open("w", encoding="utf-8") as file:
            json.dump(raw_payload, file, ensure_ascii=False, indent=2)