| """Standalone post-processing rules extracted from solar-eval multi_step pipeline.""" |
|
|
| import csv |
| import re |
| from collections import Counter |
| from pathlib import Path |
|
|
| |
|
|
|
|
| def load_vocabulary(csv_path: str) -> list[dict[str, str]]: |
| """Load vocabulary CSV with wrong/correct columns. |
| |
| Supports both: |
| - Headered CSV with columns 'wrong'/'before' and 'correct'/'after' |
| - Headerless CSV where first two columns are (wrong, correct) |
| |
| Args: |
| csv_path: Path to vocabulary CSV file. |
| |
| Returns: |
| List of dicts with 'wrong' and 'correct' keys. |
| """ |
| path = Path(csv_path) |
| if not path.exists(): |
| return [] |
|
|
| vocab: list[dict[str, str]] = [] |
| header_names = {"wrong", "correct", "before", "after"} |
|
|
| with open(path, encoding="utf-8") as f: |
| rows = list(csv.reader(f)) |
|
|
| if not rows: |
| return [] |
|
|
| |
| first_row_lower = [c.strip().lower() for c in rows[0]] |
| has_header = any(name in first_row_lower for name in header_names) |
|
|
| if has_header: |
| wrong_idx = next((i for i, c in enumerate(first_row_lower) if c in {"wrong", "before"}), 0) |
| correct_idx = next( |
| (i for i, c in enumerate(first_row_lower) if c in {"correct", "after"}), 1 |
| ) |
| data_rows = rows[1:] |
| else: |
| wrong_idx, correct_idx = 0, 1 |
| data_rows = rows |
|
|
| for row in data_rows: |
| if len(row) <= max(wrong_idx, correct_idx): |
| continue |
| wrong = row[wrong_idx].strip() |
| correct = row[correct_idx].strip() |
| if wrong and correct: |
| vocab.append({"wrong": wrong, "correct": correct}) |
| return vocab |
|
|
|
|
| def apply_vocabulary(text: str, vocab: list[dict[str, str]]) -> str: |
| """Apply vocabulary substitutions. |
| |
| Args: |
| text: Input text. |
| vocab: List of {wrong, correct} dicts. |
| |
| Returns: |
| Text with vocabulary corrections applied. |
| """ |
| for entry in vocab: |
| text = text.replace(entry["wrong"], entry["correct"]) |
| return text |
|
|
|
|
| |
|
|
| |
| PARTICLE_CORRECTION_MAP: dict[str, str] = { |
| "는": "은", |
| "가": "이", |
| "를": "을", |
| "와": "과", |
| "로": "으로", |
| "여": "이여", |
| "라": "이라", |
| "랑": "이랑", |
| "다": "이다", |
| "였다": "이었다", |
| "라면": "이라면", |
| "라서": "이라서", |
| "로부터": "으로부터", |
| } |
|
|
| |
| NEUTRAL_PARTICLES = {"만", "도", "께서"} |
|
|
|
|
| def _ends_with_consonant(char: str) -> bool: |
| """Check if a Korean character ends with a consonant (has jongseong).""" |
| if not char: |
| return False |
| code = ord(char) |
| if 0xAC00 <= code <= 0xD7A3: |
| return (code - 0xAC00) % 28 != 0 |
| return False |
|
|
|
|
| def apply_pronoun_postprocess(text: str, replacements: dict[str, str]) -> str: |
| """Replace words and correct trailing Korean particles. |
| |
| Args: |
| text: Input text. |
| replacements: Mapping of old_word -> new_word. |
| |
| Returns: |
| Text with replacements and corrected particles. |
| """ |
| |
| all_particles = sorted( |
| list(PARTICLE_CORRECTION_MAP.keys()) |
| + list(PARTICLE_CORRECTION_MAP.values()) |
| + list(NEUTRAL_PARTICLES), |
| key=len, |
| reverse=True, |
| ) |
| |
| seen: set[str] = set() |
| unique_particles: list[str] = [] |
| for p in all_particles: |
| if p not in seen: |
| seen.add(p) |
| unique_particles.append(p) |
|
|
| particle_pattern = "|".join(re.escape(p) for p in unique_particles) |
|
|
| for old, new in replacements.items(): |
| regex = re.compile(f"({re.escape(old)})({particle_pattern})?") |
|
|
| def _replace(match: re.Match, new_word: str = new) -> str: |
| particle = match.group(2) |
| if particle is None: |
| return new_word |
| |
| if new_word and _ends_with_consonant(new_word[-1]): |
| corrected = PARTICLE_CORRECTION_MAP.get(particle, particle) |
| return new_word + corrected |
| return new_word + particle |
|
|
| text = regex.sub(_replace, text) |
|
|
| return text |
|
|
|
|
| |
|
|
| _CURRENCY_UNITS = r"(원|명|개|대|위안|달러|유로|엔|톤|권|건|회|장|마리|그루|송이|필)" |
| _BOUNDARY = r'(?=[\s.,!?()"\'가-힣]|$)' |
| _NUMBER = r"(\d+(?:\.\d+)?)" |
| _SCALE = r"(만|억|조)" |
|
|
| |
| _CURRENCY_MULTI_RE = re.compile( |
| rf"{_NUMBER}{_SCALE}\s*(\d+){_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}" |
| ) |
| |
| _CURRENCY_SINGLE_RE = re.compile(rf"{_NUMBER}{_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}") |
|
|
|
|
| def apply_currency_compact(text: str) -> str: |
| """Remove spaces between numbers and currency/unit markers. |
| |
| Args: |
| text: Input text. |
| |
| Returns: |
| Text with compacted currency expressions. |
| """ |
| text = _CURRENCY_MULTI_RE.sub(r"\1\2\3\4\5", text) |
| text = _CURRENCY_SINGLE_RE.sub(r"\1\2\3", text) |
| return text |
|
|
|
|
| |
|
|
| _COMMA_RE = re.compile(r"(?<=\d),(?=\d)") |
|
|
|
|
| def apply_comma_removal(text: str) -> str: |
| """Remove thousand-separator commas from numbers. |
| |
| Args: |
| text: Input text. |
| |
| Returns: |
| Text with commas removed from numbers. |
| """ |
| return _COMMA_RE.sub("", text) |
|
|
|
|
| |
|
|
| |
| |
| |
| _UNIT_PATTERNS: list[tuple[str, str]] = [ |
| |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎟"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎠"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎢"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎡"), |
| |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎣"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎤"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎦"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎥"), |
| |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:μm|um)\b(?![A-Za-zµ])", r"\g<num>㎛"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm)\b(?![A-Za-zµ])", r"\g<num>㎜"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm)\b(?![A-Za-zµ])", r"\g<num>㎝"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*(?:km)\b(?![A-Za-zµ])", r"\g<num>㎞"), |
| |
| (r"(?P<num>\d+(?:\.\d+)?)\s*kg\b(?![A-Za-zµ])", r"\g<num>㎏"), |
| (r"(?P<num>\d+(?:\.\d+)?)\s*mg\b(?![A-Za-zµ])", r"\g<num>㎎"), |
| ] |
| _UNIT_COMPILED = [(re.compile(p), r) for p, r in _UNIT_PATTERNS] |
|
|
|
|
| def apply_unit_unicode(text: str) -> str: |
| """Convert `<number> + unit` tokens to compact Unicode equivalents. |
| |
| Only substitutes when preceded by a digit and followed by a non-letter |
| boundary, so tokens like "mmol" / "cmH2O" are left alone. Matches the |
| production Gradio regex set (area/volume/length/mass). |
| |
| Args: |
| text: Input text. |
| |
| Returns: |
| Text with unit strings replaced by Unicode characters. |
| """ |
| for pat, rep in _UNIT_COMPILED: |
| text = pat.sub(rep, text) |
| return text |
|
|
|
|
| |
|
|
|
|
| def _compute_lcs_indices(orig_tokens: list[str], pred_tokens: list[str]) -> list[tuple[int, int]]: |
| """Compute LCS matching indices between two token lists using DP.""" |
| m, n = len(orig_tokens), len(pred_tokens) |
| if m == 0 or n == 0: |
| return [] |
|
|
| |
| dp = [[0] * (n + 1) for _ in range(m + 1)] |
| for i in range(1, m + 1): |
| for j in range(1, n + 1): |
| if orig_tokens[i - 1] == pred_tokens[j - 1]: |
| dp[i][j] = dp[i - 1][j - 1] + 1 |
| else: |
| dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) |
|
|
| |
| pairs: list[tuple[int, int]] = [] |
| i, j = m, n |
| while i > 0 and j > 0: |
| if orig_tokens[i - 1] == pred_tokens[j - 1]: |
| pairs.append((i - 1, j - 1)) |
| i -= 1 |
| j -= 1 |
| elif dp[i - 1][j] >= dp[i][j - 1]: |
| i -= 1 |
| else: |
| j -= 1 |
| pairs.reverse() |
| return pairs |
|
|
|
|
| def apply_correction_filter( |
| text: str, |
| original: str, |
| max_char_diff: int = 2, |
| allow_spacing: bool = True, |
| ) -> str: |
| """Filter corrections, reverting changes that exceed max_char_diff. |
| |
| Tokenizes both texts, computes LCS, and for each gap between matching |
| tokens, checks if the correction is "safe" (small enough change). |
| Unsafe corrections are reverted to original. |
| |
| Args: |
| text: Corrected text. |
| original: Original text before correction. |
| max_char_diff: Maximum allowed character-level differences. |
| allow_spacing: Whether to allow spacing-only changes. |
| |
| Returns: |
| Filtered text with unsafe corrections reverted. |
| """ |
| orig_tokens = original.split() |
| pred_tokens = text.split() |
|
|
| if not orig_tokens or not pred_tokens: |
| return text |
|
|
| lcs_pairs = _compute_lcs_indices(orig_tokens, pred_tokens) |
|
|
| def is_safe(o_text: str, c_text: str) -> bool: |
| if o_text == c_text: |
| return True |
| if allow_spacing and o_text.replace(" ", "") == c_text.replace(" ", ""): |
| return True |
| if abs(len(o_text) - len(c_text)) <= 1: |
| min_len = min(len(o_text), len(c_text)) |
| d = sum(1 for i in range(min_len) if o_text[i] != c_text[i]) |
| d += abs(len(o_text) - len(c_text)) |
| if d <= max_char_diff: |
| return True |
| return False |
|
|
| result_tokens: list[str] = [] |
| prev_orig = -1 |
| prev_pred = -1 |
|
|
| for oi, pi in lcs_pairs: |
| |
| orig_gap = " ".join(orig_tokens[prev_orig + 1 : oi]) |
| pred_gap = " ".join(pred_tokens[prev_pred + 1 : pi]) |
|
|
| if orig_gap or pred_gap: |
| if is_safe(orig_gap, pred_gap): |
| result_tokens.append(pred_gap) |
| else: |
| result_tokens.append(orig_gap) |
|
|
| |
| result_tokens.append(pred_tokens[pi]) |
| prev_orig = oi |
| prev_pred = pi |
|
|
| |
| orig_gap = " ".join(orig_tokens[prev_orig + 1 :]) |
| pred_gap = " ".join(pred_tokens[prev_pred + 1 :]) |
| if orig_gap or pred_gap: |
| if is_safe(orig_gap, pred_gap): |
| result_tokens.append(pred_gap) |
| else: |
| result_tokens.append(orig_gap) |
|
|
| return " ".join(t for t in result_tokens if t) |
|
|
|
|
| |
|
|
| _PARA_SPLIT_RE = re.compile(r"\n+") |
| _WS_RE = re.compile(r"\s+") |
|
|
|
|
| def _normalize_paragraph(p: str) -> str: |
| """공백/탭/개행 압축 정규형. 중복 매칭용.""" |
| return _WS_RE.sub(" ", p).strip() |
|
|
|
|
| def _split_output_paragraphs(text: str) -> list[str]: |
| """Output 텍스트를 \\n+ 경계로 쪼개 비어있지 않은 문단 리스트 반환. |
| |
| 정규화하지 않은 원본 조각을 그대로 넘겨서, 재조립 시 모델이 쓴 구두점/ |
| 스페이싱 디테일이 최대한 보존되도록 한다 (strip 은 조립 시 \\n 으로 |
| 다시 감싸므로 안전). |
| """ |
| if not text: |
| return [] |
| parts = _PARA_SPLIT_RE.split(text) |
| return [p.strip() for p in parts if p.strip()] |
|
|
|
|
| def apply_paragraph_dedupe( |
| text: str, |
| original: str, |
| min_len: int = 40, |
| prefix_len: int = 30, |
| ) -> str: |
| """LLM이 뱉은 중복 문단을 제거한다. |
| |
| **언제 제거하는가** |
| - 같은 문단(정규화 후 완전일치) 이 output 에서 `input 등장 횟수 + 1` |
| 이상 등장하면 뒤쪽 것을 drop. |
| - 앞 ``prefix_len`` 자가 동일한 문단이 out > in 로 반복되면 drop |
| (모델이 미세 교정으로 다르게 뱉은 echo 잡기). |
| |
| **언제 보존하는가** |
| - 문단 길이 < ``min_len`` (◇소제목, 한 줄 인용 등 합법적 반복). |
| - ``output_count <= input_count`` (저자가 의도해서 원문에 이미 반복 |
| 포함. 저자 의도를 훼손하지 않음). |
| - 첫 등장은 항상 유지. 두 번째 이후 등장부터 조건에 맞으면 drop. |
| |
| Args: |
| text: 파이프라인 출력 텍스트. |
| original: 파이프라인 입력 (저자 원문). None 이면 빈 문자열로 간주해 |
| input_count=0. 이 경우 output 에서 2 회 이상 동일 문단 등장 → |
| 모두 drop. |
| min_len: 이 길이 미만 정규화 문단은 dedupe 에서 제외. |
| prefix_len: near-dup 탐지에 쓰는 앞자르기 길이. |
| |
| Returns: |
| Dedupe 된 텍스트. 문단 구분자는 ``\\n`` 으로 재조립 (기존 파이프라인 |
| 출력 관례와 동일). 중복이 없으면 입력을 그대로 돌려준다. |
| """ |
| if not text: |
| return text |
|
|
| |
| |
| |
| |
| |
| in_norm = _normalize_paragraph(original or "") |
|
|
| out_paras = _split_output_paragraphs(text) |
| out_exact_seen: Counter[str] = Counter() |
| out_prefix_seen: Counter[str] = Counter() |
|
|
| kept: list[str] = [] |
| any_dropped = False |
| for para in out_paras: |
| norm = _normalize_paragraph(para) |
| if len(norm) < min_len: |
| kept.append(para) |
| continue |
|
|
| prefix = norm[:prefix_len] |
| out_exact_seen[norm] += 1 |
| out_prefix_seen[prefix] += 1 |
|
|
| |
| |
| in_exact_count = in_norm.count(norm) if in_norm else 0 |
| in_prefix_count = in_norm.count(prefix) if in_norm else 0 |
|
|
| exact_dup = ( |
| out_exact_seen[norm] > in_exact_count |
| and out_exact_seen[norm] >= 2 |
| ) |
| near_dup = ( |
| out_prefix_seen[prefix] > in_prefix_count |
| and out_prefix_seen[prefix] >= 2 |
| ) |
|
|
| if exact_dup or near_dup: |
| any_dropped = True |
| continue |
|
|
| kept.append(para) |
|
|
| if not any_dropped: |
| return text |
|
|
| return "\n".join(kept) |
|
|