myte-parity-sweep / sarf_decoder.py
almaghrabima's picture
Upload sarf_decoder.py with huggingface_hub
f42e12a verified
"""
SARF Decoder - Balanced version preserving word boundaries.
"""
import json
import re
class SARFDecoder:
"""Decode SARF tokenizer output with balanced Arabic handling."""
DIACRITICS = set('\u064b\u064c\u064d\u064e\u064f\u0650\u0651\u0652\u0653\u0654\u0655\u0670')
AR_LETTERS = set('ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىيﻻﻷﻹﻵ')
SUFFIX_PAIRS = {'ها', 'هم', 'هن', 'كم', 'كن', 'نا', 'ون', 'ين', 'ات', 'وا', 'تم', 'تن', 'ني', 'ته', 'تك', 'يم', 'ية'}
def __init__(self, morf_map_path):
with open(morf_map_path, 'r', encoding='utf-8') as f:
self.morf_map = json.load(f)
self.reverse_map = {v: k for k, v in self.morf_map.items()}
self._build_byte_decoder()
self.morphemes = set(self.morf_map.keys())
def _build_byte_decoder(self):
bs = list(range(ord("!"), ord("~")+1)) + list(range(ord("¡"), ord("¬")+1)) + list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
self.byte_decoder = {chr(c): bytes([b]) for b, c in zip(bs, cs)}
def decode_bytes(self, text):
result = []
for char in text:
if char in self.byte_decoder:
result.append(self.byte_decoder[char])
else:
result.append(char.encode('utf-8', errors='replace'))
try:
return b''.join(result).decode('utf-8', errors='replace')
except:
return text
def reverse_pua(self, text):
for pua, morph in self.reverse_map.items():
text = text.replace(pua, morph)
return text
def post_process(self, text):
"""Smart post-processing preserving word boundaries."""
# Step 1: Collapse multiple spaces
text = re.sub(r' +', ' ', text)
text = text.strip()
if not text:
return text
# Step 2: Join letter + space + diacritic (always safe)
for _ in range(15):
prev = text
text = re.sub(r'([\u0621-\u064A\u066E-\u06D3]) ([\u064B-\u0655\u0670])', r'\1\2', text)
if text == prev:
break
# Step 3: Join diacritic + space + diacritic (always safe)
for _ in range(10):
prev = text
text = re.sub(r'([\u064B-\u0655\u0670]) ([\u064B-\u0655\u0670])', r'\1\2', text)
if text == prev:
break
# Step 4: Join diacritic + space + letter(s) + diacritic (within diacritized word)
for _ in range(15):
prev = text
text = re.sub(r'([\u064B-\u0655\u0670]) ([\u0621-\u064A\u066E-\u06D3]+)([\u064B-\u0655\u0670])', r'\1\2\3', text)
if text == prev:
break
# Step 5: Join diacritic + space + letter(s) at TRUE end of word
# Only if followed by space/end AND the letters don't start with ا (new word indicator)
# This avoids joining across word boundaries
for _ in range(10):
prev = text
# Match: diacritic + space + letters (not starting with ا or ال) + end/space
text = re.sub(r'([\u064B-\u0655\u0670]) ([^\u0627\s][\u0621-\u064A\u066E-\u06D3]*)(?=\s|$)', r'\1\2', text)
if text == prev:
break
# Step 6: Join "ال" + space + letter (definite article)
for _ in range(5):
prev = text
text = re.sub(r'(ال) ([\u0621-\u064A\u066E-\u06D3])', r'\1\2', text)
if text == prev:
break
# Step 7: Join Arabic word + space + single suffix letter (not ا)
for _ in range(5):
prev = text
text = re.sub(
r'([\u0621-\u064A\u066E-\u06D3\u064B-\u0655\u0670]+) ([كهينمت])(?=\s|$|[^\u0600-\u06FF])',
r'\1\2',
text
)
if text == prev:
break
# Step 8: Join Arabic word + space + two-char suffix
for suffix in self.SUFFIX_PAIRS:
text = re.sub(
rf'([\u0621-\u064A\u066E-\u06D3\u064B-\u0655\u0670]+) ({suffix})(?=\s|$|[^\u0600-\u06FF])',
r'\1\2',
text
)
# Step 9: Standard cleanup
text = re.sub(r' ([.,!?;:،؛؟%])', r'\1', text)
for _ in range(5):
prev = text
text = re.sub(r'([0-9٠-٩]) ([0-9٠-٩])', r'\1\2', text)
if text == prev:
break
text = re.sub(r' ?@ ?', '@', text)
text = re.sub(r' ?: ?// ?', '://', text)
text = re.sub(r' ?\. ?([a-zA-Z]{2,})(\s|$)', r'.\1\2', text)
text = re.sub(r' +', ' ', text)
return text.strip()
def decode(self, raw_output):
text = self.decode_bytes(raw_output)
text = self.reverse_pua(text)
text = self.post_process(text)
return text
class SARFTokenizer:
"""Complete SARF tokenizer."""
def __init__(self, tokenizer_path, morf_map_path):
from transformers import PreTrainedTokenizerFast
import sys
sys.path.insert(0, "/root/workspace/smctm")
from scripts.rewrite_bytes import ByteRewriter
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
self.rewriter = ByteRewriter(morf_map_path)
self.decoder = SARFDecoder(morf_map_path)
def encode(self, text, add_special_tokens=False):
rewritten = self.rewriter.rewrite_text(text)
return self.tokenizer.encode(rewritten, add_special_tokens=add_special_tokens)
def decode(self, token_ids, skip_special_tokens=True):
raw = self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
return self.decoder.decode(raw)
def __len__(self):
return len(self.tokenizer)
@property
def vocab_size(self):
return len(self.tokenizer)