File size: 4,456 Bytes
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Corpus filtering, safety, and quality heuristics."""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Iterable


EMAIL_RE = re.compile(r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE)
PHONE_RE = re.compile(r"(?:(?:\+?\d{1,3})?[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?){2}\d{4}")
SSN_RE = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
HTML_RE = re.compile(r"<[^>]+>")
MULTISPACE_RE = re.compile(r"[ \t]+")
NSFW_TERMS = {"porn", "explicit sex", "rape"}
HATE_TERMS = {"kill all", "ethnic cleansing"}
ALLOWED_LICENSES = {"permissive", "restricted"}
ALLOWED_LANGS = {"en", "es", "fr", "de", "hi", "zh", "ar", "pt"}


@dataclass(frozen=True)
class FilterConfig:
    """Policy controls for the filtering pipeline."""

    minimum_chars: int = 200
    maximum_chars: int = 200_000
    minimum_alpha_ratio: float = 0.45
    minimum_quality_score: float = 0.20
    language_confidence_threshold: float = 0.65


def normalize_text(text: str) -> str:
    """Strip tags and normalize whitespace."""
    text = HTML_RE.sub(" ", text)
    text = MULTISPACE_RE.sub(" ", text)
    return text.strip()


def detect_language(text: str) -> tuple[str, float]:
    """Use a light heuristic to assign a language code."""
    ascii_ratio = sum(ch.isascii() for ch in text) / max(len(text), 1)
    devanagari = sum("\u0900" <= ch <= "\u097f" for ch in text)
    cjk = sum("\u4e00" <= ch <= "\u9fff" for ch in text)
    arabic = sum("\u0600" <= ch <= "\u06ff" for ch in text)
    if cjk > 8:
        return "zh", 0.95
    if arabic > 8:
        return "ar", 0.95
    if devanagari > 8:
        return "hi", 0.95
    if ascii_ratio > 0.9:
        return "en", 0.80
    return "unknown", 0.40


def quality_score(text: str) -> float:
    """Score text using length, punctuation, and alphabetic density."""
    if not text:
        return 0.0
    alpha_ratio = sum(ch.isalpha() for ch in text) / len(text)
    punct_ratio = sum(ch in ".,;:!?()[]{}" for ch in text) / len(text)
    line_count = text.count("\n") + 1
    score = min(len(text) / 4000.0, 1.0) * 0.4 + alpha_ratio * 0.4 + min(punct_ratio * 8.0, 1.0) * 0.2
    if line_count < 2 and len(text) > 10_000:
        score *= 0.85
    return round(score, 4)


def quality_tier(score: float) -> str:
    """Map a numeric score to a quality tier."""
    if score >= 0.70:
        return "high"
    if score >= 0.40:
        return "medium"
    return "low"


def strip_pii(text: str) -> str:
    """Mask basic email, phone, and SSN patterns."""
    text = EMAIL_RE.sub("[EMAIL]", text)
    text = PHONE_RE.sub("[PHONE]", text)
    text = SSN_RE.sub("[SSN]", text)
    return text


def passes_safety_filter(text: str) -> bool:
    """Reject obviously unsafe content with simple keyword checks."""
    lower = text.lower()
    if any(term in lower for term in NSFW_TERMS):
        return False
    if any(term in lower for term in HATE_TERMS):
        return False
    return True


def license_allowed(category: str) -> bool:
    """Return whether the source license category is allowed."""
    return category in ALLOWED_LICENSES


def filter_record(record: dict[str, object], config: FilterConfig = FilterConfig()) -> dict[str, object] | None:
    """Apply the full filter pipeline to one record."""
    if not license_allowed(str(record.get("license_category", ""))):
        return None
    text = normalize_text(str(record.get("text", "")))
    if not (config.minimum_chars <= len(text) <= config.maximum_chars):
        return None
    lang, confidence = detect_language(text)
    if lang not in ALLOWED_LANGS or confidence < config.language_confidence_threshold:
        return None
    text = strip_pii(text)
    if not passes_safety_filter(text):
        return None
    score = quality_score(text)
    if score < config.minimum_quality_score:
        return None
    return {
        **record,
        "text": text,
        "lang": lang,
        "lang_confidence": confidence,
        "quality_score": score,
        "quality_tier": quality_tier(score),
        "token_count_estimate": max(1, len(text) // 4),
    }


def filter_corpus(records: Iterable[dict[str, object]], config: FilterConfig = FilterConfig()) -> list[dict[str, object]]:
    """Filter a corpus in memory."""
    kept: list[dict[str, object]] = []
    for record in records:
        filtered = filter_record(record, config)
        if filtered is not None:
            kept.append(filtered)
    return kept