Muhammad7777 commited on
Commit
7150d27
·
1 Parent(s): 2cc687a

Deploy application

Browse files
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
- title: Arabic Diacritizer
3
- emoji: 🐠
4
- colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.47.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Arabic Diacritizer Demo
3
+ emoji:
4
+ colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.47.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
 
 
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from diacritizer import Diacritizer, ModelNotFound
4
+
5
+ MODEL_INFO = {
6
+ "bilstm": {
7
+ "display_name": "BiLSTM",
8
+ "models": {
9
+ "medium": {
10
+ "size": "4 MB",
11
+ "details": "Balanced speed and accuracy.",
12
+ },
13
+ "large": {
14
+ "size": "15.5 MB",
15
+ "details": "Highest accuracy model.",
16
+ },
17
+ },
18
+ },
19
+ "bigru": {
20
+ "display_name": "BiGRU",
21
+ "models": {
22
+ "medium": {
23
+ "size": "3.8 MB",
24
+ "details": "Slightly faster than BiLSTM with comparable accuracy.",
25
+ },
26
+ "large": {
27
+ "size": "14.9 MB",
28
+ "details": "High accuracy alternative to the BiLSTM model.",
29
+ },
30
+ },
31
+ },
32
+ }
33
+
34
+ MODEL_CACHE = {}
35
+
36
+
37
+ def get_model(architecture: str, size: str, progress=gr.Progress()):
38
+ """
39
+ Lazily loads and caches a Diacritizer model.
40
+ Includes user feedback via gr.Progress to show loading status.
41
+ """
42
+ model_key = f"{architecture}/{size}"
43
+ if model_key not in MODEL_CACHE:
44
+ progress(0.5, desc=f"Loading {architecture}/{size} model...")
45
+ try:
46
+ MODEL_CACHE[model_key] = Diacritizer(architecture=architecture, size=size)
47
+ except ModelNotFound:
48
+ raise gr.Error(
49
+ f"The requested model ({model_key}) was not found on the Hugging Face Hub."
50
+ )
51
+ except Exception as e:
52
+ raise gr.Error(f"An unexpected error occurred while loading the model: {e}")
53
+ return MODEL_CACHE[model_key]
54
+
55
+
56
+ def diacritize_text(text: str, architecture: str, size: str, progress=gr.Progress()):
57
+ """
58
+ Main function to diacritize text, now with progress tracking.
59
+ """
60
+ if not text or not text.strip():
61
+ return "", "0.000s", "Please enter some text to diacritize."
62
+
63
+ progress(0, desc="Loading model...")
64
+ diacritizer = get_model(architecture, size, progress)
65
+
66
+ progress(0.8, desc="Diacritizing text...")
67
+ start_time = time.time()
68
+ diacritized_text = diacritizer.diacritize(text)
69
+ end_time = time.time()
70
+
71
+ inference_time = f"{end_time - start_time:.3f}s"
72
+
73
+ # Update the info text with the final result details
74
+ model_details = MODEL_INFO[architecture]["models"][size]["details"]
75
+ final_info_text = f"**Model:** {architecture}/{size} | **Size:** {MODEL_INFO[architecture]['models'][size]['size']} | {model_details}"
76
+
77
+ return diacritized_text, inference_time, final_info_text
78
+
79
+
80
+ def update_available_sizes(architecture: str):
81
+ """Callback to update the size choices when the architecture changes."""
82
+ available_sizes = list(MODEL_INFO[architecture]["models"].keys())
83
+ # Return a new Radio component with updated choices and a default value
84
+ return gr.Radio(
85
+ choices=available_sizes,
86
+ value=available_sizes[0], # Default to the first available size
87
+ label="Model Size",
88
+ info="Select the model size.",
89
+ )
90
+
91
+
92
+ theme = gr.themes.Soft(
93
+ primary_hue="zinc",
94
+ secondary_hue="blue",
95
+ neutral_hue="slate",
96
+ font=(gr.themes.GoogleFont("Noto Sans"), gr.themes.GoogleFont("Noto Sans Arabic")),
97
+ ).set(
98
+ body_background_fill_dark="#111827" # A slightly off-black for dark mode
99
+ )
100
+
101
+ DESCRIPTION = """
102
+ # ⚡ End-to-End Arabic Diacritizer
103
+ A lightweight and efficient model for automatic Arabic diacritization.
104
+ Select an architecture and size, enter some text, and see it in action. For more details, visit the
105
+ [GitHub repository](https://github.com/muhammad-abdelsattar/arabic-diacritizer).
106
+ """
107
+
108
+ with gr.Blocks(theme=theme, css=".footer {display: none !important}") as demo:
109
+ gr.Markdown(DESCRIPTION)
110
+
111
+ with gr.Row():
112
+ with gr.Column(scale=3):
113
+ with gr.Row():
114
+ arch_selector = gr.Radio(
115
+ choices=[
116
+ (info["display_name"], arch)
117
+ for arch, info in MODEL_INFO.items()
118
+ ],
119
+ label="Model Architecture",
120
+ value="bilstm",
121
+ info="Select the model architecture.",
122
+ )
123
+ model_selector = gr.Radio(
124
+ choices=["medium", "large"],
125
+ label="Model Size",
126
+ value="medium",
127
+ info="Select the model size.",
128
+ )
129
+ info_display = gr.Markdown(
130
+ "**Model:** bilstm/medium | **Size:** 4 MB | Balanced speed and accuracy. (Formerly 'small')",
131
+ elem_id="info-display",
132
+ )
133
+
134
+ with gr.Column(scale=1):
135
+ inference_time_output = gr.Textbox(
136
+ label="Inference Time", interactive=False, max_lines=1
137
+ )
138
+
139
+ with gr.Row(equal_height=True):
140
+ with gr.Column():
141
+ input_text = gr.Textbox(
142
+ label="Input Text (Undiacritized)",
143
+ placeholder="اكتب جملة عربية غير مشكولة هنا...",
144
+ lines=8,
145
+ rtl=True,
146
+ )
147
+ with gr.Column():
148
+ output_text = gr.Textbox(
149
+ label="Output Text (Diacritized)",
150
+ lines=8,
151
+ rtl=True,
152
+ interactive=False,
153
+ )
154
+
155
+ submit_button = gr.Button("Diacritize ✨", variant="primary")
156
+
157
+ gr.Examples(
158
+ [
159
+ ["أعلنت الشركة عن نتائجها المالية للربع الأول من العام."],
160
+ ["إن مع العسر يسرا."],
161
+ ["هل يمكن للذكاء الاصطناعي أن يكون مبدعا؟"],
162
+ ["كان المتنبي شاعرا عظيما في العصر العباسي."],
163
+ ],
164
+ inputs=input_text,
165
+ label="Examples",
166
+ )
167
+
168
+ submit_button.click(
169
+ fn=diacritize_text,
170
+ inputs=[input_text, arch_selector, model_selector],
171
+ outputs=[output_text, inference_time_output, info_display],
172
+ )
173
+
174
+ # When architecture changes, update the available sizes
175
+ arch_selector.change(
176
+ fn=update_available_sizes, inputs=arch_selector, outputs=model_selector
177
+ )
178
+
179
+
180
+ if __name__ == "__main__":
181
+ # Pre-load the default model for a faster first-time user experience
182
+ print("Pre-loading default 'bilstm/medium' model...")
183
+ get_model(architecture="bilstm", size="medium")
184
+ print("Default model loaded successfully.")
185
+
186
+ demo.launch()
arabic_diacritizer_common/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import (
2
+ ArabicDiacritics,
3
+ VALID_ARABIC_CHARS,
4
+ DIACRITIC_CHARS,
5
+ ARABIC_LETTERS,
6
+ ARABIC_LETTERS_REGEX,
7
+ )
8
+ from .cleaners import TextCleaner, DiacriticValidator
9
+ from .segmenter import TextSegmenter
10
+ from .tokenizer import CharTokenizer
11
+ from .postprocessor import Postprocessor
12
+
13
+ __all__ = [
14
+ "ArabicDiacritics",
15
+ "VALID_ARABIC_CHARS",
16
+ "DIACRITIC_CHARS",
17
+ "ARABIC_LETTERS",
18
+ "TextCleaner",
19
+ "Postprocessor",
20
+ "DiacriticValidator",
21
+ "TextSegmenter",
22
+ "CharTokenizer",
23
+ "ARABIC_LETTERS_REGEX",
24
+ ]
arabic_diacritizer_common/cleaners.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Tuple
3
+ from .constants import (
4
+ VALID_ARABIC_CHARS,
5
+ DIACRITIC_CHARS,
6
+ INVALID_SEQUENCES,
7
+ ALL_VALID_DIACRITICS,
8
+ ArabicDiacritics,
9
+ )
10
+
11
+ # Whitespace regex
12
+ _whitespace_re = re.compile(r"\s+")
13
+
14
+
15
+ class TextCleaner:
16
+ """Modular text cleaning utilities"""
17
+
18
+ @staticmethod
19
+ def collapse_whitespace(text: str) -> str:
20
+ """Collapse multiple whitespace characters into a single space"""
21
+ return re.sub(_whitespace_re, " ", text).strip()
22
+
23
+ @staticmethod
24
+ def filter_valid_arabic(text: str) -> str:
25
+ """Keep only valid Arabic characters, punctuation, and diacritics"""
26
+ return "".join(char for char in text if char in VALID_ARABIC_CHARS)
27
+
28
+ @staticmethod
29
+ def remove_diacritics(text: str) -> str:
30
+ """Remove all diacritic characters from text"""
31
+ return "".join(ch for ch in text if ch not in DIACRITIC_CHARS)
32
+
33
+ @staticmethod
34
+ def normalize_text(text: str) -> str:
35
+ """Normalize common text irregularities & diacritic order"""
36
+ # Canonicalize diacritic order (make sure Shadda always comes first)
37
+ for invalid, correct in INVALID_SEQUENCES.items():
38
+ text = text.replace(invalid, correct)
39
+
40
+ # Normalize alef variants to bare alef (optional, safer for training)
41
+ normalize_map = {"أ": "ا", "إ": "ا", "آ": "ا", "ٱ": "ا"}
42
+ for k, v in normalize_map.items():
43
+ text = text.replace(k, v)
44
+
45
+ # Remove Tatweel (ـ) since it is purely decorative
46
+ text = text.replace("ـ", "")
47
+ return text
48
+
49
+ @staticmethod
50
+ def clean_text(
51
+ text: str, keep_valid_only: bool = True, normalize: bool = False
52
+ ) -> str:
53
+ """Complete cleaning pipeline: normalize → optional filtering → collapse ws"""
54
+ if normalize:
55
+ text = TextCleaner.normalize_text(text)
56
+ if keep_valid_only:
57
+ text = TextCleaner.filter_valid_arabic(text)
58
+ return TextCleaner.collapse_whitespace(text)
59
+
60
+ @staticmethod
61
+ def strip_diacritics(text: str) -> str:
62
+ """Efficiently strips all diacritic marks from an Arabic string."""
63
+ return "".join(char for char in text if char not in DIACRITIC_CHARS)
64
+
65
+
66
+ class DiacriticValidator:
67
+ """Handles diacritic validation and extraction"""
68
+
69
+ @staticmethod
70
+ def extract_diacritics(text: str) -> Tuple[str, List[str]]:
71
+ """
72
+ Extract base text and list of diacritics.
73
+ Each base character gets an associated diacritic string (possibly multiple).
74
+ Example:
75
+ "بَّ" → ("ب", ["َّ"])
76
+ """
77
+ base_chars = []
78
+ diacritics = []
79
+
80
+ i = 0
81
+ while i < len(text):
82
+ char = text[i]
83
+ if char in DIACRITIC_CHARS:
84
+ # attach to previous base character if exists
85
+ if base_chars:
86
+ # Append this diacritic to most recent slot
87
+ diacritics[-1] = diacritics[-1] + char
88
+ else:
89
+ # Stray diacritic at beginning — skip or treat as invalid
90
+ pass
91
+ else:
92
+ # New base char: allocate diacritic slot
93
+ base_chars.append(char)
94
+ diacritics.append("")
95
+ i += 1
96
+
97
+ # Normalize combined diacritics to canonical representations
98
+ normalized_diacritics = []
99
+ for d in diacritics:
100
+ if d in ALL_VALID_DIACRITICS:
101
+ normalized_diacritics.append(d)
102
+ else:
103
+ # try to reorder if contains shadda + vowel
104
+ if "ّ" in d:
105
+ # move shadda to front
106
+ d = "ّ" + "".join(c for c in d if c != "ّ")
107
+ # keep only known chars
108
+ d = "".join(c for c in d if c in DIACRITIC_CHARS)
109
+ normalized_diacritics.append(d)
110
+ return "".join(base_chars), normalized_diacritics
111
+
112
+ @staticmethod
113
+ def validate_diacritics(
114
+ text: str, require_any: bool = False, strict: bool = False
115
+ ) -> str:
116
+ """
117
+ Validate that text diacritics are well-formed.
118
+ - require_any: if True, reject sentences with no diacritics at all.
119
+ - strict: if True, reject unknown/malformed diacritics, else sanitize them.
120
+ Returns text if valid, otherwise "".
121
+ """
122
+ try:
123
+ base_text, diacritics_list = DiacriticValidator.extract_diacritics(text)
124
+
125
+ # Optionally require that at least one diacritic is present
126
+ if require_any:
127
+ if not any(
128
+ d
129
+ for d in diacritics_list
130
+ if d != ArabicDiacritics.NO_DIACRITIC.value
131
+ ):
132
+ return ""
133
+
134
+ # In strict mode, reject any diacritic not in valid set
135
+ if strict:
136
+ for d in diacritics_list:
137
+ if d not in ALL_VALID_DIACRITICS:
138
+ return ""
139
+ return text
140
+ except Exception:
141
+ return ""
arabic_diacritizer_common/constants.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from typing import Set, FrozenSet
3
+ import re
4
+
5
+
6
+ class ArabicDiacritics(enum.Enum):
7
+ """All possible Arabic diacritics (standard + extended)."""
8
+
9
+ NO_DIACRITIC = ""
10
+ SUKOON = "ْ"
11
+ SHADDA = "ّ"
12
+ DAMMA = "ُ"
13
+ FATHA = "َ"
14
+ KASRA = "ِ"
15
+ TANWEEN_DAMMA = "ٌ"
16
+ TANWEEN_FATHA = "ً"
17
+ TANWEEN_KASRA = "ٍ"
18
+ # Combinations where order may vary in Unicode
19
+ SHADDA_PLUS_DAMMA = "ُّ" # normalize to 'shadda then vowel'
20
+ SHADDA_PLUS_FATHA = "َّ"
21
+ SHADDA_PLUS_KASRA = "ِّ"
22
+ SHADDA_PLUS_TANWEEN_DAMMA = "ٌّ"
23
+ SHADDA_PLUS_TANWEEN_FATHA = "ًّ"
24
+ SHADDA_PLUS_TANWEEN_KASRA = "ٍّ"
25
+
26
+ # Quranic / orthographic additions
27
+ DAGGER_ALEF = "ٰ" # superscript Alef (dagger)
28
+ MADDA = "ٓ" # Maddah
29
+ WASLA = "ٱ" # Hamzat Wasl (technically letter with mark)
30
+
31
+ @classmethod
32
+ def chars(cls) -> Set[str]:
33
+ """Return set of atomic (single-character) diacritics."""
34
+ return {
35
+ cls.SUKOON.value,
36
+ cls.SHADDA.value,
37
+ cls.DAMMA.value,
38
+ cls.FATHA.value,
39
+ cls.KASRA.value,
40
+ cls.TANWEEN_DAMMA.value,
41
+ cls.TANWEEN_FATHA.value,
42
+ cls.TANWEEN_KASRA.value,
43
+ cls.DAGGER_ALEF.value,
44
+ cls.MADDA.value,
45
+ }
46
+
47
+ @classmethod
48
+ def valid_combinations(cls) -> Set[str]:
49
+ """Return full set of valid diacritic combinations."""
50
+ return {
51
+ cls.NO_DIACRITIC.value,
52
+ # Singles
53
+ cls.SUKOON.value,
54
+ cls.DAMMA.value,
55
+ cls.FATHA.value,
56
+ cls.KASRA.value,
57
+ cls.TANWEEN_DAMMA.value,
58
+ cls.TANWEEN_FATHA.value,
59
+ cls.TANWEEN_KASRA.value,
60
+ cls.DAGGER_ALEF.value,
61
+ cls.MADDA.value,
62
+ # Shadda combos
63
+ cls.SHADDA_PLUS_DAMMA.value,
64
+ cls.SHADDA_PLUS_FATHA.value,
65
+ cls.SHADDA_PLUS_KASRA.value,
66
+ cls.SHADDA_PLUS_TANWEEN_DAMMA.value,
67
+ cls.SHADDA_PLUS_TANWEEN_FATHA.value,
68
+ cls.SHADDA_PLUS_TANWEEN_KASRA.value,
69
+ }
70
+
71
+ @classmethod
72
+ def is_valid_diacritic(cls, diacritic: str) -> bool:
73
+ return diacritic in cls.valid_combinations()
74
+
75
+
76
+ # Character sets
77
+ WORD_SEPARATOR = " "
78
+
79
+
80
+ # Arabic letters base Unicode block (0600–06FF covers standard Arabic letters)
81
+ ARABIC_LETTERS_BASE = [chr(c) for c in range(0x0621, 0x064B)]
82
+ # Extended Arabic letters (found in borrowed words, Persian/Urdu usage)
83
+ ARABIC_LETTERS_EXTENDED_BLOCK = [
84
+ "ى", # Alef Maqsura
85
+ "ة", # Taa Marbuta
86
+ "پ",
87
+ "چ",
88
+ "ڤ",
89
+ "گ", # Persian/Urdu additions
90
+ ]
91
+ ALEF_VARIANTS = {"ا", "أ", "إ", "آ"}
92
+
93
+ # Merge all letters
94
+ ARABIC_LETTERS = frozenset(
95
+ ARABIC_LETTERS_BASE + ARABIC_LETTERS_EXTENDED_BLOCK + list(ALEF_VARIANTS)
96
+ )
97
+
98
+ # Punctuation
99
+ PUNCTUATIONS = frozenset(
100
+ {
101
+ ".",
102
+ "،",
103
+ ":",
104
+ "؛",
105
+ "-",
106
+ "؟",
107
+ "!",
108
+ "(",
109
+ ")",
110
+ "[",
111
+ "]",
112
+ '"',
113
+ "«",
114
+ "»",
115
+ "/",
116
+ ";",
117
+ ",",
118
+ "…",
119
+ "ـ", # ellipsis + tatweel
120
+ }
121
+ )
122
+ SENTENCE_DELIMITERS = {".", "؟", "!", "،", ":", "؛", "…"}
123
+ WORD_DELIMITERS = {WORD_SEPARATOR, *SENTENCE_DELIMITERS}
124
+
125
+ # Diacritics sets
126
+ DIACRITIC_CHARS = ArabicDiacritics.chars()
127
+ ALL_VALID_DIACRITICS = ArabicDiacritics.valid_combinations()
128
+
129
+ # All valid characters
130
+ VALID_ARABIC_CHARS = {WORD_SEPARATOR, *ARABIC_LETTERS, *PUNCTUATIONS, *DIACRITIC_CHARS}
131
+
132
+ # Text normalization (fixes diacritic ordering inconsistencies)
133
+ INVALID_SEQUENCES = {
134
+ # Normalize to canonical "SHADDA first, VOWEL after"
135
+ "َّ": "َّ", # fatha + shadda → shadda + fatha
136
+ "ِّ": "ِّ", # kasra + shadda → shadda + kasra
137
+ "ُّ": "ُّ", # damma + shadda → shadda + damma
138
+ "ًّ": "ًّ", # tanween fatha
139
+ "ٍّ": "ٍّ", # tanween kasra
140
+ "ٌّ": "ٌّ", # tanween damma
141
+ # Punctuation spacing corrections
142
+ " ،": "،",
143
+ " .": ".",
144
+ " ؟": "؟",
145
+ " ؛": "؛",
146
+ " …": "…",
147
+ }
148
+
149
+ # Regex for Arabic letters
150
+ ARABIC_LETTERS_REGEX = re.compile(f'[{"".join(ARABIC_LETTERS)}]+')
arabic_diacritizer_common/postprocessor.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import DIACRITIC_CHARS, ArabicDiacritics
2
+ from .cleaners import DiacriticValidator, TextCleaner
3
+
4
+
5
+ class Postprocessor:
6
+ """
7
+ Provides rule-based post-processing to correct common linguistic errors
8
+ in model-generated diacritized text.
9
+ """
10
+
11
+ @staticmethod
12
+ def postprocess(text: str) -> str:
13
+ text = Postprocessor._correct_tanween_fatha_placement(text)
14
+ text = Postprocessor._remove_diacritics_from_alifs(text)
15
+ return text
16
+
17
+ @staticmethod
18
+ def _correct_tanween_fatha_placement(text: str) -> str:
19
+ """
20
+ Corrects the placement of Tanween Fatha (ً) from a final Alif (ا)
21
+ to the preceding character. This is a common model error.
22
+
23
+ Example: "مَرْحَبَاً" (incorrect) -> "مَرْحَبًا" (correct)
24
+ """
25
+ words = text.split(" ")
26
+ corrected_words = []
27
+ tanween_fatha = ArabicDiacritics.TANWEEN_FATHA.value
28
+
29
+ for word in words:
30
+
31
+ # Ensure the word is long enough to have a character before a final Alif.
32
+ if len(word) < 2:
33
+ corrected_words.append(word)
34
+ continue
35
+
36
+ base_chars, diacritics = DiacriticValidator.extract_diacritics(word)
37
+
38
+ if (
39
+ len(base_chars) > 1
40
+ and base_chars[-1] == "ا"
41
+ and diacritics[-1] == tanween_fatha
42
+ ):
43
+
44
+ diacritics[-2] = tanween_fatha
45
+
46
+ # Ensure the final Alif is left with no diacritic.
47
+ diacritics[-1] = ""
48
+
49
+ corrected_word = "".join(
50
+ [c + d for c, d in zip(base_chars, diacritics)]
51
+ )
52
+ corrected_words.append(corrected_word)
53
+ else:
54
+ corrected_words.append(word)
55
+
56
+ return " ".join(corrected_words)
57
+
58
+ @staticmethod
59
+ def _remove_diacritics_from_alifs(text: str) -> str:
60
+ """
61
+ Removes any diacritics from plain Alif (ا) and Alif Maqsura (ى)
62
+ anywhere within a word. These characters should not carry short vowels.
63
+
64
+ Example:
65
+ - "عَلَىَ" -> "عَلَى"
66
+ - "اِسْم" -> "اِسْم" (This is correct; the model predicted a base letter, not a diacritic)
67
+ - "كِتَابُ" -> "كِتَاب" (Incorrect model output gets corrected)
68
+ """
69
+ words = text.split(" ")
70
+ corrected_words = []
71
+ for word in words:
72
+ base_chars, diacritics = DiacriticValidator.extract_diacritics(word)
73
+
74
+ for i in range(len(base_chars)):
75
+ if base_chars[i] == "ا" or base_chars[i] == "ى":
76
+ diacritics[i] = ""
77
+
78
+ corrected_word = "".join([c + d for c, d in zip(base_chars, diacritics)])
79
+ corrected_words.append(corrected_word)
80
+
81
+ return " ".join(corrected_words)
arabic_diacritizer_common/segmenter.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+ from .cleaners import TextCleaner
4
+ from .constants import DIACRITIC_CHARS
5
+
6
+ def grapheme_length(text: str) -> int:
7
+ """Return logical length of text, counting base characters only (ignore diacritics)."""
8
+ return sum(1 for ch in text if ch not in DIACRITIC_CHARS)
9
+
10
+ class TextSegmenter:
11
+ """Handles text segmentation and sentence splitting"""
12
+
13
+ # Regex for sentence boundaries (includes multi-char delimiters)
14
+ # Treats sequences like "؟!" or "..." as a single delimiter
15
+ SENTENCE_BOUNDARY_RE = re.compile(r"(؟!|!|\?|\.{2,}|…|،|؛)")
16
+
17
+ @staticmethod
18
+ def segment_sentences(max_chars: int, line: str) -> List[str]:
19
+ """
20
+ Segment long lines into shorter sentences with max length constraint.
21
+ """
22
+ line = line.strip()
23
+ if not line:
24
+ return []
25
+
26
+ if grapheme_length(line) <= max_chars:
27
+ return [TextCleaner.collapse_whitespace(line)]
28
+
29
+ # Perform segmentation
30
+ return TextSegmenter._do_segment_sentences(line, max_chars)
31
+
32
+ @staticmethod
33
+ def _do_segment_sentences(line: str, max_chars: int) -> List[str]:
34
+ """
35
+ Internal recursive sentence segmentation logic.
36
+ """
37
+ # Split based on boundary regex (keeps delimiters)
38
+ parts = []
39
+ last_idx = 0
40
+ for match in TextSegmenter.SENTENCE_BOUNDARY_RE.finditer(line):
41
+ start, end = match.span()
42
+ segment = line[last_idx:start].strip()
43
+ delimiter = match.group()
44
+ if segment:
45
+ parts.append(segment + delimiter)
46
+ last_idx = end
47
+
48
+ if last_idx < len(line):
49
+ remainder = line[last_idx:].strip()
50
+ if remainder:
51
+ parts.append(remainder)
52
+
53
+ # Now filter by length
54
+ results: List[str] = []
55
+ for sent in parts:
56
+ sent = TextCleaner.collapse_whitespace(sent)
57
+ if not sent:
58
+ continue
59
+ if grapheme_length(sent) <= max_chars:
60
+ results.append(sent)
61
+ else:
62
+ # Recursive split if still too long
63
+ subsegments = TextSegmenter._recursive_split(sent, max_chars)
64
+ results.extend(subsegments)
65
+
66
+ return results
67
+
68
+ @staticmethod
69
+ def _recursive_split(text: str, max_chars: int) -> List[str]:
70
+ """
71
+ Splits oversized text recursively by words if necessary.
72
+ """
73
+ words = text.split()
74
+ if not words:
75
+ return []
76
+
77
+ segments = []
78
+ cur_segment = []
79
+ cur_len = 0
80
+
81
+ for w in words:
82
+ if cur_len + grapheme_length(w) + 1 > max_chars:
83
+ if cur_segment:
84
+ segments.append(" ".join(cur_segment))
85
+ cur_segment = [w]
86
+ cur_len = grapheme_length(w)
87
+ else:
88
+ cur_segment.append(w)
89
+ cur_len += grapheme_length(w) + 1
90
+
91
+ if cur_segment:
92
+ segments.append(" ".join(cur_segment))
93
+
94
+ return segments
arabic_diacritizer_common/tokenizer.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import List, Dict, Tuple, Optional
4
+
5
+ from .constants import ArabicDiacritics, ARABIC_LETTERS, VALID_ARABIC_CHARS
6
+ from .cleaners import DiacriticValidator, TextCleaner
7
+
8
+
9
+ class CharTokenizer:
10
+ """
11
+ Character-level tokenizer for Arabic diacritization.
12
+
13
+ Input: bare characters (without diacritics)
14
+ Output: per-character diacritic labels (including NO_DIACRITIC)
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ char2id: Optional[Dict[str, int]] = None,
20
+ diacritic2id: Optional[Dict[str, int]] = None,
21
+ include_punct: bool = True,
22
+ extra_chars: Optional[List[str]] = None,
23
+ ):
24
+ """
25
+ If no vocab mappings are provided, builds defaults from constants.py
26
+ """
27
+ if char2id is None or diacritic2id is None:
28
+ # Base vocabulary from constants
29
+ vocab_chars = list(ARABIC_LETTERS)
30
+ if include_punct:
31
+ vocab_chars += [
32
+ c for c in VALID_ARABIC_CHARS if c not in ARABIC_LETTERS
33
+ ]
34
+ if extra_chars:
35
+ vocab_chars += extra_chars
36
+ vocab_chars = sorted(set(vocab_chars))
37
+
38
+ # Char vocab (+PAD, +UNK)
39
+ char2id = {"<PAD>": 0, "<UNK>": 1}
40
+ char2id.update({ch: idx + 2 for idx, ch in enumerate(vocab_chars)})
41
+
42
+ # Diacritic vocab (includes NO_DIACRITIC "")
43
+ diacritic2id = {
44
+ d: i
45
+ for i, d in enumerate(sorted(ArabicDiacritics.valid_combinations()))
46
+ }
47
+
48
+ self.char2id = char2id
49
+ self.id2char = {i: c for c, i in char2id.items()}
50
+ self.diacritic2id = diacritic2id
51
+ self.id2diacritic = {i: d for d, i in diacritic2id.items()}
52
+
53
+ def save(self, path: str):
54
+ Path(path).write_text(
55
+ json.dumps(
56
+ {"char2id": self.char2id, "diacritic2id": self.diacritic2id},
57
+ ensure_ascii=False,
58
+ indent=2,
59
+ ),
60
+ encoding="utf-8",
61
+ )
62
+
63
+ @classmethod
64
+ def load(cls, path: str):
65
+ data = json.loads(Path(path).read_text(encoding="utf-8"))
66
+ return cls(data["char2id"], data["diacritic2id"])
67
+
68
+ def encode(self, text: str) -> Tuple[List[int], List[int]]:
69
+ """
70
+ Encode a diacritized string → (input_ids, diacritic_labels)
71
+ """
72
+ # clean_text = TextCleaner.clean_text(text, keep_valid_only=True)
73
+ base_text, diacritics = DiacriticValidator.extract_diacritics(text)
74
+
75
+ input_ids = [self.char2id.get(ch, self.char2id["<PAD>"]) for ch in base_text]
76
+ label_ids = [
77
+ self.diacritic2id.get(
78
+ d, self.diacritic2id[ArabicDiacritics.NO_DIACRITIC.value]
79
+ )
80
+ for d in diacritics
81
+ ]
82
+ return input_ids, label_ids
83
+
84
+ def encode_for_inference(self, text: str) -> Tuple[List[int], List[int]]:
85
+ """
86
+ Encode a diacritized string → (input_ids, diacritic_labels)
87
+ """
88
+ # clean_text = TextCleaner.clean_text(text, keep_valid_only=True)
89
+
90
+ input_ids = [self.char2id.get(ch, self.char2id["<PAD>"]) for ch in base_text]
91
+ label_ids = [
92
+ self.diacritic2id.get(
93
+ d, self.diacritic2id[ArabicDiacritics.NO_DIACRITIC.value]
94
+ )
95
+ for d in diacritics
96
+ ]
97
+ return input_ids, label_ids
98
+
99
+ def decode(
100
+ self, input_ids: List[int], label_ids: List[int], cleanup_mode: str = "clean"
101
+ ) -> str:
102
+ """
103
+ Decode (input_ids, label_ids) -> string with diacritics.
104
+
105
+ Args:
106
+ input_ids: List of character IDs.
107
+ label_ids: List of predicted diacritic IDs.
108
+ cleanup_mode (str): Determines the post-processing strategy.
109
+ - "clean": (Default) Removes diacritics from non-Arabic letters (e.g., punctuation, spaces).
110
+ - "raw": Returns the raw model output without any cleanup.
111
+
112
+ Returns:
113
+ The reconstructed, diacritized string.
114
+ """
115
+ if cleanup_mode not in {"clean", "raw"}:
116
+ raise ValueError("cleanup_mode must be either 'clean' or 'raw'.")
117
+
118
+ chars = [self.id2char.get(i, "<UNK>") for i in input_ids]
119
+ diacs = [self.id2diacritic.get(i, "") for i in label_ids]
120
+
121
+ if cleanup_mode == "raw":
122
+ return "".join(ch + d for ch, d in zip(chars, diacs))
123
+
124
+ # Default is "clean" mode
125
+ cleaned_output = []
126
+ for char, diac in zip(chars, diacs):
127
+ # Only attach a diacritic if the character is a valid Arabic letter
128
+ if char in ARABIC_LETTERS:
129
+ cleaned_output.append(char + diac)
130
+ else:
131
+ cleaned_output.append(
132
+ char
133
+ ) # Append the character without the predicted diacritic
134
+
135
+ return "".join(cleaned_output)
136
+
137
+ def decode_inference(
138
+ self,
139
+ text_list: list,
140
+ label_ids: list,
141
+ cleanup_mode: str = "clean",
142
+ ):
143
+ """
144
+ Decode (input_ids, label_ids) -> string with diacritics.
145
+
146
+ Args:
147
+ text_list: List of chars in the original text without diacritics.
148
+ label_ids: List of predicted diacritic IDs.
149
+ cleanup_mode (str): Determines the post-processing strategy.
150
+ - "clean": (Default) Removes diacritics from non-Arabic letters (e.g., punctuation, spaces).
151
+ - "raw": Returns the raw model output without any cleanup.
152
+
153
+ Returns:
154
+ The reconstructed, diacritized string.
155
+ """
156
+ if cleanup_mode not in {"clean", "raw"}:
157
+ raise ValueError("cleanup_mode must be either 'clean' or 'raw'.")
158
+
159
+ diacs = [self.id2diacritic.get(i, "") for i in label_ids]
160
+ cleaned_output = []
161
+ for char, diac in zip(text_list, diacs):
162
+ # Only attach a diacritic if the character is a valid Arabic letter
163
+ if char in ARABIC_LETTERS:
164
+ cleaned_output.append(str(char) + str(diac))
165
+ else:
166
+ cleaned_output.append(
167
+ char
168
+ ) # Append the character without the predicted diacritic
169
+ return "".join(cleaned_output)
diacritizer/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diacritizer import Diacritizer
2
+ from .exceptions import ModelNotFound, InvalidInputError, DiacritizerException
3
+
4
+ __all__ = [
5
+ "Diacritizer",
6
+ "ModelNotFound",
7
+ "InvalidInputError",
8
+ "DiacritizerException"
9
+ ]
10
+
11
+ __version__ = "0.1.0"
diacritizer/diacritizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Union, List, Optional
4
+ import numpy as np
5
+ from arabic_diacritizer_common import (
6
+ CharTokenizer,
7
+ TextSegmenter,
8
+ Postprocessor,
9
+ DiacriticValidator,
10
+ ARABIC_LETTERS_REGEX,
11
+ TextCleaner,
12
+ DIACRITIC_CHARS,
13
+ )
14
+ from .predictor import OnnxPredictor
15
+ from .hub_manager import resolve_model_path, DEFAULT_HUB_REPO_ID
16
+
17
+
18
+ class Diacritizer:
19
+ def __init__(
20
+ self,
21
+ model_identifier: Optional[str] = None,
22
+ architecture: str = "bilstm",
23
+ size: str = "medium",
24
+ revision: str = "main",
25
+ force_sync: bool = False,
26
+ use_gpu: bool = False,
27
+ ):
28
+ """
29
+ Initializes the Diacritizer by loading the model and tokenizer.
30
+
31
+ Args:
32
+ model_identifier (str, optional): The identifier for the model. Can be a
33
+ local path or a Hugging Face Hub repo ID. Defaults to the official repo.
34
+ architecture (str): The model architecture ('bilstm', 'bigru', etc.).
35
+ Defaults to "bilstm".
36
+ size (str): The model size ('small', 'medium'). Defaults to "medium".
37
+ revision (str): A specific model version from the Hub. Defaults to "main".
38
+ force_sync (bool): If True, forces a re-download. Defaults to False.
39
+ use_gpu (bool): If True, attempts to use CUDA. Defaults to False.
40
+ """
41
+ self.max_length = -1
42
+
43
+ repo_to_resolve = model_identifier or DEFAULT_HUB_REPO_ID
44
+
45
+ # Pass the new 'architecture' parameter to the resolver function
46
+ onnx_path, vocab_path = resolve_model_path(
47
+ model_identifier=repo_to_resolve,
48
+ architecture=architecture, # MODIFIED
49
+ size=size,
50
+ revision=revision,
51
+ force_sync=force_sync,
52
+ )
53
+
54
+ self.predictor = OnnxPredictor(onnx_path, use_gpu)
55
+
56
+ vocab_data = json.loads(vocab_path.read_text(encoding="utf-8"))
57
+ self.tokenizer = CharTokenizer(
58
+ char2id=vocab_data["char2id"],
59
+ diacritic2id=vocab_data["diacritic2id"],
60
+ )
61
+ self.segmenter = TextSegmenter()
62
+
63
+ def _diacritize_sentence(self, text: str) -> str:
64
+ """Helper to diacritize an arabic sentence string."""
65
+ if not text.strip():
66
+ return ""
67
+
68
+ input_ids, diacritic_ids = self.tokenizer.encode(text)
69
+
70
+ text_list = list(TextCleaner.remove_diacritics(text))
71
+
72
+ original_len = len(input_ids)
73
+ if original_len == 0:
74
+ return ""
75
+
76
+ input_chars = np.array(input_ids).astype(np.int64).reshape(1, -1)
77
+ no_diacritic_id = self.tokenizer.diacritic2id.get("", 0)
78
+ input_hints = np.full_like(
79
+ input_chars, fill_value=no_diacritic_id, dtype=np.int64
80
+ )
81
+ # inference
82
+ logits = self.predictor.predict(input_ids=input_chars, hints=input_hints)
83
+ predicted_diac_ids = np.argmax(logits, axis=-1)
84
+
85
+ # Decode the predictions
86
+ return self.tokenizer.decode_inference(
87
+ text_list, predicted_diac_ids[0].tolist()
88
+ )
89
+
90
+ def diacritize(
91
+ self, text: Union[str, List[str]], postprocess: bool = True
92
+ ) -> List[str]:
93
+ """
94
+ Diacritizes text while preserving non-Arabic characters and structure.
95
+
96
+ This method dissects the input text into Arabic and non-Arabic segments.
97
+ It processes only the Arabic segments and then reassembles the string,
98
+ maintaining the original order and content of all non-Arabic parts.
99
+
100
+ Any existing diacritics in the Arabic segments are stripped before
101
+ being processed by the model to ensure a consistent output.
102
+
103
+ Args:
104
+ text (str): The input text.
105
+
106
+ Returns:
107
+ The diacritized string.
108
+ """
109
+ if not text:
110
+ return ""
111
+
112
+ if isinstance(text, str):
113
+ # To handle a single string input gracefully
114
+ text_or_list = [text]
115
+
116
+ else:
117
+ # To handle a list of strings input gracefully
118
+ text_or_list = text
119
+
120
+ if not isinstance(text_or_list, list):
121
+ raise TypeError("Input must be a string or a list of strings.")
122
+
123
+ diacritized_list = [self._diacritize_sentence(s) for s in text_or_list]
124
+
125
+ if postprocess:
126
+ for i, diacritized_sentence in enumerate(diacritized_list):
127
+ diacritized_list[i] = Postprocessor.postprocess(diacritized_sentence)
128
+
129
+ if isinstance(text, str):
130
+ return diacritized_list[0]
131
+
132
+ return diacritized_list
diacritizer/exceptions.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class DiacritizerException(Exception):
2
+ """Base exception for all errors raised by the diacritizer package."""
3
+ pass
4
+
5
+ class ModelNotFound(DiacritizerException):
6
+ """Raised when the model files (ONNX, vocab, etc.) cannot be found."""
7
+ pass
8
+
9
+ class InvalidInputError(DiacritizerException):
10
+ """Raised when the input text provided to the diacritizer is invalid."""
11
+ pass
diacritizer/hub_manager.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Tuple
4
+
5
+ from huggingface_hub import hf_hub_download
6
+ from huggingface_hub.errors import EntryNotFoundError, LocalEntryNotFoundError
7
+
8
+ from .exceptions import ModelNotFound
9
+
10
+ DEFAULT_HUB_REPO_ID = "muhammad7777/arabic-diacritizer-models"
11
+
12
+
13
+ def _download_from_hub(
14
+ repo_id: str, architecture: str, size: str, revision: str, force_sync: bool
15
+ ) -> Tuple[str, str]:
16
+ """
17
+ Internal helper to download model artifacts from the Hub, using a nested subfolder.
18
+ """
19
+ # Construct the nested path (e.g., "bilstm/medium")
20
+ model_subfolder = f"{architecture}/{size}"
21
+
22
+ try:
23
+ # Attempt to load from cache first (offline-first)
24
+ if not force_sync:
25
+ onnx_path = hf_hub_download(
26
+ repo_id=repo_id,
27
+ filename="model.onnx",
28
+ subfolder=model_subfolder,
29
+ revision=revision,
30
+ local_files_only=True,
31
+ )
32
+ vocab_path = hf_hub_download(
33
+ repo_id=repo_id,
34
+ filename="vocab.json",
35
+ subfolder=model_subfolder,
36
+ revision=revision,
37
+ local_files_only=True,
38
+ )
39
+ return onnx_path, vocab_path
40
+ except LocalEntryNotFoundError:
41
+ pass # Not found in cache, proceed to download.
42
+
43
+ # Download from the Hub (online fallback)
44
+ try:
45
+ onnx_path = hf_hub_download(
46
+ repo_id=repo_id,
47
+ filename="model.onnx",
48
+ subfolder=model_subfolder,
49
+ revision=revision,
50
+ force_download=force_sync,
51
+ )
52
+ vocab_path = hf_hub_download(
53
+ repo_id=repo_id,
54
+ filename="vocab.json",
55
+ subfolder=model_subfolder,
56
+ revision=revision,
57
+ force_download=force_sync,
58
+ )
59
+ return onnx_path, vocab_path
60
+ except EntryNotFoundError as e:
61
+ # Make the error message more informative
62
+ raise ModelNotFound(
63
+ f"Could not find model for architecture '{architecture}' and size '{size}' "
64
+ f"at revision '{revision}' in repository '{repo_id}'. "
65
+ f"Please check the Hub for available models."
66
+ ) from e
67
+ except Exception as e:
68
+ raise ModelNotFound(
69
+ f"Failed to download model from the Hub. Please check your internet connection. "
70
+ f"Original error: {e}"
71
+ ) from e
72
+
73
+
74
+ def resolve_model_path(
75
+ model_identifier: str, architecture: str, size: str, revision: str, force_sync: bool
76
+ ) -> Tuple[Path, Path]:
77
+ """
78
+ Resolves model artifact paths, now with architecture awareness.
79
+ """
80
+ # Case 1: The identifier is a local directory path (no change here)
81
+ if os.path.isdir(model_identifier):
82
+ model_dir = Path(model_identifier)
83
+ onnx_path = model_dir / "model.onnx"
84
+ vocab_path = model_dir / "vocab.json"
85
+
86
+ if not onnx_path.exists() or not vocab_path.exists():
87
+ raise ModelNotFound(
88
+ f"Local model directory must contain 'model.onnx' and 'vocab.json'. "
89
+ f"Path: {model_dir}"
90
+ )
91
+ return onnx_path, vocab_path
92
+
93
+ # Case 2: The identifier is a Hugging Face Hub repository ID
94
+ # Delegate the download logic, passing the new architecture parameter.
95
+ onnx_path_str, vocab_path_str = _download_from_hub(
96
+ repo_id=model_identifier,
97
+ architecture=architecture,
98
+ size=size,
99
+ revision=revision,
100
+ force_sync=force_sync,
101
+ )
102
+
103
+ return Path(onnx_path_str), Path(vocab_path_str)
diacritizer/predictor.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from .exceptions import ModelNotFound
5
+
6
+
7
+ class OnnxPredictor:
8
+ def __init__(self, model_path: Path, use_gpu: bool = False):
9
+ """
10
+ Initializes the ONNX Runtime session.
11
+
12
+ Args:
13
+ model_path: Path to the .onnx model file.
14
+ use_gpu: Whether to use the GPU for inference. Defaults to False.
15
+
16
+ Raises:
17
+ ModelNotFound: If the model file does not exist at the given path.
18
+ """
19
+ if not model_path.exists():
20
+ raise ModelNotFound(f"ONNX model file not found at: {model_path}")
21
+
22
+ providers = ["CPUExecutionProvider"]
23
+ if use_gpu:
24
+ # You can customize this list based on your target hardware
25
+ providers.insert(0, "CUDAExecutionProvider")
26
+
27
+ self.session = ort.InferenceSession(str(model_path), providers=providers)
28
+ self.input_name = self.session.get_inputs()[0].name
29
+ self.hints_name = self.session.get_inputs()[1].name
30
+ self.output_name = self.session.get_outputs()[0].name
31
+
32
+ def predict(self, input_ids: np.ndarray, hints: np.ndarray) -> np.ndarray:
33
+ """
34
+ Runs inference on a batch of tokenized input IDs.
35
+
36
+ Args:
37
+ input_ids: A numpy array of shape (batch_size, sequence_length).
38
+ hints: A numpy array of shape (batch_size, sequence_length).
39
+
40
+ Returns:
41
+ A numpy array of logits of shape (batch_size, sequence_length, num_classes).
42
+ """
43
+ ort_inputs = {self.input_name: input_ids, self.hints_name: hints}
44
+ # The output is a list, we are interested in the first element
45
+ logits = self.session.run([self.output_name], ort_inputs)[0]
46
+ return logits
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio for the UI
2
+ gradio>=4.0.0
3
+
4
+ # Core dependencies
5
+ onnxruntime
6
+ numpy
7
+ huggingface-hub
8
+
9
+ # --- Custom Local Packages ---
10
+ # These will be installed from the .whl files in the repo
11
+ # arabic_diacritizer_common
12
+ # arabic_diacritizer