"""Standalone post-processing rules extracted from solar-eval multi_step pipeline.""" import csv import re from collections import Counter from pathlib import Path # --- Vocabulary substitution --- def load_vocabulary(csv_path: str) -> list[dict[str, str]]: """Load vocabulary CSV with wrong/correct columns. Supports both: - Headered CSV with columns 'wrong'/'before' and 'correct'/'after' - Headerless CSV where first two columns are (wrong, correct) Args: csv_path: Path to vocabulary CSV file. Returns: List of dicts with 'wrong' and 'correct' keys. """ path = Path(csv_path) if not path.exists(): return [] vocab: list[dict[str, str]] = [] header_names = {"wrong", "correct", "before", "after"} with open(path, encoding="utf-8") as f: rows = list(csv.reader(f)) if not rows: return [] # Detect header: first row contains any of the expected column names first_row_lower = [c.strip().lower() for c in rows[0]] has_header = any(name in first_row_lower for name in header_names) if has_header: wrong_idx = next((i for i, c in enumerate(first_row_lower) if c in {"wrong", "before"}), 0) correct_idx = next( (i for i, c in enumerate(first_row_lower) if c in {"correct", "after"}), 1 ) data_rows = rows[1:] else: wrong_idx, correct_idx = 0, 1 data_rows = rows for row in data_rows: if len(row) <= max(wrong_idx, correct_idx): continue wrong = row[wrong_idx].strip() correct = row[correct_idx].strip() if wrong and correct: vocab.append({"wrong": wrong, "correct": correct}) return vocab def apply_vocabulary(text: str, vocab: list[dict[str, str]]) -> str: """Apply vocabulary substitutions. Args: text: Input text. vocab: List of {wrong, correct} dicts. Returns: Text with vocabulary corrections applied. """ for entry in vocab: text = text.replace(entry["wrong"], entry["correct"]) return text # --- Pronoun/title post-processing --- # Korean particle correction map: vowel-ending -> consonant-ending PARTICLE_CORRECTION_MAP: dict[str, str] = { "는": "은", "가": "이", "를": "을", "와": "과", "로": "으로", "여": "이여", "라": "이라", "랑": "이랑", "다": "이다", "였다": "이었다", "라면": "이라면", "라서": "이라서", "로부터": "으로부터", } # Particles that are valid for both vowel and consonant endings NEUTRAL_PARTICLES = {"만", "도", "께서"} def _ends_with_consonant(char: str) -> bool: """Check if a Korean character ends with a consonant (has jongseong).""" if not char: return False code = ord(char) if 0xAC00 <= code <= 0xD7A3: return (code - 0xAC00) % 28 != 0 return False def apply_pronoun_postprocess(text: str, replacements: dict[str, str]) -> str: """Replace words and correct trailing Korean particles. Args: text: Input text. replacements: Mapping of old_word -> new_word. Returns: Text with replacements and corrected particles. """ # Build all particles sorted by length (longest first for greedy match) all_particles = sorted( list(PARTICLE_CORRECTION_MAP.keys()) + list(PARTICLE_CORRECTION_MAP.values()) + list(NEUTRAL_PARTICLES), key=len, reverse=True, ) # Deduplicate while preserving order seen: set[str] = set() unique_particles: list[str] = [] for p in all_particles: if p not in seen: seen.add(p) unique_particles.append(p) particle_pattern = "|".join(re.escape(p) for p in unique_particles) for old, new in replacements.items(): regex = re.compile(f"({re.escape(old)})({particle_pattern})?") def _replace(match: re.Match, new_word: str = new) -> str: particle = match.group(2) if particle is None: return new_word # If the new word ends with consonant and old particle is vowel form, correct it if new_word and _ends_with_consonant(new_word[-1]): corrected = PARTICLE_CORRECTION_MAP.get(particle, particle) return new_word + corrected return new_word + particle text = regex.sub(_replace, text) return text # --- Currency compact --- _CURRENCY_UNITS = r"(원|명|개|대|위안|달러|유로|엔|톤|권|건|회|장|마리|그루|송이|필)" _BOUNDARY = r'(?=[\s.,!?()"\'가-힣]|$)' _NUMBER = r"(\d+(?:\.\d+)?)" _SCALE = r"(만|억|조)" # Multi-digit: "3억 5000만 원" -> "3억5000만원" _CURRENCY_MULTI_RE = re.compile( rf"{_NUMBER}{_SCALE}\s*(\d+){_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}" ) # Single: "3억 원" -> "3억원" _CURRENCY_SINGLE_RE = re.compile(rf"{_NUMBER}{_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}") def apply_currency_compact(text: str) -> str: """Remove spaces between numbers and currency/unit markers. Args: text: Input text. Returns: Text with compacted currency expressions. """ text = _CURRENCY_MULTI_RE.sub(r"\1\2\3\4\5", text) text = _CURRENCY_SINGLE_RE.sub(r"\1\2\3", text) return text # --- Comma removal --- _COMMA_RE = re.compile(r"(?<=\d),(?=\d)") def apply_comma_removal(text: str) -> str: """Remove thousand-separator commas from numbers. Args: text: Input text. Returns: Text with commas removed from numbers. """ return _COMMA_RE.sub("", text) # --- Unit unicode conversion --- # Regex-based unit substitution matching the production Gradio reference. # Each pattern requires a digit prefix and enforces a non-letter boundary so # "mm" inside "mmol" or "cm" inside "cmH2O" is not mistakenly converted. _UNIT_PATTERNS: list[tuple[str, str]] = [ # Area (squared) (r"(?P\d+(?:\.\d+)?)\s*(?:mm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g㎟"), (r"(?P\d+(?:\.\d+)?)\s*(?:cm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g㎠"), (r"(?P\d+(?:\.\d+)?)\s*(?:km(?:\^?2|²))\b(?![A-Za-zµ])", r"\g㎢"), (r"(?P\d+(?:\.\d+)?)\s*(?:m(?:\^?2|²))\b(?![A-Za-zµ])", r"\g㎡"), # Volume (cubed) (r"(?P\d+(?:\.\d+)?)\s*(?:mm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g㎣"), (r"(?P\d+(?:\.\d+)?)\s*(?:cm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g㎤"), (r"(?P\d+(?:\.\d+)?)\s*(?:km(?:\^?3|³))\b(?![A-Za-zµ])", r"\g㎦"), (r"(?P\d+(?:\.\d+)?)\s*(?:m(?:\^?3|³))\b(?![A-Za-zµ])", r"\g㎥"), # Length (r"(?P\d+(?:\.\d+)?)\s*(?:μm|um)\b(?![A-Za-zµ])", r"\g㎛"), (r"(?P\d+(?:\.\d+)?)\s*(?:mm)\b(?![A-Za-zµ])", r"\g㎜"), (r"(?P\d+(?:\.\d+)?)\s*(?:cm)\b(?![A-Za-zµ])", r"\g㎝"), (r"(?P\d+(?:\.\d+)?)\s*(?:km)\b(?![A-Za-zµ])", r"\g㎞"), # Mass (r"(?P\d+(?:\.\d+)?)\s*kg\b(?![A-Za-zµ])", r"\g㎏"), (r"(?P\d+(?:\.\d+)?)\s*mg\b(?![A-Za-zµ])", r"\g㎎"), ] _UNIT_COMPILED = [(re.compile(p), r) for p, r in _UNIT_PATTERNS] def apply_unit_unicode(text: str) -> str: """Convert ` + unit` tokens to compact Unicode equivalents. Only substitutes when preceded by a digit and followed by a non-letter boundary, so tokens like "mmol" / "cmH2O" are left alone. Matches the production Gradio regex set (area/volume/length/mass). Args: text: Input text. Returns: Text with unit strings replaced by Unicode characters. """ for pat, rep in _UNIT_COMPILED: text = pat.sub(rep, text) return text # --- Correction filter --- def _compute_lcs_indices(orig_tokens: list[str], pred_tokens: list[str]) -> list[tuple[int, int]]: """Compute LCS matching indices between two token lists using DP.""" m, n = len(orig_tokens), len(pred_tokens) if m == 0 or n == 0: return [] # DP table dp = [[0] * (n + 1) for _ in range(m + 1)] for i in range(1, m + 1): for j in range(1, n + 1): if orig_tokens[i - 1] == pred_tokens[j - 1]: dp[i][j] = dp[i - 1][j - 1] + 1 else: dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) # Backtrack to find matching pairs pairs: list[tuple[int, int]] = [] i, j = m, n while i > 0 and j > 0: if orig_tokens[i - 1] == pred_tokens[j - 1]: pairs.append((i - 1, j - 1)) i -= 1 j -= 1 elif dp[i - 1][j] >= dp[i][j - 1]: i -= 1 else: j -= 1 pairs.reverse() return pairs def apply_correction_filter( text: str, original: str, max_char_diff: int = 2, allow_spacing: bool = True, ) -> str: """Filter corrections, reverting changes that exceed max_char_diff. Tokenizes both texts, computes LCS, and for each gap between matching tokens, checks if the correction is "safe" (small enough change). Unsafe corrections are reverted to original. Args: text: Corrected text. original: Original text before correction. max_char_diff: Maximum allowed character-level differences. allow_spacing: Whether to allow spacing-only changes. Returns: Filtered text with unsafe corrections reverted. """ orig_tokens = original.split() pred_tokens = text.split() if not orig_tokens or not pred_tokens: return text lcs_pairs = _compute_lcs_indices(orig_tokens, pred_tokens) def is_safe(o_text: str, c_text: str) -> bool: if o_text == c_text: return True if allow_spacing and o_text.replace(" ", "") == c_text.replace(" ", ""): return True if abs(len(o_text) - len(c_text)) <= 1: min_len = min(len(o_text), len(c_text)) d = sum(1 for i in range(min_len) if o_text[i] != c_text[i]) d += abs(len(o_text) - len(c_text)) if d <= max_char_diff: return True return False result_tokens: list[str] = [] prev_orig = -1 prev_pred = -1 for oi, pi in lcs_pairs: # Handle gap before this LCS match orig_gap = " ".join(orig_tokens[prev_orig + 1 : oi]) pred_gap = " ".join(pred_tokens[prev_pred + 1 : pi]) if orig_gap or pred_gap: if is_safe(orig_gap, pred_gap): result_tokens.append(pred_gap) else: result_tokens.append(orig_gap) # Add the matched token result_tokens.append(pred_tokens[pi]) prev_orig = oi prev_pred = pi # Handle trailing gap orig_gap = " ".join(orig_tokens[prev_orig + 1 :]) pred_gap = " ".join(pred_tokens[prev_pred + 1 :]) if orig_gap or pred_gap: if is_safe(orig_gap, pred_gap): result_tokens.append(pred_gap) else: result_tokens.append(orig_gap) return " ".join(t for t in result_tokens if t) # --- Paragraph-level duplication removal ------------------------------------- _PARA_SPLIT_RE = re.compile(r"\n+") _WS_RE = re.compile(r"\s+") def _normalize_paragraph(p: str) -> str: """공백/탭/개행 압축 정규형. 중복 매칭용.""" return _WS_RE.sub(" ", p).strip() def _split_output_paragraphs(text: str) -> list[str]: """Output 텍스트를 \\n+ 경계로 쪼개 비어있지 않은 문단 리스트 반환. 정규화하지 않은 원본 조각을 그대로 넘겨서, 재조립 시 모델이 쓴 구두점/ 스페이싱 디테일이 최대한 보존되도록 한다 (strip 은 조립 시 \\n 으로 다시 감싸므로 안전). """ if not text: return [] parts = _PARA_SPLIT_RE.split(text) return [p.strip() for p in parts if p.strip()] def apply_paragraph_dedupe( text: str, original: str, min_len: int = 40, prefix_len: int = 30, ) -> str: """LLM이 뱉은 중복 문단을 제거한다. **언제 제거하는가** - 같은 문단(정규화 후 완전일치) 이 output 에서 `input 등장 횟수 + 1` 이상 등장하면 뒤쪽 것을 drop. - 앞 ``prefix_len`` 자가 동일한 문단이 out > in 로 반복되면 drop (모델이 미세 교정으로 다르게 뱉은 echo 잡기). **언제 보존하는가** - 문단 길이 < ``min_len`` (◇소제목, 한 줄 인용 등 합법적 반복). - ``output_count <= input_count`` (저자가 의도해서 원문에 이미 반복 포함. 저자 의도를 훼손하지 않음). - 첫 등장은 항상 유지. 두 번째 이후 등장부터 조건에 맞으면 drop. Args: text: 파이프라인 출력 텍스트. original: 파이프라인 입력 (저자 원문). None 이면 빈 문자열로 간주해 input_count=0. 이 경우 output 에서 2 회 이상 동일 문단 등장 → 모두 drop. min_len: 이 길이 미만 정규화 문단은 dedupe 에서 제외. prefix_len: near-dup 탐지에 쓰는 앞자르기 길이. Returns: Dedupe 된 텍스트. 문단 구분자는 ``\\n`` 으로 재조립 (기존 파이프라인 출력 관례와 동일). 중복이 없으면 입력을 그대로 돌려준다. """ if not text: return text # Input 은 "normalized 전체 문자열" 로 두고 substring count 를 쓴다. # 이전 구현은 input 을 문단 Counter 로 셌는데, LLM 이 문단 경계를 재구조화 # (예: 줄바꿈 없이 중복 문장이 들어있던 한 문단을 여러 문단으로 쪼갬) 하면 # output 문단이 input 문단과 exact match 가 안 되어 input_count=0 으로 잡혀, # 정당한 중복 (저자/원본이 의도한 반복) 까지 drop 되는 버그가 있었다. in_norm = _normalize_paragraph(original or "") out_paras = _split_output_paragraphs(text) out_exact_seen: Counter[str] = Counter() out_prefix_seen: Counter[str] = Counter() kept: list[str] = [] any_dropped = False for para in out_paras: norm = _normalize_paragraph(para) if len(norm) < min_len: kept.append(para) continue prefix = norm[:prefix_len] out_exact_seen[norm] += 1 out_prefix_seen[prefix] += 1 # input 문자열 전체에서 해당 문단(또는 prefix)이 몇 번 substring 으로 # 등장하는지 집계. output_count 가 input_count 를 초과할 때만 drop. in_exact_count = in_norm.count(norm) if in_norm else 0 in_prefix_count = in_norm.count(prefix) if in_norm else 0 exact_dup = ( out_exact_seen[norm] > in_exact_count and out_exact_seen[norm] >= 2 ) near_dup = ( out_prefix_seen[prefix] > in_prefix_count and out_prefix_seen[prefix] >= 2 ) if exact_dup or near_dup: any_dropped = True continue kept.append(para) if not any_dropped: return text return "\n".join(kept)