mongol-editor-llm-v1 / handler.py
Tsedee's picture
Add serverless handler (v1)
ffd1353 verified
"""
MonSub LLM Editor โ€” Self-bootstrapping RunPod Serverless Handler
Loads: Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilled (base)
+ Tsedee/mongol-editor-llm-v1 (LoRA adapter) [swap to -v2 after v2 training]
Accepts batches of raw Whisper-style text segments and returns edited
Mongolian subtitle text with post-processing:
- Brand name correction (chitaโ†’GTA, ะฐะธั„ะพะฝโ†’iPhone, etc.)
- Hallucination guard (rejects outputs that are too different from input)
- Chain-of-thought stripping (keeps only "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:" content)
- </think> tag cleanup
API:
Input (JSON):
{
"texts": ["text 1", "text 2", ...], # required
"mode": "edit" | "summarize" | "rewrite", # default: "edit"
"instruction": "optional custom prompt", # optional
"skip_post_processing": false # optional
}
Output:
{
"edited": ["edited 1", "edited 2", ...],
"stats": { "count": N, "time_s": T, "tokens_per_s": X },
"fallback_used": [idx1, idx2, ...] # indices where hallucination guard fired
}
"""
import os, sys, subprocess, time
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# BOOTSTRAP
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def ensure(pkg_import, pip_name=None):
try:
__import__(pkg_import)
except ImportError:
name = pip_name or pkg_import
print(f"[BOOT] installing {name}...", flush=True)
subprocess.run([sys.executable, "-m", "pip", "install", "--quiet", "--no-cache-dir", name], check=True)
print("[BOOT] LLM editor handler starting...", flush=True)
t0 = time.time()
ensure("runpod")
ensure("transformers", "transformers==5.5.0")
ensure("peft", "peft==0.18.1")
ensure("accelerate", "accelerate>=1.0.0")
ensure("huggingface_hub")
print(f"[BOOT] deps ready in {time.time()-t0:.1f}s", flush=True)
# โ”€โ”€ Module-level: only stdlib + runpod โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
import re
import traceback
import runpod
HF_TOKEN = os.environ.get("HF_TOKEN", "")
BASE_MODEL = os.environ.get("BASE_MODEL", "Jackrong/Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilled")
ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "Tsedee/mongol-editor-llm-v1")
MODEL = None
TOKENIZER = None
torch = None # lazy-loaded
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# BRAND CORRECTION DICT โ€” post-processing safety net
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Applied AFTER model output to catch brand names the model missed.
# Case-insensitive substring match with word boundaries where possible.
BRAND_FIXES = [
# (pattern_regex, replacement)
# Games
(r"\bั‡ะธั‚ะฐ\s*5\b", "GTA 5"),
(r"\bะถะธั‚ะฐ\s*5\b", "GTA 5"),
(r"\bะณั‚ะฐ\s*5\b", "GTA 5"),
(r"\bั‡ะธั‚ะฐ\s*6\b", "GTA 6"),
(r"\bะถะธั‚ะฐ\s*6\b", "GTA 6"),
(r"\bะณั‚ะฐ\s*6\b", "GTA 6"),
(r"\bั„ะธั„ะฐ\b", "FIFA"),
(r"\bะบะพะป\s*ะพั„\s*ะดัŽั‚ะธ\b", "Call of Duty"),
(r"\bะบะฐะปะป\s*ะพั„\s*ะดัŽั‚ะธ\b", "Call of Duty"),
(r"\bะผะฐะนะฝะบั€ะฐั„ั‚\b", "Minecraft"),
(r"\bะผะฐะนะฝ\s*ะบั€ะฐั„ั‚\b", "Minecraft"),
(r"\bั€ะพะฑะปะพะบั\b", "Roblox"),
(r"\bั„ะพั€ั‚ะฝะฐะนั‚\b", "Fortnite"),
(r"\bะฒะฐะปัŒะพั€ะฐะฝั‚\b", "Valorant"),
(r"\bะฒะฐะปะพั€ะฐะฝั‚\b", "Valorant"),
(r"\bะฑะฐะณัั‚ะฐั€ะธ\b", "Rockstar Games"),
(r"\bะฑะฐะณัั‚ะฐั€\b", "Rockstar Games"),
(r"\bะฟัƒะฑะณ\b", "PUBG"),
(r"\bะบั\s*ะณะพ\b", "CS:GO"),
(r"\bะดะพั‚ะฐ\s*2\b", "Dota 2"),
(r"\bัŽะฑะธัะพั„ั‚\b", "Ubisoft"),
(r"\bัั‚ะธะผ\b", "Steam"),
# Tech
(r"\bะฐะธั„ะพะฝ\b", "iPhone"),
(r"\bะฐะนั„ะพะฝ\b", "iPhone"),
(r"\bะธะฟะฐะด\b", "iPad"),
(r"\bะฐะนะฟะฐะด\b", "iPad"),
(r"\bะผะฐะบะฑาฏาฏะบ\b", "MacBook"),
(r"\bะผะฐะบะฑัƒะบ\b", "MacBook"),
(r"\bัะนั€ะฟะพะดั\b", "AirPods"),
(r"\bัะฐะผััƒะฝะณ\b", "Samsung"),
(r"\bะณัƒะณะป\b", "Google"),
(r"\bะณาฏาฏะณัะป\b", "Google"),
(r"\bั…ัƒะฐะฒะตะน\b", "Huawei"),
(r"\bัˆะฐะพะผะธ\b", "Xiaomi"),
(r"\bััะพะผะธ\b", "Xiaomi"),
(r"\bั€ะตะดะผะธ\b", "Redmi"),
(r"\bัะฟะป\b", "Apple"),
# Apps / Social
(r"\bัŽั‚ัƒะฑ\b", "YouTube"),
(r"\bัŽั‚าฏาฏะฑ\b", "YouTube"),
(r"\bั‚ะธะบ\s*ั‚ะพะบ\b", "TikTok"),
(r"\bั‚ะธะบั‚ะพะบ\b", "TikTok"),
(r"\bะธะฝัั‚ะฐะณั€ะฐะผ\b", "Instagram"),
(r"\bั„ัะนัะฑาฏาฏะบ\b", "Facebook"),
(r"\bั„ะตะนัะฑัƒะบ\b", "Facebook"),
(r"\bะฒะฐั†ะฐะฟ\b", "WhatsApp"),
(r"\bะฒะฐั‚ัะฐะฟ\b", "WhatsApp"),
(r"\bั‚ะตะปะตะณั€ะฐะผ\b", "Telegram"),
(r"\bะดะธัะบะพั€ะด\b", "Discord"),
(r"\bั‚ะฒะธั‚ั‚ะตั€\b", "Twitter"),
(r"\bัะฟะพั‚ะธั„ะฐะน\b", "Spotify"),
(r"\bะฝะตั‚ั„ะปะธะบั\b", "Netflix"),
(r"\bัƒะฑะตั€\b", "Uber"),
(r"\bั‡ะฐั‚\s*ะถะฟั‚\b", "ChatGPT"),
(r"\bั‡ะฐั‚ะณะฟั‚\b", "ChatGPT"),
(r"\bะผะธะดะถะพั€ะฝะธ\b", "Midjourney"),
# Music / celebs
(r"\bะฑั‚ั\b", "BTS"),
(r"\bะฑั‚ัั\b", "BTS"),
(r"\bะฑะปัะบะฟะธะฝะบ\b", "BLACKPINK"),
(r"\bะฑะปัะบ\s*ะฟะธะฝะบ\b", "BLACKPINK"),
# Common proper nouns
(r"\bัƒะปะฐะฐะฝะฑะฐะฐั‚ะฐั€\b", "ะฃะปะฐะฐะฝะฑะฐะฐั‚ะฐั€"),
(r"\bะผะพะฝะณะพะป\s+ัƒะปั\b", "ะœะพะฝะณะพะป ะฃะปั"),
(r"\bะทะฐัะณะธะนะฝ\s+ะณะฐะทะฐั€\b", "ะ—ะฐัะณะธะนะฝ ะณะฐะทะฐั€"),
(r"\bัƒะธั…\b", "ะฃะ˜ะฅ"),
(r"\bะผัƒะธั\b", "ะœะฃะ˜ะก"),
]
COMPILED_BRAND_FIXES = [(re.compile(pat, re.IGNORECASE), rep) for pat, rep in BRAND_FIXES]
def apply_brand_fixes(text: str) -> str:
"""Apply brand name corrections. Case-insensitive substitution."""
if not text:
return text
for pattern, replacement in COMPILED_BRAND_FIXES:
text = pattern.sub(replacement, text)
return text
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# OUTPUT PARSING & HALLUCINATION GUARD
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def strip_reasoning(raw_output: str) -> str:
"""
Extract the final edited version from model output. The training format is:
ะญะฝั ำฉะณาฏาฏะปะฑัั€ั‚ ะดะฐั€ะฐะฐั… ะทาฏะนะปั ะทะฐัะฐั… ั…ัั€ัะณั‚ัะน:
1. ...
2. ...
ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:
<FINAL TEXT>
</think>
<FINAL TEXT again>
We want just <FINAL TEXT>. Strategy:
1. Split on "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:" โ€” take everything after
2. Split on "</think>" โ€” take first half (before tag)
3. Strip whitespace
4. If step 1 fails, return input as-is (assume model output was direct)
"""
if not raw_output:
return ""
text = raw_output
# Prefer content after "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:"
marker = "ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ั…ัƒะฒะธะปะฑะฐั€:"
if marker in text:
text = text.split(marker, 1)[1]
else:
# Fallback markers
for alt in ("ะ—ะฐัะฒะฐั€ะปะฐัะฐะฝ ำฉะณาฏาฏะปะฑัั€:", "ะญั†ัะธะนะฝ ั…ัƒะฒะธะปะฑะฐั€:", "ะ—ำฉะฒ ั…ัƒะฒะธะปะฑะฐั€:"):
if alt in text:
text = text.split(alt, 1)[1]
break
# Cut at </think> โ€” anything after is a duplicate
if "</think>" in text:
text = text.split("</think>", 1)[0]
if "<think>" in text:
# take content after <think> ... </think> block OR before it
parts = text.split("<think>", 1)
text = parts[0] if parts[0].strip() else parts[1].split("</think>", 1)[-1]
# Sometimes the chain-of-thought bleeds in โ€” cut at first blank line
# AFTER a colon list ("1. ..." or similar)
lines = [ln.rstrip() for ln in text.strip().split("\n")]
# If first line is a list item, drop lines until we hit blank
cleaned = []
skip_list = False
for ln in lines:
stripped = ln.strip()
if re.match(r"^\d+\.\s", stripped):
skip_list = True
continue
if skip_list and stripped == "":
skip_list = False
continue
if skip_list:
continue
cleaned.append(ln)
out = "\n".join(cleaned).strip()
return out or text.strip()
def hallucination_guard(original: str, edited: str, max_ratio: float = 1.6) -> tuple[str, bool]:
"""
Guard against hallucination: if the edited text is drastically longer than
the original OR introduces too many new tokens, fall back to the original
(optionally with light cleanup).
Returns (text, fallback_used).
"""
if not edited:
return original, True
orig_len = max(len(original), 1)
edit_len = len(edited)
# Rule 1: too much longer (model invented content)
if edit_len > orig_len * max_ratio and edit_len > orig_len + 40:
return original, True
# Rule 2: too much shorter (model truncated unexpectedly)
if edit_len < orig_len * 0.4 and orig_len > 20:
return original, True
# Rule 3: zero overlap with original words (wrong topic)
orig_words = set(re.findall(r"\w+", original.lower()))
edit_words = set(re.findall(r"\w+", edited.lower()))
if orig_words and len(orig_words & edit_words) / len(orig_words) < 0.3:
return original, True
return edited, False
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# MODEL LOADING (lazy, fork-safe)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def load_model():
global MODEL, TOKENIZER, torch
if MODEL is not None:
return
t = time.time()
print("[LOAD] importing torch...", flush=True)
import torch as _torch
torch = _torch
print("[LOAD] importing transformers + peft...", flush=True)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
print(f"[LOAD] CUDA available: {torch.cuda.is_available()}", flush=True)
if torch.cuda.is_available():
print(f"[LOAD] device: {torch.cuda.get_device_name(0)}", flush=True)
torch.cuda.init()
torch.backends.cuda.matmul.allow_tf32 = True
print(f"[LOAD] tokenizer from {ADAPTER_REPO}...", flush=True)
TOKENIZER = AutoTokenizer.from_pretrained(
ADAPTER_REPO, token=HF_TOKEN, trust_remote_code=True
)
if TOKENIZER.pad_token is None:
TOKENIZER.pad_token = TOKENIZER.eos_token
print(f"[LOAD] base model {BASE_MODEL}...", flush=True)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
token=HF_TOKEN,
attn_implementation="eager",
)
print(f"[LOAD] adapter {ADAPTER_REPO}...", flush=True)
MODEL = PeftModel.from_pretrained(base, ADAPTER_REPO, token=HF_TOKEN)
MODEL.eval()
print(f"[LOAD] ready in {time.time()-t:.1f}s ยท "
f"VRAM {torch.cuda.memory_allocated()/1e9:.2f}GB", flush=True)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# INFERENCE
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
INSTRUCTIONS = {
"edit": "ะ”ะฐั€ะฐะฐั… ASR-ััั ะณะฐั€ัะฐะฝ ั‚ะตะบัั‚ะธะนะณ ะทะฐัะฒะฐั€ะปะฐะถ, ะทำฉะฒ subtitle ะฑะพะปะณะพะฝะพ ัƒัƒ.",
"summarize": "ะ”ะฐั€ะฐะฐั… ะฑะธั‡ะปัะณะธะนะฝ ะฐะณัƒัƒะปะณั‹ะณ ั‚ะพะฒั‡ะธะปะฝะพ ัƒัƒ.",
"rewrite": "ะ”ะฐั€ะฐะฐั… ำฉะณาฏาฏะปะฑัั€ะธะนะณ ัƒั€ะฐะฝ ะฑะธั‡ะปัะณั‚ัะน ะฑะพะปะณะพะฝ ะทะฐัะฝะฐ ัƒัƒ.",
}
def generate_one(text: str, instruction: str, max_new_tokens: int = 256) -> str:
"""Run the model on a single text with the given instruction."""
user_msg = f"{instruction}\n\n{text}"
prompt = TOKENIZER.apply_chat_template(
[{"role": "user", "content": user_msg}],
tokenize=False,
add_generation_prompt=True,
)
inputs = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=1024).to(MODEL.device)
with torch.no_grad():
out_ids = MODEL.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
repetition_penalty=1.05,
pad_token_id=TOKENIZER.pad_token_id,
)
new_tokens = out_ids[0][inputs["input_ids"].shape[1]:]
raw = TOKENIZER.decode(new_tokens, skip_special_tokens=True).strip()
return raw
def handler(event):
"""RunPod serverless entry point."""
try:
t_total = time.time()
load_model()
inp = event.get("input", {}) or {}
texts = inp.get("texts")
if not texts or not isinstance(texts, list):
return {"error": "Missing 'texts' list in input"}
mode = inp.get("mode", "edit")
custom_instruction = inp.get("instruction")
skip_post = bool(inp.get("skip_post_processing", False))
max_new_tokens = int(inp.get("max_new_tokens", 256))
instruction = custom_instruction or INSTRUCTIONS.get(mode, INSTRUCTIONS["edit"])
edited = []
fallback_used = []
total_tokens = 0
for i, text in enumerate(texts):
if not text or not text.strip():
edited.append(text)
continue
try:
raw = generate_one(text, instruction, max_new_tokens=max_new_tokens)
parsed = strip_reasoning(raw)
if mode == "edit" and not skip_post:
# Hallucination guard
guarded, is_fallback = hallucination_guard(text, parsed)
# Brand fixes (applied to both fallback and edit)
guarded = apply_brand_fixes(guarded)
if is_fallback:
fallback_used.append(i)
edited.append(guarded)
else:
edited.append(parsed)
total_tokens += len(raw.split())
except Exception as e:
print(f"[ERR] segment {i}: {e}", flush=True)
traceback.print_exc()
# On any failure, return the original text unchanged
edited.append(text)
fallback_used.append(i)
elapsed = time.time() - t_total
return {
"edited": edited,
"stats": {
"count": len(texts),
"time_s": round(elapsed, 2),
"tokens_per_s": round(total_tokens / elapsed, 1) if elapsed > 0 else 0,
},
"fallback_used": fallback_used,
"mode": mode,
"model": ADAPTER_REPO,
}
except Exception as e:
traceback.print_exc()
return {"error": str(e)}
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ENTRY POINT
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
if __name__ == "__main__":
print(f"[BOOT] total bootstrap time: {time.time()-t0:.1f}s", flush=True)
runpod.serverless.start({"handler": handler})