File size: 10,162 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""Dataset quality gate un dedupe helperi apmācībai."""

from __future__ import annotations

import re
from collections import Counter
from dataclasses import asdict, dataclass, field
from difflib import SequenceMatcher
from typing import Any

from maris_core.data.preprocessing import clean_text, record_to_training_text

_PLACEHOLDER_TEXTS = {
    "n/a",
    "na",
    "none",
    "null",
    "placeholder",
    "test",
    "todo",
    "tbd",
}
_WORD_RE = re.compile(r"\w+", re.UNICODE)
_REPEATED_SEGMENT_SPLIT_RE = re.compile(r"(?:\n+|(?<=[.!?])\s+)")
_MIN_REPEATED_SEGMENT_LENGTH = 8


@dataclass(slots=True, frozen=True)
class DatasetQualityGateConfig:
    """Konfigurācija training kvalitātes vārtiem."""

    enabled: bool = True
    dedupe_enabled: bool = True
    min_text_chars: int = 4
    max_text_chars: int = 8192
    min_unique_char_ratio: float = 0.2
    min_completion_chars: int = 12
    max_prompt_echo_similarity: float = 0.92
    max_repeated_line_fraction: float = 0.34


@dataclass(slots=True)
class DatasetQualitySplitReport:
    """Viena split kvalitātes kopsavilkums."""

    split_name: str
    input_records: int = 0
    kept_records: int = 0
    dropped_records: int = 0
    duplicates_removed: int = 0
    reasons: dict[str, int] = field(default_factory=dict)
    sample_rejections: list[dict[str, str]] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)


@dataclass(slots=True)
class DatasetQualityReport:
    """Pilns dataset quality gate artefakts."""

    artifact_type: str
    config: dict[str, Any]
    splits: dict[str, DatasetQualitySplitReport]

    def to_dict(self) -> dict[str, Any]:
        return {
            "artifact_type": self.artifact_type,
            "config": self.config,
            "splits": {name: report.to_dict() for name, report in self.splits.items()},
            "input_records": sum(report.input_records for report in self.splits.values()),
            "kept_records": sum(report.kept_records for report in self.splits.values()),
            "dropped_records": sum(report.dropped_records for report in self.splits.values()),
            "duplicates_removed": sum(report.duplicates_removed for report in self.splits.values()),
        }


def apply_quality_gate_to_records(
    records: list[dict[str, Any]],
    *,
    split_name: str,
    config: DatasetQualityGateConfig,
) -> tuple[list[dict[str, Any]], DatasetQualitySplitReport]:
    """Filtrē ierakstus un noņem dublikātus pirms treniņa."""

    if not config.enabled:
        report = DatasetQualitySplitReport(
            split_name=split_name,
            input_records=len(records),
            kept_records=len(records),
        )
        return [_normalize_record(record) for record in records], report

    report = DatasetQualitySplitReport(split_name=split_name)
    kept_records: list[dict[str, Any]] = []
    seen_signatures: set[str] = set()

    for raw_record in records:
        report.input_records += 1
        record = _normalize_record(raw_record)
        training_text = record_to_training_text(record, max_chars=config.max_text_chars)
        rejection_reason = _get_rejection_reason(training_text, record, config)
        if rejection_reason is not None:
            _record_rejection(report, rejection_reason, training_text)
            continue

        signature = _training_text_signature(training_text)
        if config.dedupe_enabled and signature in seen_signatures:
            report.duplicates_removed += 1
            _record_rejection(report, "duplicate_training_text", training_text)
            continue

        seen_signatures.add(signature)
        kept_records.append(record)
        report.kept_records += 1

    report.dropped_records = report.input_records - report.kept_records
    return kept_records, report


def build_dataset_quality_report(
    *,
    config: DatasetQualityGateConfig,
    train_report: DatasetQualitySplitReport,
    eval_report: DatasetQualitySplitReport | None = None,
) -> DatasetQualityReport:
    """Izveido serializējamu kvalitātes artefaktu."""

    splits = {train_report.split_name: train_report}
    if eval_report is not None:
        splits[eval_report.split_name] = eval_report
    return DatasetQualityReport(
        artifact_type="dataset-quality-report",
        config=asdict(config),
        splits=splits,
    )


def _normalize_record(record: dict[str, Any]) -> dict[str, Any]:
    normalized: dict[str, Any] = {}
    for key, value in record.items():
        normalized[key] = _normalize_value(value)
    return normalized


def _normalize_value(value: Any) -> Any:
    if isinstance(value, str):
        return clean_text(value)
    if isinstance(value, dict):
        return {str(key): _normalize_value(item) for key, item in value.items()}
    if isinstance(value, list):
        return [_normalize_value(item) for item in value]
    return value


def _get_rejection_reason(
    training_text: str,
    record: dict[str, Any],
    config: DatasetQualityGateConfig,
) -> str | None:
    normalized_text = clean_text(training_text)
    if not normalized_text:
        return "empty_training_text"
    if len(normalized_text) < config.min_text_chars:
        return "too_short"
    if len(normalized_text) > config.max_text_chars:
        return "too_long"

    if normalized_text.casefold() in _PLACEHOLDER_TEXTS:
        return "placeholder_text"

    if _has_repeated_line_noise(normalized_text, config.max_repeated_line_fraction):
        return "repeated_line_noise"

    if _looks_repetitive(normalized_text, config.min_unique_char_ratio):
        return "low_information_density"

    conversation_rejection = _conversation_rejection_reason(record, config)
    if conversation_rejection is not None:
        return conversation_rejection

    return None


def _looks_repetitive(value: str, min_unique_char_ratio: float) -> bool:
    compact = "".join(character for character in value.casefold() if not character.isspace())
    if not compact:
        return True
    unique_ratio = len(set(compact)) / len(compact)
    if unique_ratio >= min_unique_char_ratio:
        return False

    words = _WORD_RE.findall(value.casefold())
    if len(words) >= 6:
        # Garākiem strukturētiem promptiem simbolu dažādība var būt zema marķējuma/JSON dēļ,
        # tāpēc izmantojam arī vārdu dažādību, lai neatmestu jēgpilnus ierakstus.
        unique_word_ratio = len(set(words)) / len(words)
        if unique_word_ratio >= 0.45:
            return False

    return True


def _conversation_rejection_reason(
    record: dict[str, Any],
    config: DatasetQualityGateConfig,
) -> str | None:
    prompt, completion = _conversation_pair(record)
    if prompt is None or completion is None:
        return None
    if not prompt or not completion:
        return "invalid_conversation_pair"
    if prompt.casefold() == completion.casefold():
        return "invalid_conversation_pair"
    if completion.casefold() in _PLACEHOLDER_TEXTS:
        return "placeholder_response"
    if len(completion) < config.min_completion_chars and len(prompt) >= config.min_completion_chars:
        return "response_too_short"
    if _looks_like_prompt_echo(
        prompt,
        completion,
        max_prompt_echo_similarity=config.max_prompt_echo_similarity,
    ):
        return "prompt_echo_response"
    return None


def _conversation_pair(record: dict[str, Any]) -> tuple[str | None, str | None]:
    prompt = _coerce_quality_text(
        record.get("user"),
        record.get("prompt"),
        record.get("instruction"),
        record.get("input"),
    )
    completion = _coerce_quality_text(
        record.get("assistant"),
        record.get("completion"),
        record.get("response"),
        record.get("output"),
    )
    return prompt, completion


def _coerce_quality_text(*values: Any) -> str | None:
    for value in values:
        if isinstance(value, str):
            normalized = clean_text(value)
            if normalized:
                return normalized
    return None


def _looks_like_prompt_echo(
    prompt: str,
    completion: str,
    *,
    max_prompt_echo_similarity: float,
) -> bool:
    prompt_normalized = prompt.casefold()
    completion_normalized = completion.casefold()
    if prompt_normalized == completion_normalized:
        return True
    if (
        completion_normalized.startswith(prompt_normalized)
        or prompt_normalized.startswith(completion_normalized)
    ) and len(completion_normalized) <= int(len(prompt_normalized) * 1.2):
        return True
    prompt_tokens = prompt_normalized.split()
    completion_tokens = completion_normalized.split()
    if not prompt_tokens or not completion_tokens:
        return False
    shared_tokens = len(set(prompt_tokens) & set(completion_tokens))
    completion_overlap = shared_tokens / max(len(set(completion_tokens)), 1)
    if completion_overlap < 0.8:
        return False
    similarity = SequenceMatcher(a=prompt_normalized, b=completion_normalized).ratio()
    return similarity >= max_prompt_echo_similarity and len(completion_normalized) <= int(
        len(prompt_normalized) * 1.2
    )


def _has_repeated_line_noise(value: str, max_repeated_line_fraction: float) -> bool:
    lines = [
        line.strip().casefold() for line in _REPEATED_SEGMENT_SPLIT_RE.split(value) if line.strip()
    ]
    if len(lines) < 3:
        return False
    repeated_counts = Counter(line for line in lines if len(line) >= _MIN_REPEATED_SEGMENT_LENGTH)
    if not repeated_counts:
        return False
    return max(repeated_counts.values()) / len(lines) > max_repeated_line_fraction


def _training_text_signature(value: str) -> str:
    return " ".join(value.casefold().split())


def _record_rejection(
    report: DatasetQualitySplitReport,
    reason: str,
    training_text: str,
) -> None:
    report.reasons[reason] = report.reasons.get(reason, 0) + 1
    if len(report.sample_rejections) < 5:
        report.sample_rejections.append(
            {
                "reason": reason,
                "preview": training_text[:160],
            }
        )