|
|
""" |
|
|
SARF Decoder - Balanced version preserving word boundaries. |
|
|
""" |
|
|
import json |
|
|
import re |
|
|
|
|
|
class SARFDecoder: |
|
|
"""Decode SARF tokenizer output with balanced Arabic handling.""" |
|
|
|
|
|
DIACRITICS = set('\u064b\u064c\u064d\u064e\u064f\u0650\u0651\u0652\u0653\u0654\u0655\u0670') |
|
|
AR_LETTERS = set('ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىيﻻﻷﻹﻵ') |
|
|
SUFFIX_PAIRS = {'ها', 'هم', 'هن', 'كم', 'كن', 'نا', 'ون', 'ين', 'ات', 'وا', 'تم', 'تن', 'ني', 'ته', 'تك', 'يم', 'ية'} |
|
|
|
|
|
def __init__(self, morf_map_path): |
|
|
with open(morf_map_path, 'r', encoding='utf-8') as f: |
|
|
self.morf_map = json.load(f) |
|
|
self.reverse_map = {v: k for k, v in self.morf_map.items()} |
|
|
self._build_byte_decoder() |
|
|
self.morphemes = set(self.morf_map.keys()) |
|
|
|
|
|
def _build_byte_decoder(self): |
|
|
bs = list(range(ord("!"), ord("~")+1)) + list(range(ord("¡"), ord("¬")+1)) + list(range(ord("®"), ord("ÿ")+1)) |
|
|
cs = bs[:] |
|
|
n = 0 |
|
|
for b in range(2**8): |
|
|
if b not in bs: |
|
|
bs.append(b) |
|
|
cs.append(2**8 + n) |
|
|
n += 1 |
|
|
self.byte_decoder = {chr(c): bytes([b]) for b, c in zip(bs, cs)} |
|
|
|
|
|
def decode_bytes(self, text): |
|
|
result = [] |
|
|
for char in text: |
|
|
if char in self.byte_decoder: |
|
|
result.append(self.byte_decoder[char]) |
|
|
else: |
|
|
result.append(char.encode('utf-8', errors='replace')) |
|
|
try: |
|
|
return b''.join(result).decode('utf-8', errors='replace') |
|
|
except: |
|
|
return text |
|
|
|
|
|
def reverse_pua(self, text): |
|
|
for pua, morph in self.reverse_map.items(): |
|
|
text = text.replace(pua, morph) |
|
|
return text |
|
|
|
|
|
def post_process(self, text): |
|
|
"""Smart post-processing preserving word boundaries.""" |
|
|
|
|
|
text = re.sub(r' +', ' ', text) |
|
|
text = text.strip() |
|
|
|
|
|
if not text: |
|
|
return text |
|
|
|
|
|
|
|
|
for _ in range(15): |
|
|
prev = text |
|
|
text = re.sub(r'([\u0621-\u064A\u066E-\u06D3]) ([\u064B-\u0655\u0670])', r'\1\2', text) |
|
|
if text == prev: |
|
|
break |
|
|
|
|
|
|
|
|
for _ in range(10): |
|
|
prev = text |
|
|
text = re.sub(r'([\u064B-\u0655\u0670]) ([\u064B-\u0655\u0670])', r'\1\2', text) |
|
|
if text == prev: |
|
|
break |
|
|
|
|
|
|
|
|
for _ in range(15): |
|
|
prev = text |
|
|
text = re.sub(r'([\u064B-\u0655\u0670]) ([\u0621-\u064A\u066E-\u06D3]+)([\u064B-\u0655\u0670])', r'\1\2\3', text) |
|
|
if text == prev: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for _ in range(10): |
|
|
prev = text |
|
|
|
|
|
text = re.sub(r'([\u064B-\u0655\u0670]) ([^\u0627\s][\u0621-\u064A\u066E-\u06D3]*)(?=\s|$)', r'\1\2', text) |
|
|
if text == prev: |
|
|
break |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
prev = text |
|
|
text = re.sub(r'(ال) ([\u0621-\u064A\u066E-\u06D3])', r'\1\2', text) |
|
|
if text == prev: |
|
|
break |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
prev = text |
|
|
text = re.sub( |
|
|
r'([\u0621-\u064A\u066E-\u06D3\u064B-\u0655\u0670]+) ([كهينمت])(?=\s|$|[^\u0600-\u06FF])', |
|
|
r'\1\2', |
|
|
text |
|
|
) |
|
|
if text == prev: |
|
|
break |
|
|
|
|
|
|
|
|
for suffix in self.SUFFIX_PAIRS: |
|
|
text = re.sub( |
|
|
rf'([\u0621-\u064A\u066E-\u06D3\u064B-\u0655\u0670]+) ({suffix})(?=\s|$|[^\u0600-\u06FF])', |
|
|
r'\1\2', |
|
|
text |
|
|
) |
|
|
|
|
|
|
|
|
text = re.sub(r' ([.,!?;:،؛؟%])', r'\1', text) |
|
|
|
|
|
for _ in range(5): |
|
|
prev = text |
|
|
text = re.sub(r'([0-9٠-٩]) ([0-9٠-٩])', r'\1\2', text) |
|
|
if text == prev: |
|
|
break |
|
|
|
|
|
text = re.sub(r' ?@ ?', '@', text) |
|
|
text = re.sub(r' ?: ?// ?', '://', text) |
|
|
text = re.sub(r' ?\. ?([a-zA-Z]{2,})(\s|$)', r'.\1\2', text) |
|
|
|
|
|
text = re.sub(r' +', ' ', text) |
|
|
return text.strip() |
|
|
|
|
|
def decode(self, raw_output): |
|
|
text = self.decode_bytes(raw_output) |
|
|
text = self.reverse_pua(text) |
|
|
text = self.post_process(text) |
|
|
return text |
|
|
|
|
|
|
|
|
class SARFTokenizer: |
|
|
"""Complete SARF tokenizer.""" |
|
|
|
|
|
def __init__(self, tokenizer_path, morf_map_path): |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
import sys |
|
|
sys.path.insert(0, "/root/workspace/smctm") |
|
|
from scripts.rewrite_bytes import ByteRewriter |
|
|
|
|
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) |
|
|
self.rewriter = ByteRewriter(morf_map_path) |
|
|
self.decoder = SARFDecoder(morf_map_path) |
|
|
|
|
|
def encode(self, text, add_special_tokens=False): |
|
|
rewritten = self.rewriter.rewrite_text(text) |
|
|
return self.tokenizer.encode(rewritten, add_special_tokens=add_special_tokens) |
|
|
|
|
|
def decode(self, token_ids, skip_special_tokens=True): |
|
|
raw = self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) |
|
|
return self.decoder.decode(raw) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.tokenizer) |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
return len(self.tokenizer) |
|
|
|