File size: 15,138 Bytes
61d7017
 
 
 
c540d93
61d7017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c540d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725b08e
c540d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725b08e
 
 
 
 
 
c540d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725b08e
 
 
 
 
c540d93
725b08e
c540d93
 
 
725b08e
c540d93
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
"""Standalone post-processing rules extracted from solar-eval multi_step pipeline."""

import csv
import re
from collections import Counter
from pathlib import Path

# --- Vocabulary substitution ---


def load_vocabulary(csv_path: str) -> list[dict[str, str]]:
    """Load vocabulary CSV with wrong/correct columns.

    Supports both:
    - Headered CSV with columns 'wrong'/'before' and 'correct'/'after'
    - Headerless CSV where first two columns are (wrong, correct)

    Args:
        csv_path: Path to vocabulary CSV file.

    Returns:
        List of dicts with 'wrong' and 'correct' keys.
    """
    path = Path(csv_path)
    if not path.exists():
        return []

    vocab: list[dict[str, str]] = []
    header_names = {"wrong", "correct", "before", "after"}

    with open(path, encoding="utf-8") as f:
        rows = list(csv.reader(f))

    if not rows:
        return []

    # Detect header: first row contains any of the expected column names
    first_row_lower = [c.strip().lower() for c in rows[0]]
    has_header = any(name in first_row_lower for name in header_names)

    if has_header:
        wrong_idx = next((i for i, c in enumerate(first_row_lower) if c in {"wrong", "before"}), 0)
        correct_idx = next(
            (i for i, c in enumerate(first_row_lower) if c in {"correct", "after"}), 1
        )
        data_rows = rows[1:]
    else:
        wrong_idx, correct_idx = 0, 1
        data_rows = rows

    for row in data_rows:
        if len(row) <= max(wrong_idx, correct_idx):
            continue
        wrong = row[wrong_idx].strip()
        correct = row[correct_idx].strip()
        if wrong and correct:
            vocab.append({"wrong": wrong, "correct": correct})
    return vocab


def apply_vocabulary(text: str, vocab: list[dict[str, str]]) -> str:
    """Apply vocabulary substitutions.

    Args:
        text: Input text.
        vocab: List of {wrong, correct} dicts.

    Returns:
        Text with vocabulary corrections applied.
    """
    for entry in vocab:
        text = text.replace(entry["wrong"], entry["correct"])
    return text


# --- Pronoun/title post-processing ---

# Korean particle correction map: vowel-ending -> consonant-ending
PARTICLE_CORRECTION_MAP: dict[str, str] = {
    "는": "은",
    "가": "이",
    "를": "을",
    "와": "과",
    "로": "으로",
    "여": "이여",
    "라": "이라",
    "랑": "이랑",
    "다": "이다",
    "였다": "이었다",
    "라면": "이라면",
    "라서": "이라서",
    "로부터": "으로부터",
}

# Particles that are valid for both vowel and consonant endings
NEUTRAL_PARTICLES = {"만", "도", "께서"}


def _ends_with_consonant(char: str) -> bool:
    """Check if a Korean character ends with a consonant (has jongseong)."""
    if not char:
        return False
    code = ord(char)
    if 0xAC00 <= code <= 0xD7A3:
        return (code - 0xAC00) % 28 != 0
    return False


def apply_pronoun_postprocess(text: str, replacements: dict[str, str]) -> str:
    """Replace words and correct trailing Korean particles.

    Args:
        text: Input text.
        replacements: Mapping of old_word -> new_word.

    Returns:
        Text with replacements and corrected particles.
    """
    # Build all particles sorted by length (longest first for greedy match)
    all_particles = sorted(
        list(PARTICLE_CORRECTION_MAP.keys())
        + list(PARTICLE_CORRECTION_MAP.values())
        + list(NEUTRAL_PARTICLES),
        key=len,
        reverse=True,
    )
    # Deduplicate while preserving order
    seen: set[str] = set()
    unique_particles: list[str] = []
    for p in all_particles:
        if p not in seen:
            seen.add(p)
            unique_particles.append(p)

    particle_pattern = "|".join(re.escape(p) for p in unique_particles)

    for old, new in replacements.items():
        regex = re.compile(f"({re.escape(old)})({particle_pattern})?")

        def _replace(match: re.Match, new_word: str = new) -> str:
            particle = match.group(2)
            if particle is None:
                return new_word
            # If the new word ends with consonant and old particle is vowel form, correct it
            if new_word and _ends_with_consonant(new_word[-1]):
                corrected = PARTICLE_CORRECTION_MAP.get(particle, particle)
                return new_word + corrected
            return new_word + particle

        text = regex.sub(_replace, text)

    return text


# --- Currency compact ---

_CURRENCY_UNITS = r"(원|명|개|대|위안|달러|유로|엔|톤|권|건|회|장|마리|그루|송이|필)"
_BOUNDARY = r'(?=[\s.,!?()"\'가-힣]|$)'
_NUMBER = r"(\d+(?:\.\d+)?)"
_SCALE = r"(만|억|조)"

# Multi-digit: "3억 5000만 원" -> "3억5000만원"
_CURRENCY_MULTI_RE = re.compile(
    rf"{_NUMBER}{_SCALE}\s*(\d+){_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}"
)
# Single: "3억 원" -> "3억원"
_CURRENCY_SINGLE_RE = re.compile(rf"{_NUMBER}{_SCALE}\s*{_CURRENCY_UNITS}{_BOUNDARY}")


def apply_currency_compact(text: str) -> str:
    """Remove spaces between numbers and currency/unit markers.

    Args:
        text: Input text.

    Returns:
        Text with compacted currency expressions.
    """
    text = _CURRENCY_MULTI_RE.sub(r"\1\2\3\4\5", text)
    text = _CURRENCY_SINGLE_RE.sub(r"\1\2\3", text)
    return text


# --- Comma removal ---

_COMMA_RE = re.compile(r"(?<=\d),(?=\d)")


def apply_comma_removal(text: str) -> str:
    """Remove thousand-separator commas from numbers.

    Args:
        text: Input text.

    Returns:
        Text with commas removed from numbers.
    """
    return _COMMA_RE.sub("", text)


# --- Unit unicode conversion ---

# Regex-based unit substitution matching the production Gradio reference.
# Each pattern requires a digit prefix and enforces a non-letter boundary so
# "mm" inside "mmol" or "cm" inside "cmH2O" is not mistakenly converted.
_UNIT_PATTERNS: list[tuple[str, str]] = [
    # Area (squared)
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎟"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎠"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎢"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?2|²))\b(?![A-Za-zµ])", r"\g<num>㎡"),
    # Volume (cubed)
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎣"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎤"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:km(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎦"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:m(?:\^?3|³))\b(?![A-Za-zµ])", r"\g<num>㎥"),
    # Length
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:μm|um)\b(?![A-Za-zµ])", r"\g<num>㎛"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:mm)\b(?![A-Za-zµ])", r"\g<num>㎜"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:cm)\b(?![A-Za-zµ])", r"\g<num>㎝"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*(?:km)\b(?![A-Za-zµ])", r"\g<num>㎞"),
    # Mass
    (r"(?P<num>\d+(?:\.\d+)?)\s*kg\b(?![A-Za-zµ])", r"\g<num>㎏"),
    (r"(?P<num>\d+(?:\.\d+)?)\s*mg\b(?![A-Za-zµ])", r"\g<num>㎎"),
]
_UNIT_COMPILED = [(re.compile(p), r) for p, r in _UNIT_PATTERNS]


def apply_unit_unicode(text: str) -> str:
    """Convert `<number> + unit` tokens to compact Unicode equivalents.

    Only substitutes when preceded by a digit and followed by a non-letter
    boundary, so tokens like "mmol" / "cmH2O" are left alone. Matches the
    production Gradio regex set (area/volume/length/mass).

    Args:
        text: Input text.

    Returns:
        Text with unit strings replaced by Unicode characters.
    """
    for pat, rep in _UNIT_COMPILED:
        text = pat.sub(rep, text)
    return text


# --- Correction filter ---


def _compute_lcs_indices(orig_tokens: list[str], pred_tokens: list[str]) -> list[tuple[int, int]]:
    """Compute LCS matching indices between two token lists using DP."""
    m, n = len(orig_tokens), len(pred_tokens)
    if m == 0 or n == 0:
        return []

    # DP table
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if orig_tokens[i - 1] == pred_tokens[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

    # Backtrack to find matching pairs
    pairs: list[tuple[int, int]] = []
    i, j = m, n
    while i > 0 and j > 0:
        if orig_tokens[i - 1] == pred_tokens[j - 1]:
            pairs.append((i - 1, j - 1))
            i -= 1
            j -= 1
        elif dp[i - 1][j] >= dp[i][j - 1]:
            i -= 1
        else:
            j -= 1
    pairs.reverse()
    return pairs


def apply_correction_filter(
    text: str,
    original: str,
    max_char_diff: int = 2,
    allow_spacing: bool = True,
) -> str:
    """Filter corrections, reverting changes that exceed max_char_diff.

    Tokenizes both texts, computes LCS, and for each gap between matching
    tokens, checks if the correction is "safe" (small enough change).
    Unsafe corrections are reverted to original.

    Args:
        text: Corrected text.
        original: Original text before correction.
        max_char_diff: Maximum allowed character-level differences.
        allow_spacing: Whether to allow spacing-only changes.

    Returns:
        Filtered text with unsafe corrections reverted.
    """
    orig_tokens = original.split()
    pred_tokens = text.split()

    if not orig_tokens or not pred_tokens:
        return text

    lcs_pairs = _compute_lcs_indices(orig_tokens, pred_tokens)

    def is_safe(o_text: str, c_text: str) -> bool:
        if o_text == c_text:
            return True
        if allow_spacing and o_text.replace(" ", "") == c_text.replace(" ", ""):
            return True
        if abs(len(o_text) - len(c_text)) <= 1:
            min_len = min(len(o_text), len(c_text))
            d = sum(1 for i in range(min_len) if o_text[i] != c_text[i])
            d += abs(len(o_text) - len(c_text))
            if d <= max_char_diff:
                return True
        return False

    result_tokens: list[str] = []
    prev_orig = -1
    prev_pred = -1

    for oi, pi in lcs_pairs:
        # Handle gap before this LCS match
        orig_gap = " ".join(orig_tokens[prev_orig + 1 : oi])
        pred_gap = " ".join(pred_tokens[prev_pred + 1 : pi])

        if orig_gap or pred_gap:
            if is_safe(orig_gap, pred_gap):
                result_tokens.append(pred_gap)
            else:
                result_tokens.append(orig_gap)

        # Add the matched token
        result_tokens.append(pred_tokens[pi])
        prev_orig = oi
        prev_pred = pi

    # Handle trailing gap
    orig_gap = " ".join(orig_tokens[prev_orig + 1 :])
    pred_gap = " ".join(pred_tokens[prev_pred + 1 :])
    if orig_gap or pred_gap:
        if is_safe(orig_gap, pred_gap):
            result_tokens.append(pred_gap)
        else:
            result_tokens.append(orig_gap)

    return " ".join(t for t in result_tokens if t)


# --- Paragraph-level duplication removal -------------------------------------

_PARA_SPLIT_RE = re.compile(r"\n+")
_WS_RE = re.compile(r"\s+")


def _normalize_paragraph(p: str) -> str:
    """공백/탭/개행 압축 정규형. 중복 매칭용."""
    return _WS_RE.sub(" ", p).strip()


def _split_output_paragraphs(text: str) -> list[str]:
    """Output 텍스트를 \\n+ 경계로 쪼개 비어있지 않은 문단 리스트 반환.

    정규화하지 않은 원본 조각을 그대로 넘겨서, 재조립 시 모델이 쓴 구두점/
    스페이싱 디테일이 최대한 보존되도록 한다 (strip 은 조립 시 \\n 으로
    다시 감싸므로 안전).
    """
    if not text:
        return []
    parts = _PARA_SPLIT_RE.split(text)
    return [p.strip() for p in parts if p.strip()]


def apply_paragraph_dedupe(
    text: str,
    original: str,
    min_len: int = 40,
    prefix_len: int = 30,
) -> str:
    """LLM이 뱉은 중복 문단을 제거한다.

    **언제 제거하는가**
      - 같은 문단(정규화 후 완전일치) 이 output 에서 `input 등장 횟수 + 1`
        이상 등장하면 뒤쪽 것을 drop.
      - 앞 ``prefix_len`` 자가 동일한 문단이 out > in 로 반복되면 drop
        (모델이 미세 교정으로 다르게 뱉은 echo 잡기).

    **언제 보존하는가**
      - 문단 길이 < ``min_len`` (◇소제목, 한 줄 인용 등 합법적 반복).
      - ``output_count <= input_count`` (저자가 의도해서 원문에 이미 반복
        포함. 저자 의도를 훼손하지 않음).
      - 첫 등장은 항상 유지. 두 번째 이후 등장부터 조건에 맞으면 drop.

    Args:
        text: 파이프라인 출력 텍스트.
        original: 파이프라인 입력 (저자 원문). None 이면 빈 문자열로 간주해
            input_count=0. 이 경우 output 에서 2 회 이상 동일 문단 등장 →
            모두 drop.
        min_len: 이 길이 미만 정규화 문단은 dedupe 에서 제외.
        prefix_len: near-dup 탐지에 쓰는 앞자르기 길이.

    Returns:
        Dedupe 된 텍스트. 문단 구분자는 ``\\n`` 으로 재조립 (기존 파이프라인
        출력 관례와 동일). 중복이 없으면 입력을 그대로 돌려준다.
    """
    if not text:
        return text

    # Input 은 "normalized 전체 문자열" 로 두고 substring count 를 쓴다.
    # 이전 구현은 input 을 문단 Counter 로 셌는데, LLM 이 문단 경계를 재구조화
    # (예: 줄바꿈 없이 중복 문장이 들어있던 한 문단을 여러 문단으로 쪼갬) 하면
    # output 문단이 input 문단과 exact match 가 안 되어 input_count=0 으로 잡혀,
    # 정당한 중복 (저자/원본이 의도한 반복) 까지 drop 되는 버그가 있었다.
    in_norm = _normalize_paragraph(original or "")

    out_paras = _split_output_paragraphs(text)
    out_exact_seen: Counter[str] = Counter()
    out_prefix_seen: Counter[str] = Counter()

    kept: list[str] = []
    any_dropped = False
    for para in out_paras:
        norm = _normalize_paragraph(para)
        if len(norm) < min_len:
            kept.append(para)
            continue

        prefix = norm[:prefix_len]
        out_exact_seen[norm] += 1
        out_prefix_seen[prefix] += 1

        # input 문자열 전체에서 해당 문단(또는 prefix)이 몇 번 substring 으로
        # 등장하는지 집계. output_count 가 input_count 를 초과할 때만 drop.
        in_exact_count = in_norm.count(norm) if in_norm else 0
        in_prefix_count = in_norm.count(prefix) if in_norm else 0

        exact_dup = (
            out_exact_seen[norm] > in_exact_count
            and out_exact_seen[norm] >= 2
        )
        near_dup = (
            out_prefix_seen[prefix] > in_prefix_count
            and out_prefix_seen[prefix] >= 2
        )

        if exact_dup or near_dup:
            any_dropped = True
            continue

        kept.append(para)

    if not any_dropped:
        return text

    return "\n".join(kept)