Files changed (1) hide show
  1. app.py +44 -77
app.py CHANGED
@@ -2,22 +2,39 @@ import os
2
  import yaml
3
  import gdown
4
  import gradio as gr
5
- from predict import PredictTri
6
  from huggingface_hub import hf_hub_download
7
 
 
 
 
8
  output_path = "tashkeela-d2.pt"
9
  gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
10
  if not os.path.exists(output_path):
 
11
  model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
12
  gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
 
13
 
 
14
  output_path = "vocab.vec"
15
  if not os.path.exists(output_path):
 
16
  vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
17
  gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
 
18
 
 
19
  if not os.path.exists("td2/tashkeela-ashaar-td2.pt"):
 
20
  hf_hub_download(repo_id="munael/Partial-Arabic-Diacritization-TD2", filename="tashkeela-ashaar-td2.pt", local_dir="td2")
 
 
 
 
 
 
 
21
 
22
  with open("config.yaml", 'r', encoding="utf-8") as file:
23
  config = yaml.load(file, Loader=yaml.FullLoader)
@@ -30,6 +47,8 @@ predictor = PredictTri(config)
30
  current_model_name = "TD2"
31
  config["model-name"] = current_model_name
32
 
 
 
33
  def diacritze_full(text, model_name):
34
  global current_model_name, predictor
35
  if model_name != current_model_name:
@@ -55,106 +74,54 @@ def diacritze_partial(text, mask_mode, threshold, model_name):
55
  diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
56
  return diacritized_lines
57
 
 
 
58
  with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
59
 
60
  gr.Markdown(
61
  """
62
- <img src='https://huggingface.co/spaces/bkhmsi/Partial-Arabic-Diacritization/resolve/main/PartialDD.png' style='float:right; margin: 0 0 10px 10px; width: 20%'/>
63
-
64
- <h1> Partial Diacritization: A Context-Contrastive Inference Approach </h1>
65
- <h2> Authors: Muhammad ElNokrashy, Badr AlKhamissi </h2>
66
- <h3> Paper: <a href="https://arxiv.org/abs/2401.08919"> https://arxiv.org/abs/2401.08919 </a> </h3>
67
- <details>
68
- <summary>Abstract</summary>
69
- <p>Diacritization plays a pivotal role in improving readability and disambiguating the meaning of Arabic texts. Efforts have so far focused on marking every eligible character (Full Diacritization). Comparatively overlooked, Partial Diacritzation (PD) is the selection of a subset of characters to be marked to aid comprehension where needed.Research has indicated that excessive diacritic marks can hinder skilled readers---reducing reading speed and accuracy. We conduct a behavioral experiment and show that partially marked text is often easier to read than fully marked text, and sometimes easier than plain text. In this light, we introduce Context-Contrastive Partial Diacritization (CCPD)---a novel approach to PD which integrates seamlessly with existing Arabic diacritization systems. CCPD processes each word twice, once with context and once without, and diacritizes only the characters with disparities between the two inferences. Further, we introduce novel indicators for measuring partial diacritization quality {SR, PDER, HDER, ERE}, essential for establishing this as a machine learning task. Lastly, we introduce TD2, a Transformer-variant of an established model which offers a markedly different performance profile on our proposed indicators compared to all other known systems.</p>
70
- </details>
71
  """)
72
 
73
-
74
  model_choice = gr.Dropdown(
75
  choices=["D2", "TD2"],
76
- label="Diacritization Model",
77
  value=current_model_name
78
  )
79
 
80
- with gr.Tab(label="Partial Diacritization") as partial_settings:
81
  with gr.Row():
82
- masking_mode = gr.Radio(choices=["Hard", "Soft"], value="Hard", label="Masking Mode")
83
- threshold_slider = gr.Slider(label="Soft Masking Threshold", minimum=0, maximum=1, value=0.1)
84
 
85
  partial_input_txt = gr.Textbox(
86
- placeholder="اكتب هنا",
87
- lines=5,
88
- label="Input",
89
- type='text',
90
- rtl=True,
91
- text_align='right',
92
  )
93
-
94
  partial_output_txt = gr.Textbox(
95
- lines=5,
96
- label="Output",
97
- type='text',
98
- rtl=True,
99
- text_align='right',
100
- show_copy_button=True,
101
- )
102
-
103
- partial_btn = gr.Button(value="Shakkel")
104
- partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider, model_choice], outputs=[partial_output_txt], queue=True)
105
-
106
- gr.Examples(
107
- examples=[
108
- ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Hard", 0, "TD2"],
109
- ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Soft", 0.1, "TD2"],
110
- ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Soft", 0.01, "TD2"],
111
- ],
112
- inputs=[partial_input_txt, masking_mode, threshold_slider, model_choice],
113
- outputs=partial_output_txt,
114
- fn=diacritze_partial,
115
- cache_examples=True,
116
  )
 
 
117
 
118
- with gr.Tab(label="Full Diacritization"):
119
-
120
  full_input_txt = gr.Textbox(
121
- placeholder="اكتب هنا",
122
- lines=5,
123
- label="Input",
124
- type='text',
125
- rtl=True,
126
- text_align='right',
127
  )
128
-
129
  full_output_txt = gr.Textbox(
130
- lines=5,
131
- label="Output",
132
- type='text',
133
- rtl=True,
134
- text_align='right',
135
- show_copy_button=True,
136
  )
 
 
137
 
138
- full_btn = gr.Button(value="Shakkel")
139
- full_btn.click(diacritze_full, inputs=[full_input_txt, model_choice], outputs=[full_output_txt], queue=True)
140
-
141
- gr.Examples(
142
- examples=[
143
- ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "TD2"],
144
- ],
145
- inputs=[full_input_txt, model_choice],
146
- outputs=full_output_txt,
147
- fn=diacritze_full,
148
- cache_examples=True,
149
- )
150
 
151
  if __name__ == "__main__":
 
 
 
152
  demo.queue().launch(
153
- # share=False,
154
- # debug=False,
155
- # server_port=7860,
156
- # server_name="0.0.0.0",
157
- # ssl_verify=False,
158
- # ssl_certfile="cert.pem",
159
- # ssl_keyfile="key.pem"
160
  )
 
2
  import yaml
3
  import gdown
4
  import gradio as gr
5
+ from predict import PredictTri # تأكد من وجود ملف predict.py في نفس المجلد
6
  from huggingface_hub import hf_hub_download
7
 
8
+ # --- تنزيل ملفات النموذج المطلوبة إذا لم تكن موجودة ---
9
+
10
+ # تنزيل النموذج الأول من Google Drive
11
  output_path = "tashkeela-d2.pt"
12
  gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
13
  if not os.path.exists(output_path):
14
+ print(f"Downloading model file: {output_path}...")
15
  model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
16
  gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
17
+ print("Download complete.")
18
 
19
+ # تنزيل ملف المفردات (vocab) من Google Drive
20
  output_path = "vocab.vec"
21
  if not os.path.exists(output_path):
22
+ print(f"Downloading vocab file: {output_path}...")
23
  vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
24
  gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
25
+ print("Download complete.")
26
 
27
+ # تنزيل النموذج الثاني من Hugging Face Hub
28
  if not os.path.exists("td2/tashkeela-ashaar-td2.pt"):
29
+ print("Downloading TD2 model from Hugging Face Hub...")
30
  hf_hub_download(repo_id="munael/Partial-Arabic-Diacritization-TD2", filename="tashkeela-ashaar-td2.pt", local_dir="td2")
31
+ print("Download complete.")
32
+
33
+ # --- تحميل الإعدادات وتهيئة المتنبئ ---
34
+
35
+ # تأكد من وجود ملف config.yaml
36
+ if not os.path.exists("config.yaml"):
37
+ raise FileNotFoundError("Error: config.yaml file not found. Please create it with the required configurations.")
38
 
39
  with open("config.yaml", 'r', encoding="utf-8") as file:
40
  config = yaml.load(file, Loader=yaml.FullLoader)
 
47
  current_model_name = "TD2"
48
  config["model-name"] = current_model_name
49
 
50
+ # --- تعريف دوال التشكيل (نقاط النهاية للـ API) ---
51
+
52
  def diacritze_full(text, model_name):
53
  global current_model_name, predictor
54
  if model_name != current_model_name:
 
74
  diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
75
  return diacritized_lines
76
 
77
+ # --- بناء واجهة المستخدم باستخدام Gradio ---
78
+
79
  with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
80
 
81
  gr.Markdown(
82
  """
83
+ <h1>أداة تشكيل النصوص العربية</h1>
84
+ <p>هذا الخادم يعمل محلياً على جهازك.</p>
 
 
 
 
 
 
 
85
  """)
86
 
 
87
  model_choice = gr.Dropdown(
88
  choices=["D2", "TD2"],
89
+ label="نموذج التشكيل",
90
  value=current_model_name
91
  )
92
 
93
+ with gr.Tab(label="التشكيل الجزئي") as partial_settings:
94
  with gr.Row():
95
+ masking_mode = gr.Radio(choices=["Hard", "Soft"], value="Hard", label="نمط الإخفاء")
96
+ threshold_slider = gr.Slider(label="عتبة الإخفاء الناعم", minimum=0, maximum=1, value=0.1)
97
 
98
  partial_input_txt = gr.Textbox(
99
+ placeholder="اكتب هنا", lines=5, label="النص المدخل", rtl=True
 
 
 
 
 
100
  )
 
101
  partial_output_txt = gr.Textbox(
102
+ lines=5, label="النص المخرَج", rtl=True, show_copy_button=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
+ partial_btn = gr.Button(value="شَكِّل جزئياً")
105
+ partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider, model_choice], outputs=[partial_output_txt], api_name="diacritze_partial")
106
 
107
+ with gr.Tab(label="التشكيل الكامل"):
 
108
  full_input_txt = gr.Textbox(
109
+ placeholder="اكتب هنا", lines=5, label="النص المدخل", rtl=True
 
 
 
 
 
110
  )
 
111
  full_output_txt = gr.Textbox(
112
+ lines=5, label="النص المخرَج", rtl=True, show_copy_button=True
 
 
 
 
 
113
  )
114
+ full_btn = gr.Button(value="شَكِّل بالكامل")
115
+ full_btn.click(diacritze_full, inputs=[full_input_txt, model_choice], outputs=[full_output_txt], api_name="diacritze_full")
116
 
117
+ # --- تشغيل الخادم (الجزء الذي تم تعديله) ---
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  if __name__ == "__main__":
120
+ print("Starting Gradio server on local network...")
121
+ # هذا السطر يجعل الخادم متاحاً على شبكتك المحلية
122
+ # استخدم عنوان IP الخاص بجهازك مع المنفذ 7860 للاتصال به من الأجهزة الأخرى
123
  demo.queue().launch(
124
+ server_name="0.0.0.0",
125
+ server_port=7860,
126
+ share=False
 
 
 
 
127
  )