proofread-20261h-demo / postprocess.py
dev-strender's picture
fix: preserve bulk boundary \n + substring-count dedupe + 30-char prefix
725b08e
"""Standalone post-processing rules extracted from solar-eval multi_step pipeline."""
import csv
import re
from collections import Counter
from pathlib import Path
# --- Vocabulary substitution ---
def load_vocabulary(csv_path: str) -> list[dict[str, str]]:
"""Load vocabulary CSV with wrong/correct columns.
Supports both:
- Headered CSV with columns 'wrong'/'before' and 'correct'/'after'
- Headerless CSV where first two columns are (wrong, correct)
Args:
csv_path: Path to vocabulary CSV file.
Returns:
List of dicts with 'wrong' and 'correct' keys.
"""
path = Path(csv_path)
if not path.exists():
return []
vocab: list[dict[str, str]] = []
header_names = {"wrong", "correct", "before", "after"}
with open(path, encoding="utf-8") as f:
rows = list(csv.reader(f))
if not rows:
return []
# Detect header: first row contains any of the expected column names
first_row_lower = [c.strip().lower() for c in rows[0]]
has_header = any(name in first_row_lower for name in header_names)
if has_header:
wrong_idx = next((i for i, c in enumerate(first_row_lower) if c in {"wrong", "before"}), 0)
correct_idx = next(
(i for i, c in enumerate(first_row_lower) if c in {"correct", "after"}), 1
)
data_rows = rows[1:]
else:
wrong_idx, correct_idx = 0, 1
data_rows = rows
for row in data_rows:
if len(row) <= max(wrong_idx, correct_idx):
continue
wrong = row[wrong_idx].strip()
correct = row[correct_idx].strip()
if wrong and correct:
vocab.append({"wrong": wrong, "correct": correct})
return vocab
def apply_vocabulary(text: str, vocab: list[dict[str, str]]) -> str:
"""Apply vocabulary substitutions.
Args:
text: Input text.
vocab: List of {wrong, correct} dicts.
Returns:
Text with vocabulary corrections applied.
"""
for entry in vocab:
text = text.replace(entry["wrong"], entry["correct"])
return text
# --- Pronoun/title post-processing ---
# Korean particle correction map: vowel-ending -> consonant-ending
PARTICLE_CORRECTION_MAP: dict[str, str] = {
"는": "은",
"가": "이",
"를": "을",
"와": "과",
"로": "으로",
"여": "이여",
"라": "이라",
"랑": "이랑",
"다": "이다",
"였다": "이었다",
"라면": "이라면",
"라서": "이라서",
"로부터": "으로부터",
}
# Particles that are valid for both vowel and consonant endings
NEUTRAL_PARTICLES = {"만", "도", "께서"}
def _ends_with_consonant(char: str) -> bool:
"""Check if a Korean character ends with a consonant (has jongseong)."""
if not char:
return False
code = ord(char)
if 0xAC00 <= code <= 0xD7A3:
return (code - 0xAC00) % 28 != 0
return False
def apply_pronoun_postprocess(text: str, replacements: dict[str, str]) -> str:
"""Replace words and correct trailing Korean particles.
Args:
text: Input text.
replacements: Mapping of old_word -> new_word.
Returns:
Text with replacements and corrected particles.
"""
# Build all particles sorted by length (longest first for greedy match)
all_particles = sorted(
list(PARTICLE_CORRECTION_MAP.keys())
+ list(PARTICLE_CORRECTION_MAP.values())
+ list(NEUTRAL_PARTICLES),
key=len,
reverse=True,
)
# Deduplicate while preserving order
seen: set[str] = set()
unique_particles: list[str] = []
for p in all_particles:
if p not in seen:
seen.add(p)
unique_particles.append(p)
particle_pattern = "|".join(re.escape(p) for p in unique_particles)
for old, new in replacements.items():
regex = re.compile(f"({re.escape(old)})({particle_pattern})?")
def _replace(match: re.Match, new_word: str = new) -> str:
particle = match.group(2)
if particle is None:
return new_word
# If the new word ends with consonant and old particle is vowel form, correct it
if new_word and _ends_with_consonant(new_word[-1]):
corrected = PARTICLE_CORRECTION_MAP.get(particle, particle)
return new_word + corrected
return new_word + particle
text = regex.sub(_replace, text)
return text
# --- Currency compact ---
_CURRENCY_UNITS = r"(원|명|개|대|위안|달러|유로|엔|톤|권|건|회|장|마리|그루|송이|필)"
_BOUNDARY = r'(?=[\s.,!?()"\'가-힣]|$)'
_NUMBER = r"(\d+(?:\.\d+)?)"
_SCALE = r"(만|억|조)"
# Multi-digit: "3억 5000만 원" -> "3억5000만원"
_CURRENCY_MULTI_RE = re.compile(
rf"{_NUMBER}{_SCALE}\s*(\d+){_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}"
)
# Single: "3억 원" -> "3억원"
_CURRENCY_SINGLE_RE = re.compile(rf"{_NUMBER}{_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}")
def apply_currency_compact(text: str) -> str:
"""Remove spaces between numbers and currency/unit markers.
Args:
text: Input text.
Returns:
Text with compacted currency expressions.
"""
text = _CURRENCY_MULTI_RE.sub(r"\1\2\3\4\5", text)
text = _CURRENCY_SINGLE_RE.sub(r"\1\2\3", text)
return text
# --- Comma removal ---
_COMMA_RE = re.compile(r"(?<=\d),(?=\d)")
def apply_comma_removal(text: str) -> str:
"""Remove thousand-separator commas from numbers.
Args:
text: Input text.
Returns:
Text with commas removed from numbers.
"""
return _COMMA_RE.sub("", text)
# --- Unit unicode conversion ---
# Regex-based unit substitution matching the production Gradio reference.
# Each pattern requires a digit prefix and enforces a non-letter boundary so
# "mm" inside "mmol" or "cm" inside "cmH2O" is not mistakenly converted.
_UNIT_PATTERNS: list[tuple[str, str]] = [
# Area (squared)
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎟"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎠"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎢"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎡"),
# Volume (cubed)
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎣"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎤"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎦"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎥"),
# Length
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:μm|um)\b(?![A-Za-zµ])", r"\g<num>㎛"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm)\b(?![A-Za-zµ])", r"\g<num>㎜"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm)\b(?![A-Za-zµ])", r"\g<num>㎝"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:km)\b(?![A-Za-zµ])", r"\g<num>㎞"),
# Mass
(r"(?P<num>\d+(?:\.\d+)?)\s*kg\b(?![A-Za-zµ])", r"\g<num>㎏"),
(r"(?P<num>\d+(?:\.\d+)?)\s*mg\b(?![A-Za-zµ])", r"\g<num>㎎"),
]
_UNIT_COMPILED = [(re.compile(p), r) for p, r in _UNIT_PATTERNS]
def apply_unit_unicode(text: str) -> str:
"""Convert `<number> + unit` tokens to compact Unicode equivalents.
Only substitutes when preceded by a digit and followed by a non-letter
boundary, so tokens like "mmol" / "cmH2O" are left alone. Matches the
production Gradio regex set (area/volume/length/mass).
Args:
text: Input text.
Returns:
Text with unit strings replaced by Unicode characters.
"""
for pat, rep in _UNIT_COMPILED:
text = pat.sub(rep, text)
return text
# --- Correction filter ---
def _compute_lcs_indices(orig_tokens: list[str], pred_tokens: list[str]) -> list[tuple[int, int]]:
"""Compute LCS matching indices between two token lists using DP."""
m, n = len(orig_tokens), len(pred_tokens)
if m == 0 or n == 0:
return []
# DP table
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if orig_tokens[i - 1] == pred_tokens[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
# Backtrack to find matching pairs
pairs: list[tuple[int, int]] = []
i, j = m, n
while i > 0 and j > 0:
if orig_tokens[i - 1] == pred_tokens[j - 1]:
pairs.append((i - 1, j - 1))
i -= 1
j -= 1
elif dp[i - 1][j] >= dp[i][j - 1]:
i -= 1
else:
j -= 1
pairs.reverse()
return pairs
def apply_correction_filter(
text: str,
original: str,
max_char_diff: int = 2,
allow_spacing: bool = True,
) -> str:
"""Filter corrections, reverting changes that exceed max_char_diff.
Tokenizes both texts, computes LCS, and for each gap between matching
tokens, checks if the correction is "safe" (small enough change).
Unsafe corrections are reverted to original.
Args:
text: Corrected text.
original: Original text before correction.
max_char_diff: Maximum allowed character-level differences.
allow_spacing: Whether to allow spacing-only changes.
Returns:
Filtered text with unsafe corrections reverted.
"""
orig_tokens = original.split()
pred_tokens = text.split()
if not orig_tokens or not pred_tokens:
return text
lcs_pairs = _compute_lcs_indices(orig_tokens, pred_tokens)
def is_safe(o_text: str, c_text: str) -> bool:
if o_text == c_text:
return True
if allow_spacing and o_text.replace(" ", "") == c_text.replace(" ", ""):
return True
if abs(len(o_text) - len(c_text)) <= 1:
min_len = min(len(o_text), len(c_text))
d = sum(1 for i in range(min_len) if o_text[i] != c_text[i])
d += abs(len(o_text) - len(c_text))
if d <= max_char_diff:
return True
return False
result_tokens: list[str] = []
prev_orig = -1
prev_pred = -1
for oi, pi in lcs_pairs:
# Handle gap before this LCS match
orig_gap = " ".join(orig_tokens[prev_orig + 1 : oi])
pred_gap = " ".join(pred_tokens[prev_pred + 1 : pi])
if orig_gap or pred_gap:
if is_safe(orig_gap, pred_gap):
result_tokens.append(pred_gap)
else:
result_tokens.append(orig_gap)
# Add the matched token
result_tokens.append(pred_tokens[pi])
prev_orig = oi
prev_pred = pi
# Handle trailing gap
orig_gap = " ".join(orig_tokens[prev_orig + 1 :])
pred_gap = " ".join(pred_tokens[prev_pred + 1 :])
if orig_gap or pred_gap:
if is_safe(orig_gap, pred_gap):
result_tokens.append(pred_gap)
else:
result_tokens.append(orig_gap)
return " ".join(t for t in result_tokens if t)
# --- Paragraph-level duplication removal -------------------------------------
_PARA_SPLIT_RE = re.compile(r"\n+")
_WS_RE = re.compile(r"\s+")
def _normalize_paragraph(p: str) -> str:
"""공백/탭/개행 압축 정규형. 중복 매칭용."""
return _WS_RE.sub(" ", p).strip()
def _split_output_paragraphs(text: str) -> list[str]:
"""Output 텍스트를 \\n+ 경계로 쪼개 비어있지 않은 문단 리스트 반환.
정규화하지 않은 원본 조각을 그대로 넘겨서, 재조립 시 모델이 쓴 구두점/
스페이싱 디테일이 최대한 보존되도록 한다 (strip 은 조립 시 \\n 으로
다시 감싸므로 안전).
"""
if not text:
return []
parts = _PARA_SPLIT_RE.split(text)
return [p.strip() for p in parts if p.strip()]
def apply_paragraph_dedupe(
text: str,
original: str,
min_len: int = 40,
prefix_len: int = 30,
) -> str:
"""LLM이 뱉은 중복 문단을 제거한다.
**언제 제거하는가**
- 같은 문단(정규화 후 완전일치) 이 output 에서 `input 등장 횟수 + 1`
이상 등장하면 뒤쪽 것을 drop.
- 앞 ``prefix_len`` 자가 동일한 문단이 out > in 로 반복되면 drop
(모델이 미세 교정으로 다르게 뱉은 echo 잡기).
**언제 보존하는가**
- 문단 길이 < ``min_len`` (◇소제목, 한 줄 인용 등 합법적 반복).
- ``output_count <= input_count`` (저자가 의도해서 원문에 이미 반복
포함. 저자 의도를 훼손하지 않음).
- 첫 등장은 항상 유지. 두 번째 이후 등장부터 조건에 맞으면 drop.
Args:
text: 파이프라인 출력 텍스트.
original: 파이프라인 입력 (저자 원문). None 이면 빈 문자열로 간주해
input_count=0. 이 경우 output 에서 2 회 이상 동일 문단 등장 →
모두 drop.
min_len: 이 길이 미만 정규화 문단은 dedupe 에서 제외.
prefix_len: near-dup 탐지에 쓰는 앞자르기 길이.
Returns:
Dedupe 된 텍스트. 문단 구분자는 ``\\n`` 으로 재조립 (기존 파이프라인
출력 관례와 동일). 중복이 없으면 입력을 그대로 돌려준다.
"""
if not text:
return text
# Input 은 "normalized 전체 문자열" 로 두고 substring count 를 쓴다.
# 이전 구현은 input 을 문단 Counter 로 셌는데, LLM 이 문단 경계를 재구조화
# (예: 줄바꿈 없이 중복 문장이 들어있던 한 문단을 여러 문단으로 쪼갬) 하면
# output 문단이 input 문단과 exact match 가 안 되어 input_count=0 으로 잡혀,
# 정당한 중복 (저자/원본이 의도한 반복) 까지 drop 되는 버그가 있었다.
in_norm = _normalize_paragraph(original or "")
out_paras = _split_output_paragraphs(text)
out_exact_seen: Counter[str] = Counter()
out_prefix_seen: Counter[str] = Counter()
kept: list[str] = []
any_dropped = False
for para in out_paras:
norm = _normalize_paragraph(para)
if len(norm) < min_len:
kept.append(para)
continue
prefix = norm[:prefix_len]
out_exact_seen[norm] += 1
out_prefix_seen[prefix] += 1
# input 문자열 전체에서 해당 문단(또는 prefix)이 몇 번 substring 으로
# 등장하는지 집계. output_count 가 input_count 를 초과할 때만 drop.
in_exact_count = in_norm.count(norm) if in_norm else 0
in_prefix_count = in_norm.count(prefix) if in_norm else 0
exact_dup = (
out_exact_seen[norm] > in_exact_count
and out_exact_seen[norm] >= 2
)
near_dup = (
out_prefix_seen[prefix] > in_prefix_count
and out_prefix_seen[prefix] >= 2
)
if exact_dup or near_dup:
any_dropped = True
continue
kept.append(para)
if not any_dropped:
return text
return "\n".join(kept)