Spaces:
Running
Running
Commit
·
7150d27
1
Parent(s):
2cc687a
Deploy application
Browse files- README.md +4 -5
- app.py +186 -0
- arabic_diacritizer_common/__init__.py +24 -0
- arabic_diacritizer_common/cleaners.py +141 -0
- arabic_diacritizer_common/constants.py +150 -0
- arabic_diacritizer_common/postprocessor.py +81 -0
- arabic_diacritizer_common/segmenter.py +94 -0
- arabic_diacritizer_common/tokenizer.py +169 -0
- diacritizer/__init__.py +11 -0
- diacritizer/diacritizer.py +132 -0
- diacritizer/exceptions.py +11 -0
- diacritizer/hub_manager.py +103 -0
- diacritizer/predictor.py +46 -0
- requirements.txt +12 -0
README.md
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title: Arabic Diacritizer
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.47.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Arabic Diacritizer Demo
|
| 3 |
+
emoji: ⚡
|
| 4 |
+
colorFrom: blue
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.47.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import time
|
| 3 |
+
from diacritizer import Diacritizer, ModelNotFound
|
| 4 |
+
|
| 5 |
+
MODEL_INFO = {
|
| 6 |
+
"bilstm": {
|
| 7 |
+
"display_name": "BiLSTM",
|
| 8 |
+
"models": {
|
| 9 |
+
"medium": {
|
| 10 |
+
"size": "4 MB",
|
| 11 |
+
"details": "Balanced speed and accuracy.",
|
| 12 |
+
},
|
| 13 |
+
"large": {
|
| 14 |
+
"size": "15.5 MB",
|
| 15 |
+
"details": "Highest accuracy model.",
|
| 16 |
+
},
|
| 17 |
+
},
|
| 18 |
+
},
|
| 19 |
+
"bigru": {
|
| 20 |
+
"display_name": "BiGRU",
|
| 21 |
+
"models": {
|
| 22 |
+
"medium": {
|
| 23 |
+
"size": "3.8 MB",
|
| 24 |
+
"details": "Slightly faster than BiLSTM with comparable accuracy.",
|
| 25 |
+
},
|
| 26 |
+
"large": {
|
| 27 |
+
"size": "14.9 MB",
|
| 28 |
+
"details": "High accuracy alternative to the BiLSTM model.",
|
| 29 |
+
},
|
| 30 |
+
},
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
MODEL_CACHE = {}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_model(architecture: str, size: str, progress=gr.Progress()):
|
| 38 |
+
"""
|
| 39 |
+
Lazily loads and caches a Diacritizer model.
|
| 40 |
+
Includes user feedback via gr.Progress to show loading status.
|
| 41 |
+
"""
|
| 42 |
+
model_key = f"{architecture}/{size}"
|
| 43 |
+
if model_key not in MODEL_CACHE:
|
| 44 |
+
progress(0.5, desc=f"Loading {architecture}/{size} model...")
|
| 45 |
+
try:
|
| 46 |
+
MODEL_CACHE[model_key] = Diacritizer(architecture=architecture, size=size)
|
| 47 |
+
except ModelNotFound:
|
| 48 |
+
raise gr.Error(
|
| 49 |
+
f"The requested model ({model_key}) was not found on the Hugging Face Hub."
|
| 50 |
+
)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
raise gr.Error(f"An unexpected error occurred while loading the model: {e}")
|
| 53 |
+
return MODEL_CACHE[model_key]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def diacritize_text(text: str, architecture: str, size: str, progress=gr.Progress()):
|
| 57 |
+
"""
|
| 58 |
+
Main function to diacritize text, now with progress tracking.
|
| 59 |
+
"""
|
| 60 |
+
if not text or not text.strip():
|
| 61 |
+
return "", "0.000s", "Please enter some text to diacritize."
|
| 62 |
+
|
| 63 |
+
progress(0, desc="Loading model...")
|
| 64 |
+
diacritizer = get_model(architecture, size, progress)
|
| 65 |
+
|
| 66 |
+
progress(0.8, desc="Diacritizing text...")
|
| 67 |
+
start_time = time.time()
|
| 68 |
+
diacritized_text = diacritizer.diacritize(text)
|
| 69 |
+
end_time = time.time()
|
| 70 |
+
|
| 71 |
+
inference_time = f"{end_time - start_time:.3f}s"
|
| 72 |
+
|
| 73 |
+
# Update the info text with the final result details
|
| 74 |
+
model_details = MODEL_INFO[architecture]["models"][size]["details"]
|
| 75 |
+
final_info_text = f"**Model:** {architecture}/{size} | **Size:** {MODEL_INFO[architecture]['models'][size]['size']} | {model_details}"
|
| 76 |
+
|
| 77 |
+
return diacritized_text, inference_time, final_info_text
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def update_available_sizes(architecture: str):
|
| 81 |
+
"""Callback to update the size choices when the architecture changes."""
|
| 82 |
+
available_sizes = list(MODEL_INFO[architecture]["models"].keys())
|
| 83 |
+
# Return a new Radio component with updated choices and a default value
|
| 84 |
+
return gr.Radio(
|
| 85 |
+
choices=available_sizes,
|
| 86 |
+
value=available_sizes[0], # Default to the first available size
|
| 87 |
+
label="Model Size",
|
| 88 |
+
info="Select the model size.",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
theme = gr.themes.Soft(
|
| 93 |
+
primary_hue="zinc",
|
| 94 |
+
secondary_hue="blue",
|
| 95 |
+
neutral_hue="slate",
|
| 96 |
+
font=(gr.themes.GoogleFont("Noto Sans"), gr.themes.GoogleFont("Noto Sans Arabic")),
|
| 97 |
+
).set(
|
| 98 |
+
body_background_fill_dark="#111827" # A slightly off-black for dark mode
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
DESCRIPTION = """
|
| 102 |
+
# ⚡ End-to-End Arabic Diacritizer
|
| 103 |
+
A lightweight and efficient model for automatic Arabic diacritization.
|
| 104 |
+
Select an architecture and size, enter some text, and see it in action. For more details, visit the
|
| 105 |
+
[GitHub repository](https://github.com/muhammad-abdelsattar/arabic-diacritizer).
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
with gr.Blocks(theme=theme, css=".footer {display: none !important}") as demo:
|
| 109 |
+
gr.Markdown(DESCRIPTION)
|
| 110 |
+
|
| 111 |
+
with gr.Row():
|
| 112 |
+
with gr.Column(scale=3):
|
| 113 |
+
with gr.Row():
|
| 114 |
+
arch_selector = gr.Radio(
|
| 115 |
+
choices=[
|
| 116 |
+
(info["display_name"], arch)
|
| 117 |
+
for arch, info in MODEL_INFO.items()
|
| 118 |
+
],
|
| 119 |
+
label="Model Architecture",
|
| 120 |
+
value="bilstm",
|
| 121 |
+
info="Select the model architecture.",
|
| 122 |
+
)
|
| 123 |
+
model_selector = gr.Radio(
|
| 124 |
+
choices=["medium", "large"],
|
| 125 |
+
label="Model Size",
|
| 126 |
+
value="medium",
|
| 127 |
+
info="Select the model size.",
|
| 128 |
+
)
|
| 129 |
+
info_display = gr.Markdown(
|
| 130 |
+
"**Model:** bilstm/medium | **Size:** 4 MB | Balanced speed and accuracy. (Formerly 'small')",
|
| 131 |
+
elem_id="info-display",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
with gr.Column(scale=1):
|
| 135 |
+
inference_time_output = gr.Textbox(
|
| 136 |
+
label="Inference Time", interactive=False, max_lines=1
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
with gr.Row(equal_height=True):
|
| 140 |
+
with gr.Column():
|
| 141 |
+
input_text = gr.Textbox(
|
| 142 |
+
label="Input Text (Undiacritized)",
|
| 143 |
+
placeholder="اكتب جملة عربية غير مشكولة هنا...",
|
| 144 |
+
lines=8,
|
| 145 |
+
rtl=True,
|
| 146 |
+
)
|
| 147 |
+
with gr.Column():
|
| 148 |
+
output_text = gr.Textbox(
|
| 149 |
+
label="Output Text (Diacritized)",
|
| 150 |
+
lines=8,
|
| 151 |
+
rtl=True,
|
| 152 |
+
interactive=False,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
submit_button = gr.Button("Diacritize ✨", variant="primary")
|
| 156 |
+
|
| 157 |
+
gr.Examples(
|
| 158 |
+
[
|
| 159 |
+
["أعلنت الشركة عن نتائجها المالية للربع الأول من العام."],
|
| 160 |
+
["إن مع العسر يسرا."],
|
| 161 |
+
["هل يمكن للذكاء الاصطناعي أن يكون مبدعا؟"],
|
| 162 |
+
["كان المتنبي شاعرا عظيما في العصر العباسي."],
|
| 163 |
+
],
|
| 164 |
+
inputs=input_text,
|
| 165 |
+
label="Examples",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
submit_button.click(
|
| 169 |
+
fn=diacritize_text,
|
| 170 |
+
inputs=[input_text, arch_selector, model_selector],
|
| 171 |
+
outputs=[output_text, inference_time_output, info_display],
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# When architecture changes, update the available sizes
|
| 175 |
+
arch_selector.change(
|
| 176 |
+
fn=update_available_sizes, inputs=arch_selector, outputs=model_selector
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
# Pre-load the default model for a faster first-time user experience
|
| 182 |
+
print("Pre-loading default 'bilstm/medium' model...")
|
| 183 |
+
get_model(architecture="bilstm", size="medium")
|
| 184 |
+
print("Default model loaded successfully.")
|
| 185 |
+
|
| 186 |
+
demo.launch()
|
arabic_diacritizer_common/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .constants import (
|
| 2 |
+
ArabicDiacritics,
|
| 3 |
+
VALID_ARABIC_CHARS,
|
| 4 |
+
DIACRITIC_CHARS,
|
| 5 |
+
ARABIC_LETTERS,
|
| 6 |
+
ARABIC_LETTERS_REGEX,
|
| 7 |
+
)
|
| 8 |
+
from .cleaners import TextCleaner, DiacriticValidator
|
| 9 |
+
from .segmenter import TextSegmenter
|
| 10 |
+
from .tokenizer import CharTokenizer
|
| 11 |
+
from .postprocessor import Postprocessor
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"ArabicDiacritics",
|
| 15 |
+
"VALID_ARABIC_CHARS",
|
| 16 |
+
"DIACRITIC_CHARS",
|
| 17 |
+
"ARABIC_LETTERS",
|
| 18 |
+
"TextCleaner",
|
| 19 |
+
"Postprocessor",
|
| 20 |
+
"DiacriticValidator",
|
| 21 |
+
"TextSegmenter",
|
| 22 |
+
"CharTokenizer",
|
| 23 |
+
"ARABIC_LETTERS_REGEX",
|
| 24 |
+
]
|
arabic_diacritizer_common/cleaners.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
from .constants import (
|
| 4 |
+
VALID_ARABIC_CHARS,
|
| 5 |
+
DIACRITIC_CHARS,
|
| 6 |
+
INVALID_SEQUENCES,
|
| 7 |
+
ALL_VALID_DIACRITICS,
|
| 8 |
+
ArabicDiacritics,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Whitespace regex
|
| 12 |
+
_whitespace_re = re.compile(r"\s+")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TextCleaner:
|
| 16 |
+
"""Modular text cleaning utilities"""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def collapse_whitespace(text: str) -> str:
|
| 20 |
+
"""Collapse multiple whitespace characters into a single space"""
|
| 21 |
+
return re.sub(_whitespace_re, " ", text).strip()
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def filter_valid_arabic(text: str) -> str:
|
| 25 |
+
"""Keep only valid Arabic characters, punctuation, and diacritics"""
|
| 26 |
+
return "".join(char for char in text if char in VALID_ARABIC_CHARS)
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def remove_diacritics(text: str) -> str:
|
| 30 |
+
"""Remove all diacritic characters from text"""
|
| 31 |
+
return "".join(ch for ch in text if ch not in DIACRITIC_CHARS)
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def normalize_text(text: str) -> str:
|
| 35 |
+
"""Normalize common text irregularities & diacritic order"""
|
| 36 |
+
# Canonicalize diacritic order (make sure Shadda always comes first)
|
| 37 |
+
for invalid, correct in INVALID_SEQUENCES.items():
|
| 38 |
+
text = text.replace(invalid, correct)
|
| 39 |
+
|
| 40 |
+
# Normalize alef variants to bare alef (optional, safer for training)
|
| 41 |
+
normalize_map = {"أ": "ا", "إ": "ا", "آ": "ا", "ٱ": "ا"}
|
| 42 |
+
for k, v in normalize_map.items():
|
| 43 |
+
text = text.replace(k, v)
|
| 44 |
+
|
| 45 |
+
# Remove Tatweel (ـ) since it is purely decorative
|
| 46 |
+
text = text.replace("ـ", "")
|
| 47 |
+
return text
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def clean_text(
|
| 51 |
+
text: str, keep_valid_only: bool = True, normalize: bool = False
|
| 52 |
+
) -> str:
|
| 53 |
+
"""Complete cleaning pipeline: normalize → optional filtering → collapse ws"""
|
| 54 |
+
if normalize:
|
| 55 |
+
text = TextCleaner.normalize_text(text)
|
| 56 |
+
if keep_valid_only:
|
| 57 |
+
text = TextCleaner.filter_valid_arabic(text)
|
| 58 |
+
return TextCleaner.collapse_whitespace(text)
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def strip_diacritics(text: str) -> str:
|
| 62 |
+
"""Efficiently strips all diacritic marks from an Arabic string."""
|
| 63 |
+
return "".join(char for char in text if char not in DIACRITIC_CHARS)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class DiacriticValidator:
|
| 67 |
+
"""Handles diacritic validation and extraction"""
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def extract_diacritics(text: str) -> Tuple[str, List[str]]:
|
| 71 |
+
"""
|
| 72 |
+
Extract base text and list of diacritics.
|
| 73 |
+
Each base character gets an associated diacritic string (possibly multiple).
|
| 74 |
+
Example:
|
| 75 |
+
"بَّ" → ("ب", ["َّ"])
|
| 76 |
+
"""
|
| 77 |
+
base_chars = []
|
| 78 |
+
diacritics = []
|
| 79 |
+
|
| 80 |
+
i = 0
|
| 81 |
+
while i < len(text):
|
| 82 |
+
char = text[i]
|
| 83 |
+
if char in DIACRITIC_CHARS:
|
| 84 |
+
# attach to previous base character if exists
|
| 85 |
+
if base_chars:
|
| 86 |
+
# Append this diacritic to most recent slot
|
| 87 |
+
diacritics[-1] = diacritics[-1] + char
|
| 88 |
+
else:
|
| 89 |
+
# Stray diacritic at beginning — skip or treat as invalid
|
| 90 |
+
pass
|
| 91 |
+
else:
|
| 92 |
+
# New base char: allocate diacritic slot
|
| 93 |
+
base_chars.append(char)
|
| 94 |
+
diacritics.append("")
|
| 95 |
+
i += 1
|
| 96 |
+
|
| 97 |
+
# Normalize combined diacritics to canonical representations
|
| 98 |
+
normalized_diacritics = []
|
| 99 |
+
for d in diacritics:
|
| 100 |
+
if d in ALL_VALID_DIACRITICS:
|
| 101 |
+
normalized_diacritics.append(d)
|
| 102 |
+
else:
|
| 103 |
+
# try to reorder if contains shadda + vowel
|
| 104 |
+
if "ّ" in d:
|
| 105 |
+
# move shadda to front
|
| 106 |
+
d = "ّ" + "".join(c for c in d if c != "ّ")
|
| 107 |
+
# keep only known chars
|
| 108 |
+
d = "".join(c for c in d if c in DIACRITIC_CHARS)
|
| 109 |
+
normalized_diacritics.append(d)
|
| 110 |
+
return "".join(base_chars), normalized_diacritics
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def validate_diacritics(
|
| 114 |
+
text: str, require_any: bool = False, strict: bool = False
|
| 115 |
+
) -> str:
|
| 116 |
+
"""
|
| 117 |
+
Validate that text diacritics are well-formed.
|
| 118 |
+
- require_any: if True, reject sentences with no diacritics at all.
|
| 119 |
+
- strict: if True, reject unknown/malformed diacritics, else sanitize them.
|
| 120 |
+
Returns text if valid, otherwise "".
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
base_text, diacritics_list = DiacriticValidator.extract_diacritics(text)
|
| 124 |
+
|
| 125 |
+
# Optionally require that at least one diacritic is present
|
| 126 |
+
if require_any:
|
| 127 |
+
if not any(
|
| 128 |
+
d
|
| 129 |
+
for d in diacritics_list
|
| 130 |
+
if d != ArabicDiacritics.NO_DIACRITIC.value
|
| 131 |
+
):
|
| 132 |
+
return ""
|
| 133 |
+
|
| 134 |
+
# In strict mode, reject any diacritic not in valid set
|
| 135 |
+
if strict:
|
| 136 |
+
for d in diacritics_list:
|
| 137 |
+
if d not in ALL_VALID_DIACRITICS:
|
| 138 |
+
return ""
|
| 139 |
+
return text
|
| 140 |
+
except Exception:
|
| 141 |
+
return ""
|
arabic_diacritizer_common/constants.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
from typing import Set, FrozenSet
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ArabicDiacritics(enum.Enum):
|
| 7 |
+
"""All possible Arabic diacritics (standard + extended)."""
|
| 8 |
+
|
| 9 |
+
NO_DIACRITIC = ""
|
| 10 |
+
SUKOON = "ْ"
|
| 11 |
+
SHADDA = "ّ"
|
| 12 |
+
DAMMA = "ُ"
|
| 13 |
+
FATHA = "َ"
|
| 14 |
+
KASRA = "ِ"
|
| 15 |
+
TANWEEN_DAMMA = "ٌ"
|
| 16 |
+
TANWEEN_FATHA = "ً"
|
| 17 |
+
TANWEEN_KASRA = "ٍ"
|
| 18 |
+
# Combinations where order may vary in Unicode
|
| 19 |
+
SHADDA_PLUS_DAMMA = "ُّ" # normalize to 'shadda then vowel'
|
| 20 |
+
SHADDA_PLUS_FATHA = "َّ"
|
| 21 |
+
SHADDA_PLUS_KASRA = "ِّ"
|
| 22 |
+
SHADDA_PLUS_TANWEEN_DAMMA = "ٌّ"
|
| 23 |
+
SHADDA_PLUS_TANWEEN_FATHA = "ًّ"
|
| 24 |
+
SHADDA_PLUS_TANWEEN_KASRA = "ٍّ"
|
| 25 |
+
|
| 26 |
+
# Quranic / orthographic additions
|
| 27 |
+
DAGGER_ALEF = "ٰ" # superscript Alef (dagger)
|
| 28 |
+
MADDA = "ٓ" # Maddah
|
| 29 |
+
WASLA = "ٱ" # Hamzat Wasl (technically letter with mark)
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def chars(cls) -> Set[str]:
|
| 33 |
+
"""Return set of atomic (single-character) diacritics."""
|
| 34 |
+
return {
|
| 35 |
+
cls.SUKOON.value,
|
| 36 |
+
cls.SHADDA.value,
|
| 37 |
+
cls.DAMMA.value,
|
| 38 |
+
cls.FATHA.value,
|
| 39 |
+
cls.KASRA.value,
|
| 40 |
+
cls.TANWEEN_DAMMA.value,
|
| 41 |
+
cls.TANWEEN_FATHA.value,
|
| 42 |
+
cls.TANWEEN_KASRA.value,
|
| 43 |
+
cls.DAGGER_ALEF.value,
|
| 44 |
+
cls.MADDA.value,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def valid_combinations(cls) -> Set[str]:
|
| 49 |
+
"""Return full set of valid diacritic combinations."""
|
| 50 |
+
return {
|
| 51 |
+
cls.NO_DIACRITIC.value,
|
| 52 |
+
# Singles
|
| 53 |
+
cls.SUKOON.value,
|
| 54 |
+
cls.DAMMA.value,
|
| 55 |
+
cls.FATHA.value,
|
| 56 |
+
cls.KASRA.value,
|
| 57 |
+
cls.TANWEEN_DAMMA.value,
|
| 58 |
+
cls.TANWEEN_FATHA.value,
|
| 59 |
+
cls.TANWEEN_KASRA.value,
|
| 60 |
+
cls.DAGGER_ALEF.value,
|
| 61 |
+
cls.MADDA.value,
|
| 62 |
+
# Shadda combos
|
| 63 |
+
cls.SHADDA_PLUS_DAMMA.value,
|
| 64 |
+
cls.SHADDA_PLUS_FATHA.value,
|
| 65 |
+
cls.SHADDA_PLUS_KASRA.value,
|
| 66 |
+
cls.SHADDA_PLUS_TANWEEN_DAMMA.value,
|
| 67 |
+
cls.SHADDA_PLUS_TANWEEN_FATHA.value,
|
| 68 |
+
cls.SHADDA_PLUS_TANWEEN_KASRA.value,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def is_valid_diacritic(cls, diacritic: str) -> bool:
|
| 73 |
+
return diacritic in cls.valid_combinations()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Character sets
|
| 77 |
+
WORD_SEPARATOR = " "
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Arabic letters base Unicode block (0600–06FF covers standard Arabic letters)
|
| 81 |
+
ARABIC_LETTERS_BASE = [chr(c) for c in range(0x0621, 0x064B)]
|
| 82 |
+
# Extended Arabic letters (found in borrowed words, Persian/Urdu usage)
|
| 83 |
+
ARABIC_LETTERS_EXTENDED_BLOCK = [
|
| 84 |
+
"ى", # Alef Maqsura
|
| 85 |
+
"ة", # Taa Marbuta
|
| 86 |
+
"پ",
|
| 87 |
+
"چ",
|
| 88 |
+
"ڤ",
|
| 89 |
+
"گ", # Persian/Urdu additions
|
| 90 |
+
]
|
| 91 |
+
ALEF_VARIANTS = {"ا", "أ", "إ", "آ"}
|
| 92 |
+
|
| 93 |
+
# Merge all letters
|
| 94 |
+
ARABIC_LETTERS = frozenset(
|
| 95 |
+
ARABIC_LETTERS_BASE + ARABIC_LETTERS_EXTENDED_BLOCK + list(ALEF_VARIANTS)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Punctuation
|
| 99 |
+
PUNCTUATIONS = frozenset(
|
| 100 |
+
{
|
| 101 |
+
".",
|
| 102 |
+
"،",
|
| 103 |
+
":",
|
| 104 |
+
"؛",
|
| 105 |
+
"-",
|
| 106 |
+
"؟",
|
| 107 |
+
"!",
|
| 108 |
+
"(",
|
| 109 |
+
")",
|
| 110 |
+
"[",
|
| 111 |
+
"]",
|
| 112 |
+
'"',
|
| 113 |
+
"«",
|
| 114 |
+
"»",
|
| 115 |
+
"/",
|
| 116 |
+
";",
|
| 117 |
+
",",
|
| 118 |
+
"…",
|
| 119 |
+
"ـ", # ellipsis + tatweel
|
| 120 |
+
}
|
| 121 |
+
)
|
| 122 |
+
SENTENCE_DELIMITERS = {".", "؟", "!", "،", ":", "؛", "…"}
|
| 123 |
+
WORD_DELIMITERS = {WORD_SEPARATOR, *SENTENCE_DELIMITERS}
|
| 124 |
+
|
| 125 |
+
# Diacritics sets
|
| 126 |
+
DIACRITIC_CHARS = ArabicDiacritics.chars()
|
| 127 |
+
ALL_VALID_DIACRITICS = ArabicDiacritics.valid_combinations()
|
| 128 |
+
|
| 129 |
+
# All valid characters
|
| 130 |
+
VALID_ARABIC_CHARS = {WORD_SEPARATOR, *ARABIC_LETTERS, *PUNCTUATIONS, *DIACRITIC_CHARS}
|
| 131 |
+
|
| 132 |
+
# Text normalization (fixes diacritic ordering inconsistencies)
|
| 133 |
+
INVALID_SEQUENCES = {
|
| 134 |
+
# Normalize to canonical "SHADDA first, VOWEL after"
|
| 135 |
+
"َّ": "َّ", # fatha + shadda → shadda + fatha
|
| 136 |
+
"ِّ": "ِّ", # kasra + shadda → shadda + kasra
|
| 137 |
+
"ُّ": "ُّ", # damma + shadda → shadda + damma
|
| 138 |
+
"ًّ": "ًّ", # tanween fatha
|
| 139 |
+
"ٍّ": "ٍّ", # tanween kasra
|
| 140 |
+
"ٌّ": "ٌّ", # tanween damma
|
| 141 |
+
# Punctuation spacing corrections
|
| 142 |
+
" ،": "،",
|
| 143 |
+
" .": ".",
|
| 144 |
+
" ؟": "؟",
|
| 145 |
+
" ؛": "؛",
|
| 146 |
+
" …": "…",
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# Regex for Arabic letters
|
| 150 |
+
ARABIC_LETTERS_REGEX = re.compile(f'[{"".join(ARABIC_LETTERS)}]+')
|
arabic_diacritizer_common/postprocessor.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .constants import DIACRITIC_CHARS, ArabicDiacritics
|
| 2 |
+
from .cleaners import DiacriticValidator, TextCleaner
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Postprocessor:
|
| 6 |
+
"""
|
| 7 |
+
Provides rule-based post-processing to correct common linguistic errors
|
| 8 |
+
in model-generated diacritized text.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def postprocess(text: str) -> str:
|
| 13 |
+
text = Postprocessor._correct_tanween_fatha_placement(text)
|
| 14 |
+
text = Postprocessor._remove_diacritics_from_alifs(text)
|
| 15 |
+
return text
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def _correct_tanween_fatha_placement(text: str) -> str:
|
| 19 |
+
"""
|
| 20 |
+
Corrects the placement of Tanween Fatha (ً) from a final Alif (ا)
|
| 21 |
+
to the preceding character. This is a common model error.
|
| 22 |
+
|
| 23 |
+
Example: "مَرْحَبَاً" (incorrect) -> "مَرْحَبًا" (correct)
|
| 24 |
+
"""
|
| 25 |
+
words = text.split(" ")
|
| 26 |
+
corrected_words = []
|
| 27 |
+
tanween_fatha = ArabicDiacritics.TANWEEN_FATHA.value
|
| 28 |
+
|
| 29 |
+
for word in words:
|
| 30 |
+
|
| 31 |
+
# Ensure the word is long enough to have a character before a final Alif.
|
| 32 |
+
if len(word) < 2:
|
| 33 |
+
corrected_words.append(word)
|
| 34 |
+
continue
|
| 35 |
+
|
| 36 |
+
base_chars, diacritics = DiacriticValidator.extract_diacritics(word)
|
| 37 |
+
|
| 38 |
+
if (
|
| 39 |
+
len(base_chars) > 1
|
| 40 |
+
and base_chars[-1] == "ا"
|
| 41 |
+
and diacritics[-1] == tanween_fatha
|
| 42 |
+
):
|
| 43 |
+
|
| 44 |
+
diacritics[-2] = tanween_fatha
|
| 45 |
+
|
| 46 |
+
# Ensure the final Alif is left with no diacritic.
|
| 47 |
+
diacritics[-1] = ""
|
| 48 |
+
|
| 49 |
+
corrected_word = "".join(
|
| 50 |
+
[c + d for c, d in zip(base_chars, diacritics)]
|
| 51 |
+
)
|
| 52 |
+
corrected_words.append(corrected_word)
|
| 53 |
+
else:
|
| 54 |
+
corrected_words.append(word)
|
| 55 |
+
|
| 56 |
+
return " ".join(corrected_words)
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def _remove_diacritics_from_alifs(text: str) -> str:
|
| 60 |
+
"""
|
| 61 |
+
Removes any diacritics from plain Alif (ا) and Alif Maqsura (ى)
|
| 62 |
+
anywhere within a word. These characters should not carry short vowels.
|
| 63 |
+
|
| 64 |
+
Example:
|
| 65 |
+
- "عَلَىَ" -> "عَلَى"
|
| 66 |
+
- "اِسْم" -> "اِسْم" (This is correct; the model predicted a base letter, not a diacritic)
|
| 67 |
+
- "كِتَابُ" -> "كِتَاب" (Incorrect model output gets corrected)
|
| 68 |
+
"""
|
| 69 |
+
words = text.split(" ")
|
| 70 |
+
corrected_words = []
|
| 71 |
+
for word in words:
|
| 72 |
+
base_chars, diacritics = DiacriticValidator.extract_diacritics(word)
|
| 73 |
+
|
| 74 |
+
for i in range(len(base_chars)):
|
| 75 |
+
if base_chars[i] == "ا" or base_chars[i] == "ى":
|
| 76 |
+
diacritics[i] = ""
|
| 77 |
+
|
| 78 |
+
corrected_word = "".join([c + d for c, d in zip(base_chars, diacritics)])
|
| 79 |
+
corrected_words.append(corrected_word)
|
| 80 |
+
|
| 81 |
+
return " ".join(corrected_words)
|
arabic_diacritizer_common/segmenter.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List
|
| 3 |
+
from .cleaners import TextCleaner
|
| 4 |
+
from .constants import DIACRITIC_CHARS
|
| 5 |
+
|
| 6 |
+
def grapheme_length(text: str) -> int:
|
| 7 |
+
"""Return logical length of text, counting base characters only (ignore diacritics)."""
|
| 8 |
+
return sum(1 for ch in text if ch not in DIACRITIC_CHARS)
|
| 9 |
+
|
| 10 |
+
class TextSegmenter:
|
| 11 |
+
"""Handles text segmentation and sentence splitting"""
|
| 12 |
+
|
| 13 |
+
# Regex for sentence boundaries (includes multi-char delimiters)
|
| 14 |
+
# Treats sequences like "؟!" or "..." as a single delimiter
|
| 15 |
+
SENTENCE_BOUNDARY_RE = re.compile(r"(؟!|!|\?|\.{2,}|…|،|؛)")
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def segment_sentences(max_chars: int, line: str) -> List[str]:
|
| 19 |
+
"""
|
| 20 |
+
Segment long lines into shorter sentences with max length constraint.
|
| 21 |
+
"""
|
| 22 |
+
line = line.strip()
|
| 23 |
+
if not line:
|
| 24 |
+
return []
|
| 25 |
+
|
| 26 |
+
if grapheme_length(line) <= max_chars:
|
| 27 |
+
return [TextCleaner.collapse_whitespace(line)]
|
| 28 |
+
|
| 29 |
+
# Perform segmentation
|
| 30 |
+
return TextSegmenter._do_segment_sentences(line, max_chars)
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def _do_segment_sentences(line: str, max_chars: int) -> List[str]:
|
| 34 |
+
"""
|
| 35 |
+
Internal recursive sentence segmentation logic.
|
| 36 |
+
"""
|
| 37 |
+
# Split based on boundary regex (keeps delimiters)
|
| 38 |
+
parts = []
|
| 39 |
+
last_idx = 0
|
| 40 |
+
for match in TextSegmenter.SENTENCE_BOUNDARY_RE.finditer(line):
|
| 41 |
+
start, end = match.span()
|
| 42 |
+
segment = line[last_idx:start].strip()
|
| 43 |
+
delimiter = match.group()
|
| 44 |
+
if segment:
|
| 45 |
+
parts.append(segment + delimiter)
|
| 46 |
+
last_idx = end
|
| 47 |
+
|
| 48 |
+
if last_idx < len(line):
|
| 49 |
+
remainder = line[last_idx:].strip()
|
| 50 |
+
if remainder:
|
| 51 |
+
parts.append(remainder)
|
| 52 |
+
|
| 53 |
+
# Now filter by length
|
| 54 |
+
results: List[str] = []
|
| 55 |
+
for sent in parts:
|
| 56 |
+
sent = TextCleaner.collapse_whitespace(sent)
|
| 57 |
+
if not sent:
|
| 58 |
+
continue
|
| 59 |
+
if grapheme_length(sent) <= max_chars:
|
| 60 |
+
results.append(sent)
|
| 61 |
+
else:
|
| 62 |
+
# Recursive split if still too long
|
| 63 |
+
subsegments = TextSegmenter._recursive_split(sent, max_chars)
|
| 64 |
+
results.extend(subsegments)
|
| 65 |
+
|
| 66 |
+
return results
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def _recursive_split(text: str, max_chars: int) -> List[str]:
|
| 70 |
+
"""
|
| 71 |
+
Splits oversized text recursively by words if necessary.
|
| 72 |
+
"""
|
| 73 |
+
words = text.split()
|
| 74 |
+
if not words:
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
segments = []
|
| 78 |
+
cur_segment = []
|
| 79 |
+
cur_len = 0
|
| 80 |
+
|
| 81 |
+
for w in words:
|
| 82 |
+
if cur_len + grapheme_length(w) + 1 > max_chars:
|
| 83 |
+
if cur_segment:
|
| 84 |
+
segments.append(" ".join(cur_segment))
|
| 85 |
+
cur_segment = [w]
|
| 86 |
+
cur_len = grapheme_length(w)
|
| 87 |
+
else:
|
| 88 |
+
cur_segment.append(w)
|
| 89 |
+
cur_len += grapheme_length(w) + 1
|
| 90 |
+
|
| 91 |
+
if cur_segment:
|
| 92 |
+
segments.append(" ".join(cur_segment))
|
| 93 |
+
|
| 94 |
+
return segments
|
arabic_diacritizer_common/tokenizer.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Dict, Tuple, Optional
|
| 4 |
+
|
| 5 |
+
from .constants import ArabicDiacritics, ARABIC_LETTERS, VALID_ARABIC_CHARS
|
| 6 |
+
from .cleaners import DiacriticValidator, TextCleaner
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CharTokenizer:
|
| 10 |
+
"""
|
| 11 |
+
Character-level tokenizer for Arabic diacritization.
|
| 12 |
+
|
| 13 |
+
Input: bare characters (without diacritics)
|
| 14 |
+
Output: per-character diacritic labels (including NO_DIACRITIC)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
char2id: Optional[Dict[str, int]] = None,
|
| 20 |
+
diacritic2id: Optional[Dict[str, int]] = None,
|
| 21 |
+
include_punct: bool = True,
|
| 22 |
+
extra_chars: Optional[List[str]] = None,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
If no vocab mappings are provided, builds defaults from constants.py
|
| 26 |
+
"""
|
| 27 |
+
if char2id is None or diacritic2id is None:
|
| 28 |
+
# Base vocabulary from constants
|
| 29 |
+
vocab_chars = list(ARABIC_LETTERS)
|
| 30 |
+
if include_punct:
|
| 31 |
+
vocab_chars += [
|
| 32 |
+
c for c in VALID_ARABIC_CHARS if c not in ARABIC_LETTERS
|
| 33 |
+
]
|
| 34 |
+
if extra_chars:
|
| 35 |
+
vocab_chars += extra_chars
|
| 36 |
+
vocab_chars = sorted(set(vocab_chars))
|
| 37 |
+
|
| 38 |
+
# Char vocab (+PAD, +UNK)
|
| 39 |
+
char2id = {"<PAD>": 0, "<UNK>": 1}
|
| 40 |
+
char2id.update({ch: idx + 2 for idx, ch in enumerate(vocab_chars)})
|
| 41 |
+
|
| 42 |
+
# Diacritic vocab (includes NO_DIACRITIC "")
|
| 43 |
+
diacritic2id = {
|
| 44 |
+
d: i
|
| 45 |
+
for i, d in enumerate(sorted(ArabicDiacritics.valid_combinations()))
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
self.char2id = char2id
|
| 49 |
+
self.id2char = {i: c for c, i in char2id.items()}
|
| 50 |
+
self.diacritic2id = diacritic2id
|
| 51 |
+
self.id2diacritic = {i: d for d, i in diacritic2id.items()}
|
| 52 |
+
|
| 53 |
+
def save(self, path: str):
|
| 54 |
+
Path(path).write_text(
|
| 55 |
+
json.dumps(
|
| 56 |
+
{"char2id": self.char2id, "diacritic2id": self.diacritic2id},
|
| 57 |
+
ensure_ascii=False,
|
| 58 |
+
indent=2,
|
| 59 |
+
),
|
| 60 |
+
encoding="utf-8",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def load(cls, path: str):
|
| 65 |
+
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
| 66 |
+
return cls(data["char2id"], data["diacritic2id"])
|
| 67 |
+
|
| 68 |
+
def encode(self, text: str) -> Tuple[List[int], List[int]]:
|
| 69 |
+
"""
|
| 70 |
+
Encode a diacritized string → (input_ids, diacritic_labels)
|
| 71 |
+
"""
|
| 72 |
+
# clean_text = TextCleaner.clean_text(text, keep_valid_only=True)
|
| 73 |
+
base_text, diacritics = DiacriticValidator.extract_diacritics(text)
|
| 74 |
+
|
| 75 |
+
input_ids = [self.char2id.get(ch, self.char2id["<PAD>"]) for ch in base_text]
|
| 76 |
+
label_ids = [
|
| 77 |
+
self.diacritic2id.get(
|
| 78 |
+
d, self.diacritic2id[ArabicDiacritics.NO_DIACRITIC.value]
|
| 79 |
+
)
|
| 80 |
+
for d in diacritics
|
| 81 |
+
]
|
| 82 |
+
return input_ids, label_ids
|
| 83 |
+
|
| 84 |
+
def encode_for_inference(self, text: str) -> Tuple[List[int], List[int]]:
|
| 85 |
+
"""
|
| 86 |
+
Encode a diacritized string → (input_ids, diacritic_labels)
|
| 87 |
+
"""
|
| 88 |
+
# clean_text = TextCleaner.clean_text(text, keep_valid_only=True)
|
| 89 |
+
|
| 90 |
+
input_ids = [self.char2id.get(ch, self.char2id["<PAD>"]) for ch in base_text]
|
| 91 |
+
label_ids = [
|
| 92 |
+
self.diacritic2id.get(
|
| 93 |
+
d, self.diacritic2id[ArabicDiacritics.NO_DIACRITIC.value]
|
| 94 |
+
)
|
| 95 |
+
for d in diacritics
|
| 96 |
+
]
|
| 97 |
+
return input_ids, label_ids
|
| 98 |
+
|
| 99 |
+
def decode(
|
| 100 |
+
self, input_ids: List[int], label_ids: List[int], cleanup_mode: str = "clean"
|
| 101 |
+
) -> str:
|
| 102 |
+
"""
|
| 103 |
+
Decode (input_ids, label_ids) -> string with diacritics.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
input_ids: List of character IDs.
|
| 107 |
+
label_ids: List of predicted diacritic IDs.
|
| 108 |
+
cleanup_mode (str): Determines the post-processing strategy.
|
| 109 |
+
- "clean": (Default) Removes diacritics from non-Arabic letters (e.g., punctuation, spaces).
|
| 110 |
+
- "raw": Returns the raw model output without any cleanup.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
The reconstructed, diacritized string.
|
| 114 |
+
"""
|
| 115 |
+
if cleanup_mode not in {"clean", "raw"}:
|
| 116 |
+
raise ValueError("cleanup_mode must be either 'clean' or 'raw'.")
|
| 117 |
+
|
| 118 |
+
chars = [self.id2char.get(i, "<UNK>") for i in input_ids]
|
| 119 |
+
diacs = [self.id2diacritic.get(i, "") for i in label_ids]
|
| 120 |
+
|
| 121 |
+
if cleanup_mode == "raw":
|
| 122 |
+
return "".join(ch + d for ch, d in zip(chars, diacs))
|
| 123 |
+
|
| 124 |
+
# Default is "clean" mode
|
| 125 |
+
cleaned_output = []
|
| 126 |
+
for char, diac in zip(chars, diacs):
|
| 127 |
+
# Only attach a diacritic if the character is a valid Arabic letter
|
| 128 |
+
if char in ARABIC_LETTERS:
|
| 129 |
+
cleaned_output.append(char + diac)
|
| 130 |
+
else:
|
| 131 |
+
cleaned_output.append(
|
| 132 |
+
char
|
| 133 |
+
) # Append the character without the predicted diacritic
|
| 134 |
+
|
| 135 |
+
return "".join(cleaned_output)
|
| 136 |
+
|
| 137 |
+
def decode_inference(
|
| 138 |
+
self,
|
| 139 |
+
text_list: list,
|
| 140 |
+
label_ids: list,
|
| 141 |
+
cleanup_mode: str = "clean",
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Decode (input_ids, label_ids) -> string with diacritics.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
text_list: List of chars in the original text without diacritics.
|
| 148 |
+
label_ids: List of predicted diacritic IDs.
|
| 149 |
+
cleanup_mode (str): Determines the post-processing strategy.
|
| 150 |
+
- "clean": (Default) Removes diacritics from non-Arabic letters (e.g., punctuation, spaces).
|
| 151 |
+
- "raw": Returns the raw model output without any cleanup.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
The reconstructed, diacritized string.
|
| 155 |
+
"""
|
| 156 |
+
if cleanup_mode not in {"clean", "raw"}:
|
| 157 |
+
raise ValueError("cleanup_mode must be either 'clean' or 'raw'.")
|
| 158 |
+
|
| 159 |
+
diacs = [self.id2diacritic.get(i, "") for i in label_ids]
|
| 160 |
+
cleaned_output = []
|
| 161 |
+
for char, diac in zip(text_list, diacs):
|
| 162 |
+
# Only attach a diacritic if the character is a valid Arabic letter
|
| 163 |
+
if char in ARABIC_LETTERS:
|
| 164 |
+
cleaned_output.append(str(char) + str(diac))
|
| 165 |
+
else:
|
| 166 |
+
cleaned_output.append(
|
| 167 |
+
char
|
| 168 |
+
) # Append the character without the predicted diacritic
|
| 169 |
+
return "".join(cleaned_output)
|
diacritizer/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .diacritizer import Diacritizer
|
| 2 |
+
from .exceptions import ModelNotFound, InvalidInputError, DiacritizerException
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"Diacritizer",
|
| 6 |
+
"ModelNotFound",
|
| 7 |
+
"InvalidInputError",
|
| 8 |
+
"DiacritizerException"
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
__version__ = "0.1.0"
|
diacritizer/diacritizer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Union, List, Optional
|
| 4 |
+
import numpy as np
|
| 5 |
+
from arabic_diacritizer_common import (
|
| 6 |
+
CharTokenizer,
|
| 7 |
+
TextSegmenter,
|
| 8 |
+
Postprocessor,
|
| 9 |
+
DiacriticValidator,
|
| 10 |
+
ARABIC_LETTERS_REGEX,
|
| 11 |
+
TextCleaner,
|
| 12 |
+
DIACRITIC_CHARS,
|
| 13 |
+
)
|
| 14 |
+
from .predictor import OnnxPredictor
|
| 15 |
+
from .hub_manager import resolve_model_path, DEFAULT_HUB_REPO_ID
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Diacritizer:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
model_identifier: Optional[str] = None,
|
| 22 |
+
architecture: str = "bilstm",
|
| 23 |
+
size: str = "medium",
|
| 24 |
+
revision: str = "main",
|
| 25 |
+
force_sync: bool = False,
|
| 26 |
+
use_gpu: bool = False,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Initializes the Diacritizer by loading the model and tokenizer.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model_identifier (str, optional): The identifier for the model. Can be a
|
| 33 |
+
local path or a Hugging Face Hub repo ID. Defaults to the official repo.
|
| 34 |
+
architecture (str): The model architecture ('bilstm', 'bigru', etc.).
|
| 35 |
+
Defaults to "bilstm".
|
| 36 |
+
size (str): The model size ('small', 'medium'). Defaults to "medium".
|
| 37 |
+
revision (str): A specific model version from the Hub. Defaults to "main".
|
| 38 |
+
force_sync (bool): If True, forces a re-download. Defaults to False.
|
| 39 |
+
use_gpu (bool): If True, attempts to use CUDA. Defaults to False.
|
| 40 |
+
"""
|
| 41 |
+
self.max_length = -1
|
| 42 |
+
|
| 43 |
+
repo_to_resolve = model_identifier or DEFAULT_HUB_REPO_ID
|
| 44 |
+
|
| 45 |
+
# Pass the new 'architecture' parameter to the resolver function
|
| 46 |
+
onnx_path, vocab_path = resolve_model_path(
|
| 47 |
+
model_identifier=repo_to_resolve,
|
| 48 |
+
architecture=architecture, # MODIFIED
|
| 49 |
+
size=size,
|
| 50 |
+
revision=revision,
|
| 51 |
+
force_sync=force_sync,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.predictor = OnnxPredictor(onnx_path, use_gpu)
|
| 55 |
+
|
| 56 |
+
vocab_data = json.loads(vocab_path.read_text(encoding="utf-8"))
|
| 57 |
+
self.tokenizer = CharTokenizer(
|
| 58 |
+
char2id=vocab_data["char2id"],
|
| 59 |
+
diacritic2id=vocab_data["diacritic2id"],
|
| 60 |
+
)
|
| 61 |
+
self.segmenter = TextSegmenter()
|
| 62 |
+
|
| 63 |
+
def _diacritize_sentence(self, text: str) -> str:
|
| 64 |
+
"""Helper to diacritize an arabic sentence string."""
|
| 65 |
+
if not text.strip():
|
| 66 |
+
return ""
|
| 67 |
+
|
| 68 |
+
input_ids, diacritic_ids = self.tokenizer.encode(text)
|
| 69 |
+
|
| 70 |
+
text_list = list(TextCleaner.remove_diacritics(text))
|
| 71 |
+
|
| 72 |
+
original_len = len(input_ids)
|
| 73 |
+
if original_len == 0:
|
| 74 |
+
return ""
|
| 75 |
+
|
| 76 |
+
input_chars = np.array(input_ids).astype(np.int64).reshape(1, -1)
|
| 77 |
+
no_diacritic_id = self.tokenizer.diacritic2id.get("", 0)
|
| 78 |
+
input_hints = np.full_like(
|
| 79 |
+
input_chars, fill_value=no_diacritic_id, dtype=np.int64
|
| 80 |
+
)
|
| 81 |
+
# inference
|
| 82 |
+
logits = self.predictor.predict(input_ids=input_chars, hints=input_hints)
|
| 83 |
+
predicted_diac_ids = np.argmax(logits, axis=-1)
|
| 84 |
+
|
| 85 |
+
# Decode the predictions
|
| 86 |
+
return self.tokenizer.decode_inference(
|
| 87 |
+
text_list, predicted_diac_ids[0].tolist()
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def diacritize(
|
| 91 |
+
self, text: Union[str, List[str]], postprocess: bool = True
|
| 92 |
+
) -> List[str]:
|
| 93 |
+
"""
|
| 94 |
+
Diacritizes text while preserving non-Arabic characters and structure.
|
| 95 |
+
|
| 96 |
+
This method dissects the input text into Arabic and non-Arabic segments.
|
| 97 |
+
It processes only the Arabic segments and then reassembles the string,
|
| 98 |
+
maintaining the original order and content of all non-Arabic parts.
|
| 99 |
+
|
| 100 |
+
Any existing diacritics in the Arabic segments are stripped before
|
| 101 |
+
being processed by the model to ensure a consistent output.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
text (str): The input text.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
The diacritized string.
|
| 108 |
+
"""
|
| 109 |
+
if not text:
|
| 110 |
+
return ""
|
| 111 |
+
|
| 112 |
+
if isinstance(text, str):
|
| 113 |
+
# To handle a single string input gracefully
|
| 114 |
+
text_or_list = [text]
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
# To handle a list of strings input gracefully
|
| 118 |
+
text_or_list = text
|
| 119 |
+
|
| 120 |
+
if not isinstance(text_or_list, list):
|
| 121 |
+
raise TypeError("Input must be a string or a list of strings.")
|
| 122 |
+
|
| 123 |
+
diacritized_list = [self._diacritize_sentence(s) for s in text_or_list]
|
| 124 |
+
|
| 125 |
+
if postprocess:
|
| 126 |
+
for i, diacritized_sentence in enumerate(diacritized_list):
|
| 127 |
+
diacritized_list[i] = Postprocessor.postprocess(diacritized_sentence)
|
| 128 |
+
|
| 129 |
+
if isinstance(text, str):
|
| 130 |
+
return diacritized_list[0]
|
| 131 |
+
|
| 132 |
+
return diacritized_list
|
diacritizer/exceptions.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class DiacritizerException(Exception):
|
| 2 |
+
"""Base exception for all errors raised by the diacritizer package."""
|
| 3 |
+
pass
|
| 4 |
+
|
| 5 |
+
class ModelNotFound(DiacritizerException):
|
| 6 |
+
"""Raised when the model files (ONNX, vocab, etc.) cannot be found."""
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
class InvalidInputError(DiacritizerException):
|
| 10 |
+
"""Raised when the input text provided to the diacritizer is invalid."""
|
| 11 |
+
pass
|
diacritizer/hub_manager.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from huggingface_hub.errors import EntryNotFoundError, LocalEntryNotFoundError
|
| 7 |
+
|
| 8 |
+
from .exceptions import ModelNotFound
|
| 9 |
+
|
| 10 |
+
DEFAULT_HUB_REPO_ID = "muhammad7777/arabic-diacritizer-models"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _download_from_hub(
|
| 14 |
+
repo_id: str, architecture: str, size: str, revision: str, force_sync: bool
|
| 15 |
+
) -> Tuple[str, str]:
|
| 16 |
+
"""
|
| 17 |
+
Internal helper to download model artifacts from the Hub, using a nested subfolder.
|
| 18 |
+
"""
|
| 19 |
+
# Construct the nested path (e.g., "bilstm/medium")
|
| 20 |
+
model_subfolder = f"{architecture}/{size}"
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# Attempt to load from cache first (offline-first)
|
| 24 |
+
if not force_sync:
|
| 25 |
+
onnx_path = hf_hub_download(
|
| 26 |
+
repo_id=repo_id,
|
| 27 |
+
filename="model.onnx",
|
| 28 |
+
subfolder=model_subfolder,
|
| 29 |
+
revision=revision,
|
| 30 |
+
local_files_only=True,
|
| 31 |
+
)
|
| 32 |
+
vocab_path = hf_hub_download(
|
| 33 |
+
repo_id=repo_id,
|
| 34 |
+
filename="vocab.json",
|
| 35 |
+
subfolder=model_subfolder,
|
| 36 |
+
revision=revision,
|
| 37 |
+
local_files_only=True,
|
| 38 |
+
)
|
| 39 |
+
return onnx_path, vocab_path
|
| 40 |
+
except LocalEntryNotFoundError:
|
| 41 |
+
pass # Not found in cache, proceed to download.
|
| 42 |
+
|
| 43 |
+
# Download from the Hub (online fallback)
|
| 44 |
+
try:
|
| 45 |
+
onnx_path = hf_hub_download(
|
| 46 |
+
repo_id=repo_id,
|
| 47 |
+
filename="model.onnx",
|
| 48 |
+
subfolder=model_subfolder,
|
| 49 |
+
revision=revision,
|
| 50 |
+
force_download=force_sync,
|
| 51 |
+
)
|
| 52 |
+
vocab_path = hf_hub_download(
|
| 53 |
+
repo_id=repo_id,
|
| 54 |
+
filename="vocab.json",
|
| 55 |
+
subfolder=model_subfolder,
|
| 56 |
+
revision=revision,
|
| 57 |
+
force_download=force_sync,
|
| 58 |
+
)
|
| 59 |
+
return onnx_path, vocab_path
|
| 60 |
+
except EntryNotFoundError as e:
|
| 61 |
+
# Make the error message more informative
|
| 62 |
+
raise ModelNotFound(
|
| 63 |
+
f"Could not find model for architecture '{architecture}' and size '{size}' "
|
| 64 |
+
f"at revision '{revision}' in repository '{repo_id}'. "
|
| 65 |
+
f"Please check the Hub for available models."
|
| 66 |
+
) from e
|
| 67 |
+
except Exception as e:
|
| 68 |
+
raise ModelNotFound(
|
| 69 |
+
f"Failed to download model from the Hub. Please check your internet connection. "
|
| 70 |
+
f"Original error: {e}"
|
| 71 |
+
) from e
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def resolve_model_path(
|
| 75 |
+
model_identifier: str, architecture: str, size: str, revision: str, force_sync: bool
|
| 76 |
+
) -> Tuple[Path, Path]:
|
| 77 |
+
"""
|
| 78 |
+
Resolves model artifact paths, now with architecture awareness.
|
| 79 |
+
"""
|
| 80 |
+
# Case 1: The identifier is a local directory path (no change here)
|
| 81 |
+
if os.path.isdir(model_identifier):
|
| 82 |
+
model_dir = Path(model_identifier)
|
| 83 |
+
onnx_path = model_dir / "model.onnx"
|
| 84 |
+
vocab_path = model_dir / "vocab.json"
|
| 85 |
+
|
| 86 |
+
if not onnx_path.exists() or not vocab_path.exists():
|
| 87 |
+
raise ModelNotFound(
|
| 88 |
+
f"Local model directory must contain 'model.onnx' and 'vocab.json'. "
|
| 89 |
+
f"Path: {model_dir}"
|
| 90 |
+
)
|
| 91 |
+
return onnx_path, vocab_path
|
| 92 |
+
|
| 93 |
+
# Case 2: The identifier is a Hugging Face Hub repository ID
|
| 94 |
+
# Delegate the download logic, passing the new architecture parameter.
|
| 95 |
+
onnx_path_str, vocab_path_str = _download_from_hub(
|
| 96 |
+
repo_id=model_identifier,
|
| 97 |
+
architecture=architecture,
|
| 98 |
+
size=size,
|
| 99 |
+
revision=revision,
|
| 100 |
+
force_sync=force_sync,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return Path(onnx_path_str), Path(vocab_path_str)
|
diacritizer/predictor.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import numpy as np
|
| 3 |
+
import onnxruntime as ort
|
| 4 |
+
from .exceptions import ModelNotFound
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class OnnxPredictor:
|
| 8 |
+
def __init__(self, model_path: Path, use_gpu: bool = False):
|
| 9 |
+
"""
|
| 10 |
+
Initializes the ONNX Runtime session.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
model_path: Path to the .onnx model file.
|
| 14 |
+
use_gpu: Whether to use the GPU for inference. Defaults to False.
|
| 15 |
+
|
| 16 |
+
Raises:
|
| 17 |
+
ModelNotFound: If the model file does not exist at the given path.
|
| 18 |
+
"""
|
| 19 |
+
if not model_path.exists():
|
| 20 |
+
raise ModelNotFound(f"ONNX model file not found at: {model_path}")
|
| 21 |
+
|
| 22 |
+
providers = ["CPUExecutionProvider"]
|
| 23 |
+
if use_gpu:
|
| 24 |
+
# You can customize this list based on your target hardware
|
| 25 |
+
providers.insert(0, "CUDAExecutionProvider")
|
| 26 |
+
|
| 27 |
+
self.session = ort.InferenceSession(str(model_path), providers=providers)
|
| 28 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 29 |
+
self.hints_name = self.session.get_inputs()[1].name
|
| 30 |
+
self.output_name = self.session.get_outputs()[0].name
|
| 31 |
+
|
| 32 |
+
def predict(self, input_ids: np.ndarray, hints: np.ndarray) -> np.ndarray:
|
| 33 |
+
"""
|
| 34 |
+
Runs inference on a batch of tokenized input IDs.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
input_ids: A numpy array of shape (batch_size, sequence_length).
|
| 38 |
+
hints: A numpy array of shape (batch_size, sequence_length).
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A numpy array of logits of shape (batch_size, sequence_length, num_classes).
|
| 42 |
+
"""
|
| 43 |
+
ort_inputs = {self.input_name: input_ids, self.hints_name: hints}
|
| 44 |
+
# The output is a list, we are interested in the first element
|
| 45 |
+
logits = self.session.run([self.output_name], ort_inputs)[0]
|
| 46 |
+
return logits
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gradio for the UI
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
|
| 4 |
+
# Core dependencies
|
| 5 |
+
onnxruntime
|
| 6 |
+
numpy
|
| 7 |
+
huggingface-hub
|
| 8 |
+
|
| 9 |
+
# --- Custom Local Packages ---
|
| 10 |
+
# These will be installed from the .whl files in the repo
|
| 11 |
+
# arabic_diacritizer_common
|
| 12 |
+
# arabic_diacritizer
|