| """Dataset quality gate un dedupe helperi apmācībai.""" |
|
|
| from __future__ import annotations |
|
|
| import re |
| from collections import Counter |
| from dataclasses import asdict, dataclass, field |
| from difflib import SequenceMatcher |
| from typing import Any |
|
|
| from maris_core.data.preprocessing import clean_text, record_to_training_text |
|
|
| _PLACEHOLDER_TEXTS = { |
| "n/a", |
| "na", |
| "none", |
| "null", |
| "placeholder", |
| "test", |
| "todo", |
| "tbd", |
| } |
| _WORD_RE = re.compile(r"\w+", re.UNICODE) |
| _REPEATED_SEGMENT_SPLIT_RE = re.compile(r"(?:\n+|(?<=[.!?])\s+)") |
| _MIN_REPEATED_SEGMENT_LENGTH = 8 |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class DatasetQualityGateConfig: |
| """Konfigurācija training kvalitātes vārtiem.""" |
|
|
| enabled: bool = True |
| dedupe_enabled: bool = True |
| min_text_chars: int = 4 |
| max_text_chars: int = 8192 |
| min_unique_char_ratio: float = 0.2 |
| min_completion_chars: int = 12 |
| max_prompt_echo_similarity: float = 0.92 |
| max_repeated_line_fraction: float = 0.34 |
|
|
|
|
| @dataclass(slots=True) |
| class DatasetQualitySplitReport: |
| """Viena split kvalitātes kopsavilkums.""" |
|
|
| split_name: str |
| input_records: int = 0 |
| kept_records: int = 0 |
| dropped_records: int = 0 |
| duplicates_removed: int = 0 |
| reasons: dict[str, int] = field(default_factory=dict) |
| sample_rejections: list[dict[str, str]] = field(default_factory=list) |
|
|
| def to_dict(self) -> dict[str, Any]: |
| return asdict(self) |
|
|
|
|
| @dataclass(slots=True) |
| class DatasetQualityReport: |
| """Pilns dataset quality gate artefakts.""" |
|
|
| artifact_type: str |
| config: dict[str, Any] |
| splits: dict[str, DatasetQualitySplitReport] |
|
|
| def to_dict(self) -> dict[str, Any]: |
| return { |
| "artifact_type": self.artifact_type, |
| "config": self.config, |
| "splits": {name: report.to_dict() for name, report in self.splits.items()}, |
| "input_records": sum(report.input_records for report in self.splits.values()), |
| "kept_records": sum(report.kept_records for report in self.splits.values()), |
| "dropped_records": sum(report.dropped_records for report in self.splits.values()), |
| "duplicates_removed": sum(report.duplicates_removed for report in self.splits.values()), |
| } |
|
|
|
|
| def apply_quality_gate_to_records( |
| records: list[dict[str, Any]], |
| *, |
| split_name: str, |
| config: DatasetQualityGateConfig, |
| ) -> tuple[list[dict[str, Any]], DatasetQualitySplitReport]: |
| """Filtrē ierakstus un noņem dublikātus pirms treniņa.""" |
|
|
| if not config.enabled: |
| report = DatasetQualitySplitReport( |
| split_name=split_name, |
| input_records=len(records), |
| kept_records=len(records), |
| ) |
| return [_normalize_record(record) for record in records], report |
|
|
| report = DatasetQualitySplitReport(split_name=split_name) |
| kept_records: list[dict[str, Any]] = [] |
| seen_signatures: set[str] = set() |
|
|
| for raw_record in records: |
| report.input_records += 1 |
| record = _normalize_record(raw_record) |
| training_text = record_to_training_text(record, max_chars=config.max_text_chars) |
| rejection_reason = _get_rejection_reason(training_text, record, config) |
| if rejection_reason is not None: |
| _record_rejection(report, rejection_reason, training_text) |
| continue |
|
|
| signature = _training_text_signature(training_text) |
| if config.dedupe_enabled and signature in seen_signatures: |
| report.duplicates_removed += 1 |
| _record_rejection(report, "duplicate_training_text", training_text) |
| continue |
|
|
| seen_signatures.add(signature) |
| kept_records.append(record) |
| report.kept_records += 1 |
|
|
| report.dropped_records = report.input_records - report.kept_records |
| return kept_records, report |
|
|
|
|
| def build_dataset_quality_report( |
| *, |
| config: DatasetQualityGateConfig, |
| train_report: DatasetQualitySplitReport, |
| eval_report: DatasetQualitySplitReport | None = None, |
| ) -> DatasetQualityReport: |
| """Izveido serializējamu kvalitātes artefaktu.""" |
|
|
| splits = {train_report.split_name: train_report} |
| if eval_report is not None: |
| splits[eval_report.split_name] = eval_report |
| return DatasetQualityReport( |
| artifact_type="dataset-quality-report", |
| config=asdict(config), |
| splits=splits, |
| ) |
|
|
|
|
| def _normalize_record(record: dict[str, Any]) -> dict[str, Any]: |
| normalized: dict[str, Any] = {} |
| for key, value in record.items(): |
| normalized[key] = _normalize_value(value) |
| return normalized |
|
|
|
|
| def _normalize_value(value: Any) -> Any: |
| if isinstance(value, str): |
| return clean_text(value) |
| if isinstance(value, dict): |
| return {str(key): _normalize_value(item) for key, item in value.items()} |
| if isinstance(value, list): |
| return [_normalize_value(item) for item in value] |
| return value |
|
|
|
|
| def _get_rejection_reason( |
| training_text: str, |
| record: dict[str, Any], |
| config: DatasetQualityGateConfig, |
| ) -> str | None: |
| normalized_text = clean_text(training_text) |
| if not normalized_text: |
| return "empty_training_text" |
| if len(normalized_text) < config.min_text_chars: |
| return "too_short" |
| if len(normalized_text) > config.max_text_chars: |
| return "too_long" |
|
|
| if normalized_text.casefold() in _PLACEHOLDER_TEXTS: |
| return "placeholder_text" |
|
|
| if _has_repeated_line_noise(normalized_text, config.max_repeated_line_fraction): |
| return "repeated_line_noise" |
|
|
| if _looks_repetitive(normalized_text, config.min_unique_char_ratio): |
| return "low_information_density" |
|
|
| conversation_rejection = _conversation_rejection_reason(record, config) |
| if conversation_rejection is not None: |
| return conversation_rejection |
|
|
| return None |
|
|
|
|
| def _looks_repetitive(value: str, min_unique_char_ratio: float) -> bool: |
| compact = "".join(character for character in value.casefold() if not character.isspace()) |
| if not compact: |
| return True |
| unique_ratio = len(set(compact)) / len(compact) |
| if unique_ratio >= min_unique_char_ratio: |
| return False |
|
|
| words = _WORD_RE.findall(value.casefold()) |
| if len(words) >= 6: |
| |
| |
| unique_word_ratio = len(set(words)) / len(words) |
| if unique_word_ratio >= 0.45: |
| return False |
|
|
| return True |
|
|
|
|
| def _conversation_rejection_reason( |
| record: dict[str, Any], |
| config: DatasetQualityGateConfig, |
| ) -> str | None: |
| prompt, completion = _conversation_pair(record) |
| if prompt is None or completion is None: |
| return None |
| if not prompt or not completion: |
| return "invalid_conversation_pair" |
| if prompt.casefold() == completion.casefold(): |
| return "invalid_conversation_pair" |
| if completion.casefold() in _PLACEHOLDER_TEXTS: |
| return "placeholder_response" |
| if len(completion) < config.min_completion_chars and len(prompt) >= config.min_completion_chars: |
| return "response_too_short" |
| if _looks_like_prompt_echo( |
| prompt, |
| completion, |
| max_prompt_echo_similarity=config.max_prompt_echo_similarity, |
| ): |
| return "prompt_echo_response" |
| return None |
|
|
|
|
| def _conversation_pair(record: dict[str, Any]) -> tuple[str | None, str | None]: |
| prompt = _coerce_quality_text( |
| record.get("user"), |
| record.get("prompt"), |
| record.get("instruction"), |
| record.get("input"), |
| ) |
| completion = _coerce_quality_text( |
| record.get("assistant"), |
| record.get("completion"), |
| record.get("response"), |
| record.get("output"), |
| ) |
| return prompt, completion |
|
|
|
|
| def _coerce_quality_text(*values: Any) -> str | None: |
| for value in values: |
| if isinstance(value, str): |
| normalized = clean_text(value) |
| if normalized: |
| return normalized |
| return None |
|
|
|
|
| def _looks_like_prompt_echo( |
| prompt: str, |
| completion: str, |
| *, |
| max_prompt_echo_similarity: float, |
| ) -> bool: |
| prompt_normalized = prompt.casefold() |
| completion_normalized = completion.casefold() |
| if prompt_normalized == completion_normalized: |
| return True |
| if ( |
| completion_normalized.startswith(prompt_normalized) |
| or prompt_normalized.startswith(completion_normalized) |
| ) and len(completion_normalized) <= int(len(prompt_normalized) * 1.2): |
| return True |
| prompt_tokens = prompt_normalized.split() |
| completion_tokens = completion_normalized.split() |
| if not prompt_tokens or not completion_tokens: |
| return False |
| shared_tokens = len(set(prompt_tokens) & set(completion_tokens)) |
| completion_overlap = shared_tokens / max(len(set(completion_tokens)), 1) |
| if completion_overlap < 0.8: |
| return False |
| similarity = SequenceMatcher(a=prompt_normalized, b=completion_normalized).ratio() |
| return similarity >= max_prompt_echo_similarity and len(completion_normalized) <= int( |
| len(prompt_normalized) * 1.2 |
| ) |
|
|
|
|
| def _has_repeated_line_noise(value: str, max_repeated_line_fraction: float) -> bool: |
| lines = [ |
| line.strip().casefold() for line in _REPEATED_SEGMENT_SPLIT_RE.split(value) if line.strip() |
| ] |
| if len(lines) < 3: |
| return False |
| repeated_counts = Counter(line for line in lines if len(line) >= _MIN_REPEATED_SEGMENT_LENGTH) |
| if not repeated_counts: |
| return False |
| return max(repeated_counts.values()) / len(lines) > max_repeated_line_fraction |
|
|
|
|
| def _training_text_signature(value: str) -> str: |
| return " ".join(value.casefold().split()) |
|
|
|
|
| def _record_rejection( |
| report: DatasetQualitySplitReport, |
| reason: str, |
| training_text: str, |
| ) -> None: |
| report.reasons[reason] = report.reasons.get(reason, 0) + 1 |
| if len(report.sample_rejections) < 5: |
| report.sample_rejections.append( |
| { |
| "reason": reason, |
| "preview": training_text[:160], |
| } |
| ) |
|
|