File size: 17,488 Bytes
ffd1353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
"""
MonSub LLM Editor โ€” Self-bootstrapping RunPod Serverless Handler

Loads: Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilled (base)
     + Tsedee/mongol-editor-llm-v1 (LoRA adapter) [swap to -v2 after v2 training]

Accepts batches of raw Whisper-style text segments and returns edited
Mongolian subtitle text with post-processing:
  - Brand name correction (chitaโ†’GTA, ะฐะธั„ะพะฝโ†’iPhone, etc.)
  - Hallucination guard (rejects outputs that are too different from input)
  - Chain-of-thought stripping (keeps only "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:" content)
  - </think> tag cleanup

API:
    Input (JSON):
    {
        "texts": ["text 1", "text 2", ...],            # required
        "mode":  "edit" | "summarize" | "rewrite",     # default: "edit"
        "instruction": "optional custom prompt",        # optional
        "skip_post_processing": false                    # optional
    }

    Output:
    {
        "edited":   ["edited 1", "edited 2", ...],
        "stats":    { "count": N, "time_s": T, "tokens_per_s": X },
        "fallback_used": [idx1, idx2, ...]  # indices where hallucination guard fired
    }
"""
import os, sys, subprocess, time

# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# BOOTSTRAP
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def ensure(pkg_import, pip_name=None):
    try:
        __import__(pkg_import)
    except ImportError:
        name = pip_name or pkg_import
        print(f"[BOOT] installing {name}...", flush=True)
        subprocess.run([sys.executable, "-m", "pip", "install", "--quiet", "--no-cache-dir", name], check=True)


print("[BOOT] LLM editor handler starting...", flush=True)
t0 = time.time()

ensure("runpod")
ensure("transformers", "transformers==5.5.0")
ensure("peft", "peft==0.18.1")
ensure("accelerate", "accelerate>=1.0.0")
ensure("huggingface_hub")

print(f"[BOOT] deps ready in {time.time()-t0:.1f}s", flush=True)

# โ”€โ”€ Module-level: only stdlib + runpod โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

import re
import traceback
import runpod

HF_TOKEN     = os.environ.get("HF_TOKEN", "")
BASE_MODEL   = os.environ.get("BASE_MODEL", "Jackrong/Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilled")
ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "Tsedee/mongol-editor-llm-v1")

MODEL     = None
TOKENIZER = None
torch = None  # lazy-loaded

# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# BRAND CORRECTION DICT โ€” post-processing safety net
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Applied AFTER model output to catch brand names the model missed.
# Case-insensitive substring match with word boundaries where possible.
BRAND_FIXES = [
    # (pattern_regex, replacement)
    # Games
    (r"\bั‡ะธั‚ะฐ\s*5\b",            "GTA 5"),
    (r"\bะถะธั‚ะฐ\s*5\b",            "GTA 5"),
    (r"\bะณั‚ะฐ\s*5\b",             "GTA 5"),
    (r"\bั‡ะธั‚ะฐ\s*6\b",            "GTA 6"),
    (r"\bะถะธั‚ะฐ\s*6\b",            "GTA 6"),
    (r"\bะณั‚ะฐ\s*6\b",             "GTA 6"),
    (r"\bั„ะธั„ะฐ\b",                "FIFA"),
    (r"\bะบะพะป\s*ะพั„\s*ะดัŽั‚ะธ\b",     "Call of Duty"),
    (r"\bะบะฐะปะป\s*ะพั„\s*ะดัŽั‚ะธ\b",    "Call of Duty"),
    (r"\bะผะฐะนะฝะบั€ะฐั„ั‚\b",           "Minecraft"),
    (r"\bะผะฐะนะฝ\s*ะบั€ะฐั„ั‚\b",        "Minecraft"),
    (r"\bั€ะพะฑะปะพะบั\b",             "Roblox"),
    (r"\bั„ะพั€ั‚ะฝะฐะนั‚\b",            "Fortnite"),
    (r"\bะฒะฐะปัŒะพั€ะฐะฝั‚\b",           "Valorant"),
    (r"\bะฒะฐะปะพั€ะฐะฝั‚\b",            "Valorant"),
    (r"\bะฑะฐะณัั‚ะฐั€ะธ\b",            "Rockstar Games"),
    (r"\bะฑะฐะณัั‚ะฐั€\b",             "Rockstar Games"),
    (r"\bะฟัƒะฑะณ\b",                "PUBG"),
    (r"\bะบั\s*ะณะพ\b",             "CS:GO"),
    (r"\bะดะพั‚ะฐ\s*2\b",            "Dota 2"),
    (r"\bัŽะฑะธัะพั„ั‚\b",             "Ubisoft"),
    (r"\bัั‚ะธะผ\b",                "Steam"),

    # Tech
    (r"\bะฐะธั„ะพะฝ\b",               "iPhone"),
    (r"\bะฐะนั„ะพะฝ\b",               "iPhone"),
    (r"\bะธะฟะฐะด\b",                "iPad"),
    (r"\bะฐะนะฟะฐะด\b",               "iPad"),
    (r"\bะผะฐะบะฑาฏาฏะบ\b",             "MacBook"),
    (r"\bะผะฐะบะฑัƒะบ\b",              "MacBook"),
    (r"\bัะนั€ะฟะพะดั\b",             "AirPods"),
    (r"\bัะฐะผััƒะฝะณ\b",             "Samsung"),
    (r"\bะณัƒะณะป\b",                "Google"),
    (r"\bะณาฏาฏะณัะป\b",              "Google"),
    (r"\bั…ัƒะฐะฒะตะน\b",              "Huawei"),
    (r"\bัˆะฐะพะผะธ\b",               "Xiaomi"),
    (r"\bััะพะผะธ\b",               "Xiaomi"),
    (r"\bั€ะตะดะผะธ\b",               "Redmi"),
    (r"\bัะฟะป\b",                 "Apple"),

    # Apps / Social
    (r"\bัŽั‚ัƒะฑ\b",                "YouTube"),
    (r"\bัŽั‚าฏาฏะฑ\b",               "YouTube"),
    (r"\bั‚ะธะบ\s*ั‚ะพะบ\b",           "TikTok"),
    (r"\bั‚ะธะบั‚ะพะบ\b",              "TikTok"),
    (r"\bะธะฝัั‚ะฐะณั€ะฐะผ\b",           "Instagram"),
    (r"\bั„ัะนัะฑาฏาฏะบ\b",            "Facebook"),
    (r"\bั„ะตะนัะฑัƒะบ\b",             "Facebook"),
    (r"\bะฒะฐั†ะฐะฟ\b",               "WhatsApp"),
    (r"\bะฒะฐั‚ัะฐะฟ\b",              "WhatsApp"),
    (r"\bั‚ะตะปะตะณั€ะฐะผ\b",            "Telegram"),
    (r"\bะดะธัะบะพั€ะด\b",             "Discord"),
    (r"\bั‚ะฒะธั‚ั‚ะตั€\b",             "Twitter"),
    (r"\bัะฟะพั‚ะธั„ะฐะน\b",            "Spotify"),
    (r"\bะฝะตั‚ั„ะปะธะบั\b",            "Netflix"),
    (r"\bัƒะฑะตั€\b",                "Uber"),
    (r"\bั‡ะฐั‚\s*ะถะฟั‚\b",           "ChatGPT"),
    (r"\bั‡ะฐั‚ะณะฟั‚\b",              "ChatGPT"),
    (r"\bะผะธะดะถะพั€ะฝะธ\b",            "Midjourney"),

    # Music / celebs
    (r"\bะฑั‚ั\b",                 "BTS"),
    (r"\bะฑั‚ัั\b",                "BTS"),
    (r"\bะฑะปัะบะฟะธะฝะบ\b",            "BLACKPINK"),
    (r"\bะฑะปัะบ\s*ะฟะธะฝะบ\b",         "BLACKPINK"),

    # Common proper nouns
    (r"\bัƒะปะฐะฐะฝะฑะฐะฐั‚ะฐั€\b",         "ะฃะปะฐะฐะฝะฑะฐะฐั‚ะฐั€"),
    (r"\bะผะพะฝะณะพะป\s+ัƒะปั\b",        "ะœะพะฝะณะพะป ะฃะปั"),
    (r"\bะทะฐัะณะธะนะฝ\s+ะณะฐะทะฐั€\b",     "ะ—ะฐัะณะธะนะฝ ะณะฐะทะฐั€"),
    (r"\bัƒะธั…\b",                 "ะฃะ˜ะฅ"),
    (r"\bะผัƒะธั\b",                "ะœะฃะ˜ะก"),
]

COMPILED_BRAND_FIXES = [(re.compile(pat, re.IGNORECASE), rep) for pat, rep in BRAND_FIXES]


def apply_brand_fixes(text: str) -> str:
    """Apply brand name corrections. Case-insensitive substitution."""
    if not text:
        return text
    for pattern, replacement in COMPILED_BRAND_FIXES:
        text = pattern.sub(replacement, text)
    return text


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# OUTPUT PARSING & HALLUCINATION GUARD
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def strip_reasoning(raw_output: str) -> str:
    """
    Extract the final edited version from model output. The training format is:

        ะญะฝั ำฉะณาฏาฏะปะฑัั€ั‚ ะดะฐั€ะฐะฐั… ะทาฏะนะปั ะทะฐัะฐั… ั…ัั€ัะณั‚ัะน:
        1. ...
        2. ...

        ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:
        <FINAL TEXT>
        </think>
        <FINAL TEXT again>

    We want just <FINAL TEXT>. Strategy:
      1. Split on "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:" โ€” take everything after
      2. Split on "</think>" โ€” take first half (before tag)
      3. Strip whitespace
      4. If step 1 fails, return input as-is (assume model output was direct)
    """
    if not raw_output:
        return ""

    text = raw_output

    # Prefer content after "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:"
    marker = "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:"
    if marker in text:
        text = text.split(marker, 1)[1]
    else:
        # Fallback markers
        for alt in ("ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ำฉะณาฏาฏะปะฑัั€:", "ะญั†ัะธะนะฝ ั…ัƒะฒะธะปะฑะฐั€:", "ะ—ำฉะฒ ั…ัƒะฒะธะปะฑะฐั€:"):
            if alt in text:
                text = text.split(alt, 1)[1]
                break

    # Cut at </think> โ€” anything after is a duplicate
    if "</think>" in text:
        text = text.split("</think>", 1)[0]
    if "<think>" in text:
        # take content after <think> ... </think> block OR before it
        parts = text.split("<think>", 1)
        text = parts[0] if parts[0].strip() else parts[1].split("</think>", 1)[-1]

    # Sometimes the chain-of-thought bleeds in โ€” cut at first blank line
    # AFTER a colon list ("1. ..." or similar)
    lines = [ln.rstrip() for ln in text.strip().split("\n")]
    # If first line is a list item, drop lines until we hit blank
    cleaned = []
    skip_list = False
    for ln in lines:
        stripped = ln.strip()
        if re.match(r"^\d+\.\s", stripped):
            skip_list = True
            continue
        if skip_list and stripped == "":
            skip_list = False
            continue
        if skip_list:
            continue
        cleaned.append(ln)

    out = "\n".join(cleaned).strip()
    return out or text.strip()


def hallucination_guard(original: str, edited: str, max_ratio: float = 1.6) -> tuple[str, bool]:
    """
    Guard against hallucination: if the edited text is drastically longer than
    the original OR introduces too many new tokens, fall back to the original
    (optionally with light cleanup).

    Returns (text, fallback_used).
    """
    if not edited:
        return original, True

    orig_len = max(len(original), 1)
    edit_len = len(edited)

    # Rule 1: too much longer (model invented content)
    if edit_len > orig_len * max_ratio and edit_len > orig_len + 40:
        return original, True

    # Rule 2: too much shorter (model truncated unexpectedly)
    if edit_len < orig_len * 0.4 and orig_len > 20:
        return original, True

    # Rule 3: zero overlap with original words (wrong topic)
    orig_words = set(re.findall(r"\w+", original.lower()))
    edit_words = set(re.findall(r"\w+", edited.lower()))
    if orig_words and len(orig_words & edit_words) / len(orig_words) < 0.3:
        return original, True

    return edited, False


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# MODEL LOADING (lazy, fork-safe)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def load_model():
    global MODEL, TOKENIZER, torch
    if MODEL is not None:
        return

    t = time.time()
    print("[LOAD] importing torch...", flush=True)
    import torch as _torch
    torch = _torch

    print("[LOAD] importing transformers + peft...", flush=True)
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import PeftModel

    print(f"[LOAD] CUDA available: {torch.cuda.is_available()}", flush=True)
    if torch.cuda.is_available():
        print(f"[LOAD] device: {torch.cuda.get_device_name(0)}", flush=True)
        torch.cuda.init()
        torch.backends.cuda.matmul.allow_tf32 = True

    print(f"[LOAD] tokenizer from {ADAPTER_REPO}...", flush=True)
    TOKENIZER = AutoTokenizer.from_pretrained(
        ADAPTER_REPO, token=HF_TOKEN, trust_remote_code=True
    )
    if TOKENIZER.pad_token is None:
        TOKENIZER.pad_token = TOKENIZER.eos_token

    print(f"[LOAD] base model {BASE_MODEL}...", flush=True)
    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        token=HF_TOKEN,
        attn_implementation="eager",
    )

    print(f"[LOAD] adapter {ADAPTER_REPO}...", flush=True)
    MODEL = PeftModel.from_pretrained(base, ADAPTER_REPO, token=HF_TOKEN)
    MODEL.eval()

    print(f"[LOAD] ready in {time.time()-t:.1f}s ยท "
          f"VRAM {torch.cuda.memory_allocated()/1e9:.2f}GB", flush=True)


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# INFERENCE
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
INSTRUCTIONS = {
    "edit":      "ะ”ะฐั€ะฐะฐั… ASR-ััั ะณะฐั€ัะฐะฝ ั‚ะตะบัั‚ะธะนะณ ะทะฐัะฒะฐั€ะปะฐะถ, ะทำฉะฒ subtitle ะฑะพะปะณะพะฝะพ ัƒัƒ.",
    "summarize": "ะ”ะฐั€ะฐะฐั… ะฑะธั‡ะปัะณะธะนะฝ ะฐะณัƒัƒะปะณั‹ะณ ั‚ะพะฒั‡ะธะปะฝะพ ัƒัƒ.",
    "rewrite":   "ะ”ะฐั€ะฐะฐั… ำฉะณาฏาฏะปะฑัั€ะธะนะณ ัƒั€ะฐะฝ ะฑะธั‡ะปัะณั‚ัะน ะฑะพะปะณะพะฝ ะทะฐัะฝะฐ ัƒัƒ.",
}


def generate_one(text: str, instruction: str, max_new_tokens: int = 256) -> str:
    """Run the model on a single text with the given instruction."""
    user_msg = f"{instruction}\n\n{text}"
    prompt = TOKENIZER.apply_chat_template(
        [{"role": "user", "content": user_msg}],
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=1024).to(MODEL.device)

    with torch.no_grad():
        out_ids = MODEL.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=1.0,
            repetition_penalty=1.05,
            pad_token_id=TOKENIZER.pad_token_id,
        )

    new_tokens = out_ids[0][inputs["input_ids"].shape[1]:]
    raw = TOKENIZER.decode(new_tokens, skip_special_tokens=True).strip()
    return raw


def handler(event):
    """RunPod serverless entry point."""
    try:
        t_total = time.time()
        load_model()

        inp = event.get("input", {}) or {}
        texts = inp.get("texts")
        if not texts or not isinstance(texts, list):
            return {"error": "Missing 'texts' list in input"}

        mode = inp.get("mode", "edit")
        custom_instruction = inp.get("instruction")
        skip_post = bool(inp.get("skip_post_processing", False))
        max_new_tokens = int(inp.get("max_new_tokens", 256))

        instruction = custom_instruction or INSTRUCTIONS.get(mode, INSTRUCTIONS["edit"])

        edited = []
        fallback_used = []
        total_tokens = 0

        for i, text in enumerate(texts):
            if not text or not text.strip():
                edited.append(text)
                continue

            try:
                raw = generate_one(text, instruction, max_new_tokens=max_new_tokens)
                parsed = strip_reasoning(raw)

                if mode == "edit" and not skip_post:
                    # Hallucination guard
                    guarded, is_fallback = hallucination_guard(text, parsed)
                    # Brand fixes (applied to both fallback and edit)
                    guarded = apply_brand_fixes(guarded)
                    if is_fallback:
                        fallback_used.append(i)
                    edited.append(guarded)
                else:
                    edited.append(parsed)

                total_tokens += len(raw.split())
            except Exception as e:
                print(f"[ERR] segment {i}: {e}", flush=True)
                traceback.print_exc()
                # On any failure, return the original text unchanged
                edited.append(text)
                fallback_used.append(i)

        elapsed = time.time() - t_total
        return {
            "edited": edited,
            "stats": {
                "count": len(texts),
                "time_s": round(elapsed, 2),
                "tokens_per_s": round(total_tokens / elapsed, 1) if elapsed > 0 else 0,
            },
            "fallback_used": fallback_used,
            "mode": mode,
            "model": ADAPTER_REPO,
        }

    except Exception as e:
        traceback.print_exc()
        return {"error": str(e)}


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ENTRY POINT
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
if __name__ == "__main__":
    print(f"[BOOT] total bootstrap time: {time.time()-t0:.1f}s", flush=True)
    runpod.serverless.start({"handler": handler})