"""Dataset quality gate un dedupe helperi apmācībai.""" from __future__ import annotations import re from collections import Counter from dataclasses import asdict, dataclass, field from difflib import SequenceMatcher from typing import Any from maris_core.data.preprocessing import clean_text, record_to_training_text _PLACEHOLDER_TEXTS = { "n/a", "na", "none", "null", "placeholder", "test", "todo", "tbd", } _WORD_RE = re.compile(r"\w+", re.UNICODE) _REPEATED_SEGMENT_SPLIT_RE = re.compile(r"(?:\n+|(?<=[.!?])\s+)") _MIN_REPEATED_SEGMENT_LENGTH = 8 @dataclass(slots=True, frozen=True) class DatasetQualityGateConfig: """Konfigurācija training kvalitātes vārtiem.""" enabled: bool = True dedupe_enabled: bool = True min_text_chars: int = 4 max_text_chars: int = 8192 min_unique_char_ratio: float = 0.2 min_completion_chars: int = 12 max_prompt_echo_similarity: float = 0.92 max_repeated_line_fraction: float = 0.34 @dataclass(slots=True) class DatasetQualitySplitReport: """Viena split kvalitātes kopsavilkums.""" split_name: str input_records: int = 0 kept_records: int = 0 dropped_records: int = 0 duplicates_removed: int = 0 reasons: dict[str, int] = field(default_factory=dict) sample_rejections: list[dict[str, str]] = field(default_factory=list) def to_dict(self) -> dict[str, Any]: return asdict(self) @dataclass(slots=True) class DatasetQualityReport: """Pilns dataset quality gate artefakts.""" artifact_type: str config: dict[str, Any] splits: dict[str, DatasetQualitySplitReport] def to_dict(self) -> dict[str, Any]: return { "artifact_type": self.artifact_type, "config": self.config, "splits": {name: report.to_dict() for name, report in self.splits.items()}, "input_records": sum(report.input_records for report in self.splits.values()), "kept_records": sum(report.kept_records for report in self.splits.values()), "dropped_records": sum(report.dropped_records for report in self.splits.values()), "duplicates_removed": sum(report.duplicates_removed for report in self.splits.values()), } def apply_quality_gate_to_records( records: list[dict[str, Any]], *, split_name: str, config: DatasetQualityGateConfig, ) -> tuple[list[dict[str, Any]], DatasetQualitySplitReport]: """Filtrē ierakstus un noņem dublikātus pirms treniņa.""" if not config.enabled: report = DatasetQualitySplitReport( split_name=split_name, input_records=len(records), kept_records=len(records), ) return [_normalize_record(record) for record in records], report report = DatasetQualitySplitReport(split_name=split_name) kept_records: list[dict[str, Any]] = [] seen_signatures: set[str] = set() for raw_record in records: report.input_records += 1 record = _normalize_record(raw_record) training_text = record_to_training_text(record, max_chars=config.max_text_chars) rejection_reason = _get_rejection_reason(training_text, record, config) if rejection_reason is not None: _record_rejection(report, rejection_reason, training_text) continue signature = _training_text_signature(training_text) if config.dedupe_enabled and signature in seen_signatures: report.duplicates_removed += 1 _record_rejection(report, "duplicate_training_text", training_text) continue seen_signatures.add(signature) kept_records.append(record) report.kept_records += 1 report.dropped_records = report.input_records - report.kept_records return kept_records, report def build_dataset_quality_report( *, config: DatasetQualityGateConfig, train_report: DatasetQualitySplitReport, eval_report: DatasetQualitySplitReport | None = None, ) -> DatasetQualityReport: """Izveido serializējamu kvalitātes artefaktu.""" splits = {train_report.split_name: train_report} if eval_report is not None: splits[eval_report.split_name] = eval_report return DatasetQualityReport( artifact_type="dataset-quality-report", config=asdict(config), splits=splits, ) def _normalize_record(record: dict[str, Any]) -> dict[str, Any]: normalized: dict[str, Any] = {} for key, value in record.items(): normalized[key] = _normalize_value(value) return normalized def _normalize_value(value: Any) -> Any: if isinstance(value, str): return clean_text(value) if isinstance(value, dict): return {str(key): _normalize_value(item) for key, item in value.items()} if isinstance(value, list): return [_normalize_value(item) for item in value] return value def _get_rejection_reason( training_text: str, record: dict[str, Any], config: DatasetQualityGateConfig, ) -> str | None: normalized_text = clean_text(training_text) if not normalized_text: return "empty_training_text" if len(normalized_text) < config.min_text_chars: return "too_short" if len(normalized_text) > config.max_text_chars: return "too_long" if normalized_text.casefold() in _PLACEHOLDER_TEXTS: return "placeholder_text" if _has_repeated_line_noise(normalized_text, config.max_repeated_line_fraction): return "repeated_line_noise" if _looks_repetitive(normalized_text, config.min_unique_char_ratio): return "low_information_density" conversation_rejection = _conversation_rejection_reason(record, config) if conversation_rejection is not None: return conversation_rejection return None def _looks_repetitive(value: str, min_unique_char_ratio: float) -> bool: compact = "".join(character for character in value.casefold() if not character.isspace()) if not compact: return True unique_ratio = len(set(compact)) / len(compact) if unique_ratio >= min_unique_char_ratio: return False words = _WORD_RE.findall(value.casefold()) if len(words) >= 6: # Garākiem strukturētiem promptiem simbolu dažādība var būt zema marķējuma/JSON dēļ, # tāpēc izmantojam arī vārdu dažādību, lai neatmestu jēgpilnus ierakstus. unique_word_ratio = len(set(words)) / len(words) if unique_word_ratio >= 0.45: return False return True def _conversation_rejection_reason( record: dict[str, Any], config: DatasetQualityGateConfig, ) -> str | None: prompt, completion = _conversation_pair(record) if prompt is None or completion is None: return None if not prompt or not completion: return "invalid_conversation_pair" if prompt.casefold() == completion.casefold(): return "invalid_conversation_pair" if completion.casefold() in _PLACEHOLDER_TEXTS: return "placeholder_response" if len(completion) < config.min_completion_chars and len(prompt) >= config.min_completion_chars: return "response_too_short" if _looks_like_prompt_echo( prompt, completion, max_prompt_echo_similarity=config.max_prompt_echo_similarity, ): return "prompt_echo_response" return None def _conversation_pair(record: dict[str, Any]) -> tuple[str | None, str | None]: prompt = _coerce_quality_text( record.get("user"), record.get("prompt"), record.get("instruction"), record.get("input"), ) completion = _coerce_quality_text( record.get("assistant"), record.get("completion"), record.get("response"), record.get("output"), ) return prompt, completion def _coerce_quality_text(*values: Any) -> str | None: for value in values: if isinstance(value, str): normalized = clean_text(value) if normalized: return normalized return None def _looks_like_prompt_echo( prompt: str, completion: str, *, max_prompt_echo_similarity: float, ) -> bool: prompt_normalized = prompt.casefold() completion_normalized = completion.casefold() if prompt_normalized == completion_normalized: return True if ( completion_normalized.startswith(prompt_normalized) or prompt_normalized.startswith(completion_normalized) ) and len(completion_normalized) <= int(len(prompt_normalized) * 1.2): return True prompt_tokens = prompt_normalized.split() completion_tokens = completion_normalized.split() if not prompt_tokens or not completion_tokens: return False shared_tokens = len(set(prompt_tokens) & set(completion_tokens)) completion_overlap = shared_tokens / max(len(set(completion_tokens)), 1) if completion_overlap < 0.8: return False similarity = SequenceMatcher(a=prompt_normalized, b=completion_normalized).ratio() return similarity >= max_prompt_echo_similarity and len(completion_normalized) <= int( len(prompt_normalized) * 1.2 ) def _has_repeated_line_noise(value: str, max_repeated_line_fraction: float) -> bool: lines = [ line.strip().casefold() for line in _REPEATED_SEGMENT_SPLIT_RE.split(value) if line.strip() ] if len(lines) < 3: return False repeated_counts = Counter(line for line in lines if len(line) >= _MIN_REPEATED_SEGMENT_LENGTH) if not repeated_counts: return False return max(repeated_counts.values()) / len(lines) > max_repeated_line_fraction def _training_text_signature(value: str) -> str: return " ".join(value.casefold().split()) def _record_rejection( report: DatasetQualitySplitReport, reason: str, training_text: str, ) -> None: report.reasons[reason] = report.reasons.get(reason, 0) + 1 if len(report.sample_rejections) < 5: report.sample_rejections.append( { "reason": reason, "preview": training_text[:160], } )