MarisUK's picture
Maris AI model sync
f440f03 verified
"""Dataset quality gate un dedupe helperi apmācībai."""
from __future__ import annotations
import re
from collections import Counter
from dataclasses import asdict, dataclass, field
from difflib import SequenceMatcher
from typing import Any
from maris_core.data.preprocessing import clean_text, record_to_training_text
_PLACEHOLDER_TEXTS = {
"n/a",
"na",
"none",
"null",
"placeholder",
"test",
"todo",
"tbd",
}
_WORD_RE = re.compile(r"\w+", re.UNICODE)
_REPEATED_SEGMENT_SPLIT_RE = re.compile(r"(?:\n+|(?<=[.!?])\s+)")
_MIN_REPEATED_SEGMENT_LENGTH = 8
@dataclass(slots=True, frozen=True)
class DatasetQualityGateConfig:
"""Konfigurācija training kvalitātes vārtiem."""
enabled: bool = True
dedupe_enabled: bool = True
min_text_chars: int = 4
max_text_chars: int = 8192
min_unique_char_ratio: float = 0.2
min_completion_chars: int = 12
max_prompt_echo_similarity: float = 0.92
max_repeated_line_fraction: float = 0.34
@dataclass(slots=True)
class DatasetQualitySplitReport:
"""Viena split kvalitātes kopsavilkums."""
split_name: str
input_records: int = 0
kept_records: int = 0
dropped_records: int = 0
duplicates_removed: int = 0
reasons: dict[str, int] = field(default_factory=dict)
sample_rejections: list[dict[str, str]] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@dataclass(slots=True)
class DatasetQualityReport:
"""Pilns dataset quality gate artefakts."""
artifact_type: str
config: dict[str, Any]
splits: dict[str, DatasetQualitySplitReport]
def to_dict(self) -> dict[str, Any]:
return {
"artifact_type": self.artifact_type,
"config": self.config,
"splits": {name: report.to_dict() for name, report in self.splits.items()},
"input_records": sum(report.input_records for report in self.splits.values()),
"kept_records": sum(report.kept_records for report in self.splits.values()),
"dropped_records": sum(report.dropped_records for report in self.splits.values()),
"duplicates_removed": sum(report.duplicates_removed for report in self.splits.values()),
}
def apply_quality_gate_to_records(
records: list[dict[str, Any]],
*,
split_name: str,
config: DatasetQualityGateConfig,
) -> tuple[list[dict[str, Any]], DatasetQualitySplitReport]:
"""Filtrē ierakstus un noņem dublikātus pirms treniņa."""
if not config.enabled:
report = DatasetQualitySplitReport(
split_name=split_name,
input_records=len(records),
kept_records=len(records),
)
return [_normalize_record(record) for record in records], report
report = DatasetQualitySplitReport(split_name=split_name)
kept_records: list[dict[str, Any]] = []
seen_signatures: set[str] = set()
for raw_record in records:
report.input_records += 1
record = _normalize_record(raw_record)
training_text = record_to_training_text(record, max_chars=config.max_text_chars)
rejection_reason = _get_rejection_reason(training_text, record, config)
if rejection_reason is not None:
_record_rejection(report, rejection_reason, training_text)
continue
signature = _training_text_signature(training_text)
if config.dedupe_enabled and signature in seen_signatures:
report.duplicates_removed += 1
_record_rejection(report, "duplicate_training_text", training_text)
continue
seen_signatures.add(signature)
kept_records.append(record)
report.kept_records += 1
report.dropped_records = report.input_records - report.kept_records
return kept_records, report
def build_dataset_quality_report(
*,
config: DatasetQualityGateConfig,
train_report: DatasetQualitySplitReport,
eval_report: DatasetQualitySplitReport | None = None,
) -> DatasetQualityReport:
"""Izveido serializējamu kvalitātes artefaktu."""
splits = {train_report.split_name: train_report}
if eval_report is not None:
splits[eval_report.split_name] = eval_report
return DatasetQualityReport(
artifact_type="dataset-quality-report",
config=asdict(config),
splits=splits,
)
def _normalize_record(record: dict[str, Any]) -> dict[str, Any]:
normalized: dict[str, Any] = {}
for key, value in record.items():
normalized[key] = _normalize_value(value)
return normalized
def _normalize_value(value: Any) -> Any:
if isinstance(value, str):
return clean_text(value)
if isinstance(value, dict):
return {str(key): _normalize_value(item) for key, item in value.items()}
if isinstance(value, list):
return [_normalize_value(item) for item in value]
return value
def _get_rejection_reason(
training_text: str,
record: dict[str, Any],
config: DatasetQualityGateConfig,
) -> str | None:
normalized_text = clean_text(training_text)
if not normalized_text:
return "empty_training_text"
if len(normalized_text) < config.min_text_chars:
return "too_short"
if len(normalized_text) > config.max_text_chars:
return "too_long"
if normalized_text.casefold() in _PLACEHOLDER_TEXTS:
return "placeholder_text"
if _has_repeated_line_noise(normalized_text, config.max_repeated_line_fraction):
return "repeated_line_noise"
if _looks_repetitive(normalized_text, config.min_unique_char_ratio):
return "low_information_density"
conversation_rejection = _conversation_rejection_reason(record, config)
if conversation_rejection is not None:
return conversation_rejection
return None
def _looks_repetitive(value: str, min_unique_char_ratio: float) -> bool:
compact = "".join(character for character in value.casefold() if not character.isspace())
if not compact:
return True
unique_ratio = len(set(compact)) / len(compact)
if unique_ratio >= min_unique_char_ratio:
return False
words = _WORD_RE.findall(value.casefold())
if len(words) >= 6:
# Garākiem strukturētiem promptiem simbolu dažādība var būt zema marķējuma/JSON dēļ,
# tāpēc izmantojam arī vārdu dažādību, lai neatmestu jēgpilnus ierakstus.
unique_word_ratio = len(set(words)) / len(words)
if unique_word_ratio >= 0.45:
return False
return True
def _conversation_rejection_reason(
record: dict[str, Any],
config: DatasetQualityGateConfig,
) -> str | None:
prompt, completion = _conversation_pair(record)
if prompt is None or completion is None:
return None
if not prompt or not completion:
return "invalid_conversation_pair"
if prompt.casefold() == completion.casefold():
return "invalid_conversation_pair"
if completion.casefold() in _PLACEHOLDER_TEXTS:
return "placeholder_response"
if len(completion) < config.min_completion_chars and len(prompt) >= config.min_completion_chars:
return "response_too_short"
if _looks_like_prompt_echo(
prompt,
completion,
max_prompt_echo_similarity=config.max_prompt_echo_similarity,
):
return "prompt_echo_response"
return None
def _conversation_pair(record: dict[str, Any]) -> tuple[str | None, str | None]:
prompt = _coerce_quality_text(
record.get("user"),
record.get("prompt"),
record.get("instruction"),
record.get("input"),
)
completion = _coerce_quality_text(
record.get("assistant"),
record.get("completion"),
record.get("response"),
record.get("output"),
)
return prompt, completion
def _coerce_quality_text(*values: Any) -> str | None:
for value in values:
if isinstance(value, str):
normalized = clean_text(value)
if normalized:
return normalized
return None
def _looks_like_prompt_echo(
prompt: str,
completion: str,
*,
max_prompt_echo_similarity: float,
) -> bool:
prompt_normalized = prompt.casefold()
completion_normalized = completion.casefold()
if prompt_normalized == completion_normalized:
return True
if (
completion_normalized.startswith(prompt_normalized)
or prompt_normalized.startswith(completion_normalized)
) and len(completion_normalized) <= int(len(prompt_normalized) * 1.2):
return True
prompt_tokens = prompt_normalized.split()
completion_tokens = completion_normalized.split()
if not prompt_tokens or not completion_tokens:
return False
shared_tokens = len(set(prompt_tokens) & set(completion_tokens))
completion_overlap = shared_tokens / max(len(set(completion_tokens)), 1)
if completion_overlap < 0.8:
return False
similarity = SequenceMatcher(a=prompt_normalized, b=completion_normalized).ratio()
return similarity >= max_prompt_echo_similarity and len(completion_normalized) <= int(
len(prompt_normalized) * 1.2
)
def _has_repeated_line_noise(value: str, max_repeated_line_fraction: float) -> bool:
lines = [
line.strip().casefold() for line in _REPEATED_SEGMENT_SPLIT_RE.split(value) if line.strip()
]
if len(lines) < 3:
return False
repeated_counts = Counter(line for line in lines if len(line) >= _MIN_REPEATED_SEGMENT_LENGTH)
if not repeated_counts:
return False
return max(repeated_counts.values()) / len(lines) > max_repeated_line_fraction
def _training_text_signature(value: str) -> str:
return " ".join(value.casefold().split())
def _record_rejection(
report: DatasetQualitySplitReport,
reason: str,
training_text: str,
) -> None:
report.reasons[reason] = report.reasons.get(reason, 0) + 1
if len(report.sample_rejections) < 5:
report.sample_rejections.append(
{
"reason": reason,
"preview": training_text[:160],
}
)