File size: 15,138 Bytes
61d7017 c540d93 61d7017 c540d93 725b08e c540d93 725b08e c540d93 725b08e c540d93 725b08e c540d93 725b08e c540d93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 | """Standalone post-processing rules extracted from solar-eval multi_step pipeline."""
import csv
import re
from collections import Counter
from pathlib import Path
# --- Vocabulary substitution ---
def load_vocabulary(csv_path: str) -> list[dict[str, str]]:
"""Load vocabulary CSV with wrong/correct columns.
Supports both:
- Headered CSV with columns 'wrong'/'before' and 'correct'/'after'
- Headerless CSV where first two columns are (wrong, correct)
Args:
csv_path: Path to vocabulary CSV file.
Returns:
List of dicts with 'wrong' and 'correct' keys.
"""
path = Path(csv_path)
if not path.exists():
return []
vocab: list[dict[str, str]] = []
header_names = {"wrong", "correct", "before", "after"}
with open(path, encoding="utf-8") as f:
rows = list(csv.reader(f))
if not rows:
return []
# Detect header: first row contains any of the expected column names
first_row_lower = [c.strip().lower() for c in rows[0]]
has_header = any(name in first_row_lower for name in header_names)
if has_header:
wrong_idx = next((i for i, c in enumerate(first_row_lower) if c in {"wrong", "before"}), 0)
correct_idx = next(
(i for i, c in enumerate(first_row_lower) if c in {"correct", "after"}), 1
)
data_rows = rows[1:]
else:
wrong_idx, correct_idx = 0, 1
data_rows = rows
for row in data_rows:
if len(row) <= max(wrong_idx, correct_idx):
continue
wrong = row[wrong_idx].strip()
correct = row[correct_idx].strip()
if wrong and correct:
vocab.append({"wrong": wrong, "correct": correct})
return vocab
def apply_vocabulary(text: str, vocab: list[dict[str, str]]) -> str:
"""Apply vocabulary substitutions.
Args:
text: Input text.
vocab: List of {wrong, correct} dicts.
Returns:
Text with vocabulary corrections applied.
"""
for entry in vocab:
text = text.replace(entry["wrong"], entry["correct"])
return text
# --- Pronoun/title post-processing ---
# Korean particle correction map: vowel-ending -> consonant-ending
PARTICLE_CORRECTION_MAP: dict[str, str] = {
"는": "은",
"가": "이",
"를": "을",
"와": "과",
"로": "으로",
"여": "이여",
"라": "이라",
"랑": "이랑",
"다": "이다",
"였다": "이었다",
"라면": "이라면",
"라서": "이라서",
"로부터": "으로부터",
}
# Particles that are valid for both vowel and consonant endings
NEUTRAL_PARTICLES = {"만", "도", "께서"}
def _ends_with_consonant(char: str) -> bool:
"""Check if a Korean character ends with a consonant (has jongseong)."""
if not char:
return False
code = ord(char)
if 0xAC00 <= code <= 0xD7A3:
return (code - 0xAC00) % 28 != 0
return False
def apply_pronoun_postprocess(text: str, replacements: dict[str, str]) -> str:
"""Replace words and correct trailing Korean particles.
Args:
text: Input text.
replacements: Mapping of old_word -> new_word.
Returns:
Text with replacements and corrected particles.
"""
# Build all particles sorted by length (longest first for greedy match)
all_particles = sorted(
list(PARTICLE_CORRECTION_MAP.keys())
+ list(PARTICLE_CORRECTION_MAP.values())
+ list(NEUTRAL_PARTICLES),
key=len,
reverse=True,
)
# Deduplicate while preserving order
seen: set[str] = set()
unique_particles: list[str] = []
for p in all_particles:
if p not in seen:
seen.add(p)
unique_particles.append(p)
particle_pattern = "|".join(re.escape(p) for p in unique_particles)
for old, new in replacements.items():
regex = re.compile(f"({re.escape(old)})({particle_pattern})?")
def _replace(match: re.Match, new_word: str = new) -> str:
particle = match.group(2)
if particle is None:
return new_word
# If the new word ends with consonant and old particle is vowel form, correct it
if new_word and _ends_with_consonant(new_word[-1]):
corrected = PARTICLE_CORRECTION_MAP.get(particle, particle)
return new_word + corrected
return new_word + particle
text = regex.sub(_replace, text)
return text
# --- Currency compact ---
_CURRENCY_UNITS = r"(원|명|개|대|위안|달러|유로|엔|톤|권|건|회|장|마리|그루|송이|필)"
_BOUNDARY = r'(?=[\s.,!?()"\'가-힣]|$)'
_NUMBER = r"(\d+(?:\.\d+)?)"
_SCALE = r"(만|억|조)"
# Multi-digit: "3억 5000만 원" -> "3억5000만원"
_CURRENCY_MULTI_RE = re.compile(
rf"{_NUMBER}{_SCALE}\s*(\d+){_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}"
)
# Single: "3억 원" -> "3억원"
_CURRENCY_SINGLE_RE = re.compile(rf"{_NUMBER}{_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}")
def apply_currency_compact(text: str) -> str:
"""Remove spaces between numbers and currency/unit markers.
Args:
text: Input text.
Returns:
Text with compacted currency expressions.
"""
text = _CURRENCY_MULTI_RE.sub(r"\1\2\3\4\5", text)
text = _CURRENCY_SINGLE_RE.sub(r"\1\2\3", text)
return text
# --- Comma removal ---
_COMMA_RE = re.compile(r"(?<=\d),(?=\d)")
def apply_comma_removal(text: str) -> str:
"""Remove thousand-separator commas from numbers.
Args:
text: Input text.
Returns:
Text with commas removed from numbers.
"""
return _COMMA_RE.sub("", text)
# --- Unit unicode conversion ---
# Regex-based unit substitution matching the production Gradio reference.
# Each pattern requires a digit prefix and enforces a non-letter boundary so
# "mm" inside "mmol" or "cm" inside "cmH2O" is not mistakenly converted.
_UNIT_PATTERNS: list[tuple[str, str]] = [
# Area (squared)
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎟"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎠"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎢"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎡"),
# Volume (cubed)
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎣"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎤"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎦"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎥"),
# Length
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:μm|um)\b(?![A-Za-zµ])", r"\g<num>㎛"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm)\b(?![A-Za-zµ])", r"\g<num>㎜"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm)\b(?![A-Za-zµ])", r"\g<num>㎝"),
(r"(?P<num>\d+(?:\.\d+)?)\s*(?:km)\b(?![A-Za-zµ])", r"\g<num>㎞"),
# Mass
(r"(?P<num>\d+(?:\.\d+)?)\s*kg\b(?![A-Za-zµ])", r"\g<num>㎏"),
(r"(?P<num>\d+(?:\.\d+)?)\s*mg\b(?![A-Za-zµ])", r"\g<num>㎎"),
]
_UNIT_COMPILED = [(re.compile(p), r) for p, r in _UNIT_PATTERNS]
def apply_unit_unicode(text: str) -> str:
"""Convert `<number> + unit` tokens to compact Unicode equivalents.
Only substitutes when preceded by a digit and followed by a non-letter
boundary, so tokens like "mmol" / "cmH2O" are left alone. Matches the
production Gradio regex set (area/volume/length/mass).
Args:
text: Input text.
Returns:
Text with unit strings replaced by Unicode characters.
"""
for pat, rep in _UNIT_COMPILED:
text = pat.sub(rep, text)
return text
# --- Correction filter ---
def _compute_lcs_indices(orig_tokens: list[str], pred_tokens: list[str]) -> list[tuple[int, int]]:
"""Compute LCS matching indices between two token lists using DP."""
m, n = len(orig_tokens), len(pred_tokens)
if m == 0 or n == 0:
return []
# DP table
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if orig_tokens[i - 1] == pred_tokens[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
# Backtrack to find matching pairs
pairs: list[tuple[int, int]] = []
i, j = m, n
while i > 0 and j > 0:
if orig_tokens[i - 1] == pred_tokens[j - 1]:
pairs.append((i - 1, j - 1))
i -= 1
j -= 1
elif dp[i - 1][j] >= dp[i][j - 1]:
i -= 1
else:
j -= 1
pairs.reverse()
return pairs
def apply_correction_filter(
text: str,
original: str,
max_char_diff: int = 2,
allow_spacing: bool = True,
) -> str:
"""Filter corrections, reverting changes that exceed max_char_diff.
Tokenizes both texts, computes LCS, and for each gap between matching
tokens, checks if the correction is "safe" (small enough change).
Unsafe corrections are reverted to original.
Args:
text: Corrected text.
original: Original text before correction.
max_char_diff: Maximum allowed character-level differences.
allow_spacing: Whether to allow spacing-only changes.
Returns:
Filtered text with unsafe corrections reverted.
"""
orig_tokens = original.split()
pred_tokens = text.split()
if not orig_tokens or not pred_tokens:
return text
lcs_pairs = _compute_lcs_indices(orig_tokens, pred_tokens)
def is_safe(o_text: str, c_text: str) -> bool:
if o_text == c_text:
return True
if allow_spacing and o_text.replace(" ", "") == c_text.replace(" ", ""):
return True
if abs(len(o_text) - len(c_text)) <= 1:
min_len = min(len(o_text), len(c_text))
d = sum(1 for i in range(min_len) if o_text[i] != c_text[i])
d += abs(len(o_text) - len(c_text))
if d <= max_char_diff:
return True
return False
result_tokens: list[str] = []
prev_orig = -1
prev_pred = -1
for oi, pi in lcs_pairs:
# Handle gap before this LCS match
orig_gap = " ".join(orig_tokens[prev_orig + 1 : oi])
pred_gap = " ".join(pred_tokens[prev_pred + 1 : pi])
if orig_gap or pred_gap:
if is_safe(orig_gap, pred_gap):
result_tokens.append(pred_gap)
else:
result_tokens.append(orig_gap)
# Add the matched token
result_tokens.append(pred_tokens[pi])
prev_orig = oi
prev_pred = pi
# Handle trailing gap
orig_gap = " ".join(orig_tokens[prev_orig + 1 :])
pred_gap = " ".join(pred_tokens[prev_pred + 1 :])
if orig_gap or pred_gap:
if is_safe(orig_gap, pred_gap):
result_tokens.append(pred_gap)
else:
result_tokens.append(orig_gap)
return " ".join(t for t in result_tokens if t)
# --- Paragraph-level duplication removal -------------------------------------
_PARA_SPLIT_RE = re.compile(r"\n+")
_WS_RE = re.compile(r"\s+")
def _normalize_paragraph(p: str) -> str:
"""공백/탭/개행 압축 정규형. 중복 매칭용."""
return _WS_RE.sub(" ", p).strip()
def _split_output_paragraphs(text: str) -> list[str]:
"""Output 텍스트를 \\n+ 경계로 쪼개 비어있지 않은 문단 리스트 반환.
정규화하지 않은 원본 조각을 그대로 넘겨서, 재조립 시 모델이 쓴 구두점/
스페이싱 디테일이 최대한 보존되도록 한다 (strip 은 조립 시 \\n 으로
다시 감싸므로 안전).
"""
if not text:
return []
parts = _PARA_SPLIT_RE.split(text)
return [p.strip() for p in parts if p.strip()]
def apply_paragraph_dedupe(
text: str,
original: str,
min_len: int = 40,
prefix_len: int = 30,
) -> str:
"""LLM이 뱉은 중복 문단을 제거한다.
**언제 제거하는가**
- 같은 문단(정규화 후 완전일치) 이 output 에서 `input 등장 횟수 + 1`
이상 등장하면 뒤쪽 것을 drop.
- 앞 ``prefix_len`` 자가 동일한 문단이 out > in 로 반복되면 drop
(모델이 미세 교정으로 다르게 뱉은 echo 잡기).
**언제 보존하는가**
- 문단 길이 < ``min_len`` (◇소제목, 한 줄 인용 등 합법적 반복).
- ``output_count <= input_count`` (저자가 의도해서 원문에 이미 반복
포함. 저자 의도를 훼손하지 않음).
- 첫 등장은 항상 유지. 두 번째 이후 등장부터 조건에 맞으면 drop.
Args:
text: 파이프라인 출력 텍스트.
original: 파이프라인 입력 (저자 원문). None 이면 빈 문자열로 간주해
input_count=0. 이 경우 output 에서 2 회 이상 동일 문단 등장 →
모두 drop.
min_len: 이 길이 미만 정규화 문단은 dedupe 에서 제외.
prefix_len: near-dup 탐지에 쓰는 앞자르기 길이.
Returns:
Dedupe 된 텍스트. 문단 구분자는 ``\\n`` 으로 재조립 (기존 파이프라인
출력 관례와 동일). 중복이 없으면 입력을 그대로 돌려준다.
"""
if not text:
return text
# Input 은 "normalized 전체 문자열" 로 두고 substring count 를 쓴다.
# 이전 구현은 input 을 문단 Counter 로 셌는데, LLM 이 문단 경계를 재구조화
# (예: 줄바꿈 없이 중복 문장이 들어있던 한 문단을 여러 문단으로 쪼갬) 하면
# output 문단이 input 문단과 exact match 가 안 되어 input_count=0 으로 잡혀,
# 정당한 중복 (저자/원본이 의도한 반복) 까지 drop 되는 버그가 있었다.
in_norm = _normalize_paragraph(original or "")
out_paras = _split_output_paragraphs(text)
out_exact_seen: Counter[str] = Counter()
out_prefix_seen: Counter[str] = Counter()
kept: list[str] = []
any_dropped = False
for para in out_paras:
norm = _normalize_paragraph(para)
if len(norm) < min_len:
kept.append(para)
continue
prefix = norm[:prefix_len]
out_exact_seen[norm] += 1
out_prefix_seen[prefix] += 1
# input 문자열 전체에서 해당 문단(또는 prefix)이 몇 번 substring 으로
# 등장하는지 집계. output_count 가 input_count 를 초과할 때만 drop.
in_exact_count = in_norm.count(norm) if in_norm else 0
in_prefix_count = in_norm.count(prefix) if in_norm else 0
exact_dup = (
out_exact_seen[norm] > in_exact_count
and out_exact_seen[norm] >= 2
)
near_dup = (
out_prefix_seen[prefix] > in_prefix_count
and out_prefix_seen[prefix] >= 2
)
if exact_dup or near_dup:
any_dropped = True
continue
kept.append(para)
if not any_dropped:
return text
return "\n".join(kept)
|