youssefreda9 commited on
Commit
8bea99d
·
1 Parent(s): 08ba334

feat: NLP-3 PuncAra-v1 Integration — Local punctuation model - Created src/nlp/punctuation/ package (rules + service) - Extracted from PuncAra.py: preprocessing, postprocessing, chunking logic - PunctuationChecker: lazy-loaded EncoderDecoderModel from bayan10/PuncAra-v1 - Wired into /api/analyze as Step 3 (Spelling -> Grammar -> Punctuation) - Updated /api/punctuation standalone endpoint - Updated /api/health to report punctuation status - Dockerfile pre-downloads PuncAra-v1 weights during build - Increased gunicorn timeout to 300s for full pipeline

Browse files
Dockerfile CHANGED
@@ -48,6 +48,17 @@ print('Spelling model + MLM cached!'); \
48
  # 3. Grammar — camel-tools MLE disambiguator data
49
  RUN camel_data -i light
50
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Copy application code
52
  COPY src/ ./src/
53
  COPY .env* ./
@@ -61,4 +72,5 @@ ENV PYTHONUNBUFFERED=1
61
  EXPOSE 7860
62
 
63
  # Start the app with gunicorn (single worker to minimize RAM)
64
- CMD ["gunicorn", "--chdir", "src", "app:app", "--bind", "0.0.0.0:7860", "--timeout", "120", "--workers", "1"]
 
 
48
  # 3. Grammar — camel-tools MLE disambiguator data
49
  RUN camel_data -i light
50
 
51
+ # 4. Punctuation model (PuncAra-v1 — EncoderDecoderModel)
52
+ RUN python -c "\
53
+ from transformers import EncoderDecoderModel, AutoTokenizer; \
54
+ repo = 'bayan10/PuncAra-v1'; \
55
+ print('Downloading PuncAra-v1 tokenizer...'); \
56
+ AutoTokenizer.from_pretrained(repo); \
57
+ print('Downloading PuncAra-v1 model...'); \
58
+ EncoderDecoderModel.from_pretrained(repo); \
59
+ print('PuncAra-v1 cached!'); \
60
+ "
61
+
62
  # Copy application code
63
  COPY src/ ./src/
64
  COPY .env* ./
 
72
  EXPOSE 7860
73
 
74
  # Start the app with gunicorn (single worker to minimize RAM)
75
+ # Timeout 300s: full pipeline (spelling ~50s + grammar ~8s + punctuation ~30s + cold start)
76
+ CMD ["gunicorn", "--chdir", "src", "app:app", "--bind", "0.0.0.0:7860", "--timeout", "300", "--workers", "1"]
PuncAra.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Untitled18.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1ebBGzEo4wbwwvReea_n0PRHdfYescKcs
8
+ """
9
+
10
+ import os
11
+ import torch
12
+ from transformers import EncoderDecoderModel, AutoTokenizer
13
+ import re
14
+
15
+ # تعريف الثوابت
16
+ HF_REPO_ID = "bayan10/PuncAra-v1"
17
+
18
+ # متغيرات عامة
19
+ device = None
20
+ test_model = None
21
+ test_tokenizer = None
22
+
23
+ def initialize_model(repo_id=HF_REPO_ID):
24
+ """
25
+ تهيئة وإعداد كرت الشاشة وتحميل النموذج والـ Tokenizer من Hugging Face Hub.
26
+ يتم استدعاء هذه الدالة مرة واحدة فقط في بداية تشغيل المشروع.
27
+ """
28
+ global device, test_model, test_tokenizer
29
+ print(f"Loading test model directly from Hugging Face Hub: {repo_id}")
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ print(f"Loading test model to: {device}")
32
+
33
+ if device.type == "cuda" and not torch.cuda.is_available():
34
+ print("Warning: CUDA device requested, but torch.cuda.is_available() is False. Model will be loaded to CPU.")
35
+ device = torch.device("cpu")
36
+
37
+ test_model = EncoderDecoderModel.from_pretrained(repo_id).to(device)
38
+ test_tokenizer = AutoTokenizer.from_pretrained(repo_id)
39
+
40
+ # إعداد الـ Special tokens للـ Decoder والـ Encoder
41
+ test_model.config.decoder_start_token_id = test_tokenizer.cls_token_id
42
+ test_model.config.bos_token_id = test_tokenizer.cls_token_id
43
+ test_model.config.eos_token_id = test_tokenizer.sep_token_id
44
+ test_model.config.pad_token_id = test_tokenizer.pad_token_id
45
+ print("Model and Tokenizer loaded successfully!")
46
+
47
+ def predict_chunk(text_chunk):
48
+ """توليد التوقعات لعلامات الترقيم لقطعة نصية صغيرة لا تتعدى الـ 128 Token."""
49
+ global device, test_model, test_tokenizer
50
+ if test_model is None or test_tokenizer is None:
51
+ raise RuntimeError("الموديل لم يتم تهيئته بعد. يرجى استدعاء initialize_model() أولاً.")
52
+
53
+ # تطبيق الـ Preprocessing لتنظيف التشكيل قبل دخول النص للموديل
54
+ text_chunk = arabic_preprocessing(text_chunk)
55
+
56
+ inputs = test_tokenizer(text_chunk, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
57
+
58
+ outputs = test_model.generate(
59
+ inputs.input_ids,
60
+ attention_mask=inputs.attention_mask,
61
+ decoder_start_token_id=test_tokenizer.cls_token_id,
62
+ bos_token_id=test_tokenizer.cls_token_id,
63
+ eos_token_id=test_tokenizer.sep_token_id,
64
+ pad_token_id=test_tokenizer.pad_token_id,
65
+ max_length=128,
66
+ num_beams=3,
67
+ repetition_penalty=1.2,
68
+ length_penalty=1.0,
69
+ early_stopping=True,
70
+ do_sample=False
71
+ )
72
+ return test_tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+
74
+ def arabic_preprocessing(text):
75
+ """حذف الحركات التشكيلية لتوحيد المدخلات وتسهيل عمل الموديل."""
76
+ arabic_diacritics = re.compile(r'[\u064B-\u0652]')
77
+ return re.sub(arabic_diacritics, '', text).strip()
78
+
79
+ def arabic_postprocessing(text):
80
+ """
81
+ التنظيف والتحسين المطبعي وعلاج مشاكل دمج النصوص وعلامات الترقيم الزائدة.
82
+ """
83
+ if not text:
84
+ return text
85
+
86
+ # 1. حماية الأرقام والكسور والتوقيت من التحويل الخاطئ
87
+ text = re.sub(r'(?<=\d),(?=\d)', '٪TEMP_COMMA٪', text)
88
+ text = re.sub(r'(?<=\d):(?=\d)', '٪TEMP_COLON٪', text)
89
+
90
+ # 2. التوحيد والتعريب المطبعي للعلامات
91
+ text = text.replace(',', '،').replace(';', '؛').replace('?', '؟')
92
+
93
+ # 3. ضبط المسافات الداخلية للأقواس وعلامات الاقتباس العربي
94
+ text = re.sub(r'\(\s+', '(', text)
95
+ text = re.sub(r'\s+\)', ')', text)
96
+ text = re.sub(r'\[\s+', '[', text)
97
+ text = re.sub(r'\s+\]', ']', text)
98
+ text = re.sub(r'«\s+', '«', text)
99
+ text = re.sub(r'\s+»', '»', text)
100
+
101
+ # 4. منع تكرار العلامات الانفعالية عدا النقاط الثلاثية للحذف
102
+ text = re.sub(r'([،؛:!؟])\1+', r'\1', text)
103
+ text = re.sub(r'\.{4,}', '...', text)
104
+
105
+ # 5. معالجة التناقضات المباشرة الناتجة عن تجميع الـ Chunks
106
+ text = re.sub(r'[،؛:]+([.!؟])', r'\1', text)
107
+ text = re.sub(r'،؛|؛،', '؛', text)
108
+ text = re.sub(r'([!؟])\.', r'\1', text)
109
+
110
+ # 6. مسح علامات الترقيم العشوائية إذا ظهرت أول النص
111
+ text = re.sub(r'^[،؛:!؟. \t]+', '', text)
112
+
113
+ # 7. ضمان مسافة فارغة واحدة بعد علامة الترقيم إذا تبعها كلام
114
+ text = re.sub(r'([،؛:!؟.])(?=\S)', r'\1 ', text)
115
+
116
+ # 8. إعادة الأرقام والكسور والتوقيت المحمية إلى أصلها
117
+ text = text.replace('٪TEMP_COMMA٪', ',').replace('٪TEMP_COLON٪', ':')
118
+
119
+ # 9. إلصاق علامات الترقيم بالكلمة السابقة لها مباشرة
120
+ text = re.sub(r'\s+([،؛:!؟.])', r'\1', text)
121
+
122
+ # 10. إزالة المسافات المتكررة الأفقية فقط (بدون لمس السطور الجديدة)
123
+ text = re.sub(r'[ \t]+', ' ', text).strip()
124
+ return text
125
+
126
+ def fix_punctuation(text):
127
+ """معالجة الفقرة الواحدة الطويلة عبر تقسيمها لقطع غير متداخلة لمنع التكرار."""
128
+ words = text.split()
129
+ total_words = len(words)
130
+
131
+ # جعل حجم الخطوة مساوياً لحجم النافذة يمنع تكرار الكلمات تماماً
132
+ window_size = 50
133
+ stride = 50
134
+
135
+ if total_words <= window_size:
136
+ result = predict_chunk(text)
137
+ else:
138
+ segments_output = []
139
+
140
+ for i in range(0, total_words, stride):
141
+ chunk_words = words[i : i + window_size]
142
+ chunk_text = " ".join(chunk_words)
143
+
144
+ if not chunk_text.strip():
145
+ continue
146
+
147
+ processed_segment = predict_chunk(chunk_text).strip()
148
+
149
+ # مسح علامات الترقيم الناتجة عن القص الإجباري بين القطع
150
+ is_last_segment = (i + window_size) >= total_words
151
+ if not is_last_segment:
152
+ punctuation_marks = ".?!،؛:؟!"
153
+ if processed_segment and processed_segment[-1] in punctuation_marks:
154
+ # نمسح العلامة تماماً لأن السياق مستمر في القطعة اللي بعدها
155
+ processed_segment = processed_segment[:-1]
156
+
157
+ segments_output.append(processed_segment)
158
+
159
+ result = " ".join(segments_output)
160
+
161
+ # تنظيف المسافات الزائدة والتكرار إن وجد
162
+ result = re.sub(r'\s+', ' ', result).strip()
163
+ return result
164
+
165
+ def process_full_document(text):
166
+ if not text:
167
+ return text
168
+
169
+ # تقسيم بناءً على السطور الجديدة وتنظيف الأسطر الفارغة
170
+ paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
171
+ processed_paragraphs = []
172
+
173
+ for paragraph in paragraphs:
174
+ # معالجة الفقرة المستقلة
175
+ punctuated_paragraph = fix_punctuation(paragraph)
176
+ cleaned_paragraph = arabic_postprocessing(punctuated_paragraph)
177
+ processed_paragraphs.append(cleaned_paragraph)
178
+
179
+ # الدمج بسطرين متباعدين لضمان الفصل البصري التام بين الفقرات
180
+ return "\n\n".join(processed_paragraphs)
src/app.py CHANGED
@@ -156,7 +156,7 @@ def health_check():
156
  'spelling': _spelling_available(),
157
  'autocomplete': False,
158
  'grammar': _grammar_available(),
159
- 'punctuation': False
160
  },
161
  'note': 'Free tier: summarization local, other models return input unchanged',
162
  'supabase': {
@@ -241,6 +241,15 @@ def _grammar_available():
241
  return False
242
 
243
 
 
 
 
 
 
 
 
 
 
244
  @app.route('/api/spelling', methods=['POST'])
245
  def spelling_correction():
246
  """
@@ -506,42 +515,47 @@ def grammar_correction():
506
  @app.route('/api/punctuation', methods=['POST'])
507
  def add_punctuation():
508
  """
509
- Add punctuation to Arabic text.
510
-
511
- Expected JSON payload:
512
  {
513
  "text": "Arabic text without punctuation"
514
  }
 
 
 
 
 
 
 
515
  """
516
- if not USE_HF_API and punctuation_model is None:
517
- return jsonify({
518
- 'error': 'Punctuation model not loaded. Please check server logs.',
519
- 'status': 'error'
520
- }), 503
521
-
522
  try:
523
  if not request.is_json:
524
  return jsonify({'error': 'Request must be JSON', 'status': 'error'}), 400
525
-
526
  data = request.get_json()
527
  text = data.get('text', '').strip()
528
-
529
  if not text:
530
  return jsonify({'error': 'Text is required', 'status': 'error'}), 400
531
-
532
  logger.info(f"Adding punctuation for text of length: {len(text)}")
533
- if USE_HF_API:
534
- punctuated = hf_add_punctuation(text)
535
- else:
536
- punctuated = punctuation_model.add_punctuation(text)
537
-
538
  return jsonify({
539
- 'punctuated': punctuated,
540
- 'status': 'success',
541
- 'original_length': len(text),
542
- 'punctuated_length': len(punctuated)
543
  })
544
-
 
 
 
 
 
 
545
  except Exception as e:
546
  logger.error(f"Error during punctuation: {str(e)}")
547
  logger.error(traceback.format_exc())
@@ -893,31 +907,29 @@ def analyze_text():
893
  except Exception as e:
894
  logger.error(f"[ANALYZE] Grammar failed: {e}")
895
 
896
- # 3. Punctuation (runs on grammar-corrected text)
897
- has_punctuation = USE_HF_API or punctuation_model
898
- if has_punctuation:
899
- try:
900
- t0 = time.time()
901
- logger.info(f"[ANALYZE] Step 3: Punctuation starting...")
902
- if USE_HF_API:
903
- corrected_punc = hf_add_punctuation(current_text)
904
- else:
905
- corrected_punc = punctuation_model.add_punctuation(current_text)
906
- logger.info(f"[ANALYZE] Step 3: Punctuation done in {time.time()-t0:.2f}s")
907
- if corrected_punc != current_text:
908
- diffs = get_word_diffs(current_text, corrected_punc)
909
- for d in diffs:
910
- orig_start, orig_end = map_range_to_original(d['start'], d['end'])
911
- suggestions.append({
912
- 'start': orig_start,
913
- 'end': orig_end,
914
- 'original': text[orig_start:orig_end],
915
- 'correction': d['correction'],
916
- 'type': 'punctuation'
917
- })
918
- current_text = corrected_punc
919
- except Exception as e:
920
- logger.error(f"[ANALYZE] Punctuation failed: {e}")
921
 
922
  total_time = time.time() - total_start
923
 
 
156
  'spelling': _spelling_available(),
157
  'autocomplete': False,
158
  'grammar': _grammar_available(),
159
+ 'punctuation': _punctuation_available()
160
  },
161
  'note': 'Free tier: summarization local, other models return input unchanged',
162
  'supabase': {
 
241
  return False
242
 
243
 
244
+ def _punctuation_available():
245
+ """Check if punctuation model is loaded (without triggering lazy load)."""
246
+ try:
247
+ from nlp.punctuation.punctuation_service import is_loaded
248
+ return is_loaded()
249
+ except Exception:
250
+ return False
251
+
252
+
253
  @app.route('/api/spelling', methods=['POST'])
254
  def spelling_correction():
255
  """
 
515
  @app.route('/api/punctuation', methods=['POST'])
516
  def add_punctuation():
517
  """
518
+ Add punctuation to Arabic text using PuncAra-v1.
519
+
520
+ Request JSON:
521
  {
522
  "text": "Arabic text without punctuation"
523
  }
524
+
525
+ Response JSON:
526
+ {
527
+ "status": "success",
528
+ "original_text": "...",
529
+ "corrected_text": "..."
530
+ }
531
  """
 
 
 
 
 
 
532
  try:
533
  if not request.is_json:
534
  return jsonify({'error': 'Request must be JSON', 'status': 'error'}), 400
535
+
536
  data = request.get_json()
537
  text = data.get('text', '').strip()
538
+
539
  if not text:
540
  return jsonify({'error': 'Text is required', 'status': 'error'}), 400
541
+
542
  logger.info(f"Adding punctuation for text of length: {len(text)}")
543
+ from nlp.punctuation.punctuation_service import get_punctuation_model
544
+ punc_checker = get_punctuation_model()
545
+ punctuated = punc_checker.correct(text)
546
+
 
547
  return jsonify({
548
+ 'original_text': text,
549
+ 'corrected_text': punctuated,
550
+ 'status': 'success'
 
551
  })
552
+
553
+ except RuntimeError as e:
554
+ logger.error(f"Punctuation model error: {e}")
555
+ return jsonify({
556
+ 'error': f'Punctuation model unavailable: {str(e)[:200]}',
557
+ 'status': 'error'
558
+ }), 503
559
  except Exception as e:
560
  logger.error(f"Error during punctuation: {str(e)}")
561
  logger.error(traceback.format_exc())
 
907
  except Exception as e:
908
  logger.error(f"[ANALYZE] Grammar failed: {e}")
909
 
910
+ # 3. Punctuation (runs on grammar-corrected text — PuncAra-v1 local model)
911
+ try:
912
+ t0 = time.time()
913
+ logger.info(f"[ANALYZE] Step 3: Punctuation starting...")
914
+ from nlp.punctuation.punctuation_service import get_punctuation_model
915
+ punc_checker = get_punctuation_model()
916
+ corrected_punc = punc_checker.correct(current_text)
917
+ logger.info(f"[ANALYZE] Step 3: Punctuation done in {time.time()-t0:.2f}s")
918
+ if corrected_punc != current_text:
919
+ diffs = get_word_diffs(current_text, corrected_punc)
920
+ for d in diffs:
921
+ orig_start, orig_end = map_range_to_original(d['start'], d['end'])
922
+ suggestions.append({
923
+ 'start': orig_start,
924
+ 'end': orig_end,
925
+ 'original': text[orig_start:orig_end],
926
+ 'correction': d['correction'],
927
+ 'type': 'punctuation'
928
+ })
929
+ mappers.append(OffsetMapper(current_text, corrected_punc))
930
+ current_text = corrected_punc
931
+ except Exception as e:
932
+ logger.error(f"[ANALYZE] Punctuation failed: {e}")
 
 
933
 
934
  total_time = time.time() - total_start
935
 
src/nlp/punctuation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # PuncAra punctuation package
src/nlp/punctuation/punctuation_rules.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuncAra — Arabic Punctuation Restoration Rules
2
+ # Extracted from PuncAra.py — preprocessing + postprocessing + chunking logic.
3
+ # All classes are imported by punctuation_service.py.
4
+
5
+ import re
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def arabic_preprocessing(text: str) -> str:
12
+ """Remove Arabic diacritics to normalize input for the model."""
13
+ arabic_diacritics = re.compile(r'[\u064B-\u0652]')
14
+ return re.sub(arabic_diacritics, '', text).strip()
15
+
16
+
17
+ def arabic_postprocessing(text: str) -> str:
18
+ """
19
+ Typographic cleanup and punctuation normalization after model inference.
20
+ Handles: bracket spacing, duplicate marks, chunk-join artifacts, etc.
21
+ """
22
+ if not text:
23
+ return text
24
+
25
+ # 1. Protect numbers/fractions/time from incorrect conversion
26
+ text = re.sub(r'(?<=\d),(?=\d)', '٪TEMP_COMMA٪', text)
27
+ text = re.sub(r'(?<=\d):(?=\d)', '٪TEMP_COLON٪', text)
28
+
29
+ # 2. Arabize typographic marks
30
+ text = text.replace(',', '،').replace(';', '؛').replace('?', '؟')
31
+
32
+ # 3. Fix internal spacing for brackets and Arabic quotes
33
+ text = re.sub(r'\(\s+', '(', text)
34
+ text = re.sub(r'\s+\)', ')', text)
35
+ text = re.sub(r'\[\s+', '[', text)
36
+ text = re.sub(r'\s+\]', ']', text)
37
+ text = re.sub(r'«\s+', '«', text)
38
+ text = re.sub(r'\s+»', '»', text)
39
+
40
+ # 4. Remove repeated emotional marks (except ellipsis)
41
+ text = re.sub(r'([،؛:!؟])\1+', r'\1', text)
42
+ text = re.sub(r'\.{4,}', '...', text)
43
+
44
+ # 5. Fix chunk-join contradictions
45
+ text = re.sub(r'[،؛:]+([.!؟])', r'\1', text)
46
+ text = re.sub(r'،؛|؛،', '؛', text)
47
+ text = re.sub(r'([!؟])\.', r'\1', text)
48
+
49
+ # 6. Remove stray leading punctuation
50
+ text = re.sub(r'^[،؛:!؟. \t]+', '', text)
51
+
52
+ # 7. Ensure single space after punctuation before text
53
+ text = re.sub(r'([،؛:!؟.])(?=\S)', r'\1 ', text)
54
+
55
+ # 8. Restore protected numbers
56
+ text = text.replace('٪TEMP_COMMA٪', ',').replace('٪TEMP_COLON٪', ':')
57
+
58
+ # 9. Attach punctuation to preceding word
59
+ text = re.sub(r'\s+([،؛:!؟.])', r'\1', text)
60
+
61
+ # 10. Collapse horizontal spaces only
62
+ text = re.sub(r'[ \t]+', ' ', text).strip()
63
+ return text
src/nlp/punctuation/punctuation_service.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Punctuation Service — Lazy-loaded Arabic punctuation restoration.
3
+
4
+ Uses:
5
+ 1. bayan10/PuncAra-v1 (EncoderDecoderModel — local, seq2seq)
6
+ 2. Rule-based pre/post-processing from punctuation_rules.py
7
+
8
+ Model loaded on first request and kept in memory.
9
+ """
10
+
11
+ import logging
12
+ import time
13
+ import torch
14
+ import re
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # ── Lazy-loaded singletons ──
19
+ _punctuation_checker = None
20
+ _load_error = None
21
+
22
+ HF_REPO_ID = "bayan10/PuncAra-v1"
23
+
24
+
25
+ class PunctuationChecker:
26
+ """
27
+ Arabic punctuation restoration pipeline:
28
+ 1. Preprocessing (remove diacritics)
29
+ 2. Model inference (chunked, windowed — 50 words/chunk)
30
+ 3. Postprocessing (typographic cleanup)
31
+ """
32
+
33
+ def __init__(self, model, tokenizer, device):
34
+ self.model = model
35
+ self.tokenizer = tokenizer
36
+ self.device = device
37
+
38
+ def _predict_chunk(self, text_chunk: str) -> str:
39
+ """Run model inference on a single chunk (max 128 tokens)."""
40
+ from nlp.punctuation.punctuation_rules import arabic_preprocessing
41
+
42
+ text_chunk = arabic_preprocessing(text_chunk)
43
+
44
+ inputs = self.tokenizer(
45
+ text_chunk, return_tensors="pt",
46
+ padding=True, truncation=True, max_length=128
47
+ ).to(self.device)
48
+
49
+ with torch.no_grad():
50
+ outputs = self.model.generate(
51
+ inputs.input_ids,
52
+ attention_mask=inputs.attention_mask,
53
+ decoder_start_token_id=self.tokenizer.cls_token_id,
54
+ bos_token_id=self.tokenizer.cls_token_id,
55
+ eos_token_id=self.tokenizer.sep_token_id,
56
+ pad_token_id=self.tokenizer.pad_token_id,
57
+ max_length=128,
58
+ num_beams=3,
59
+ repetition_penalty=1.2,
60
+ length_penalty=1.0,
61
+ early_stopping=True,
62
+ do_sample=False
63
+ )
64
+
65
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+
67
+ def _fix_punctuation(self, text: str) -> str:
68
+ """Process a paragraph using non-overlapping window chunking."""
69
+ words = text.split()
70
+ total_words = len(words)
71
+ window_size = 50
72
+ stride = 50
73
+
74
+ if total_words <= window_size:
75
+ return self._predict_chunk(text)
76
+
77
+ segments_output = []
78
+ for i in range(0, total_words, stride):
79
+ chunk_words = words[i: i + window_size]
80
+ chunk_text = " ".join(chunk_words)
81
+ if not chunk_text.strip():
82
+ continue
83
+
84
+ processed_segment = self._predict_chunk(chunk_text).strip()
85
+
86
+ # Remove trailing punctuation from non-last segments (context continues)
87
+ is_last_segment = (i + window_size) >= total_words
88
+ if not is_last_segment:
89
+ punctuation_marks = ".?!،؛:؟!"
90
+ if processed_segment and processed_segment[-1] in punctuation_marks:
91
+ processed_segment = processed_segment[:-1]
92
+
93
+ segments_output.append(processed_segment)
94
+
95
+ result = " ".join(segments_output)
96
+ result = re.sub(r'\s+', ' ', result).strip()
97
+ return result
98
+
99
+ def correct(self, text: str) -> str:
100
+ """
101
+ Run full punctuation restoration on text.
102
+ Handles multi-paragraph documents.
103
+ Returns punctuated text, or original text on failure.
104
+ """
105
+ if not text or not text.strip():
106
+ return text
107
+
108
+ try:
109
+ from nlp.punctuation.punctuation_rules import arabic_postprocessing
110
+
111
+ # Split into paragraphs
112
+ paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
113
+ processed_paragraphs = []
114
+
115
+ for paragraph in paragraphs:
116
+ punctuated = self._fix_punctuation(paragraph)
117
+ cleaned = arabic_postprocessing(punctuated)
118
+ processed_paragraphs.append(cleaned)
119
+
120
+ result = "\n\n".join(processed_paragraphs)
121
+ logger.info(f"Punctuation output: '{result[:80]}...' (input: '{text[:80]}...')")
122
+ return result
123
+
124
+ except Exception as e:
125
+ logger.error(f"Punctuation correction failed: {e}")
126
+ return text
127
+
128
+
129
+ def get_punctuation_model():
130
+ """
131
+ Lazy-load the punctuation model on first call.
132
+ Returns the PunctuationChecker instance, or raises RuntimeError if loading fails.
133
+ """
134
+ global _punctuation_checker, _load_error
135
+
136
+ if _punctuation_checker is not None:
137
+ return _punctuation_checker
138
+
139
+ if _load_error is not None:
140
+ raise RuntimeError(f"Punctuation model previously failed to load: {_load_error}")
141
+
142
+ try:
143
+ t0 = time.time()
144
+ logger.info("Loading PuncAra-v1 punctuation model (lazy init)...")
145
+
146
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
+ logger.info(f"Punctuation model device: {device}")
148
+
149
+ from transformers import EncoderDecoderModel, AutoTokenizer
150
+
151
+ logger.info(f"Loading model from HF Hub: {HF_REPO_ID}")
152
+ model = EncoderDecoderModel.from_pretrained(HF_REPO_ID)
153
+ tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)
154
+
155
+ # Configure special tokens
156
+ model.config.decoder_start_token_id = tokenizer.cls_token_id
157
+ model.config.bos_token_id = tokenizer.cls_token_id
158
+ model.config.eos_token_id = tokenizer.sep_token_id
159
+ model.config.pad_token_id = tokenizer.pad_token_id
160
+
161
+ model = model.to(device)
162
+ model.eval()
163
+
164
+ _punctuation_checker = PunctuationChecker(model, tokenizer, device)
165
+
166
+ elapsed = time.time() - t0
167
+ logger.info(f"PuncAra-v1 ready in {elapsed:.1f}s")
168
+ return _punctuation_checker
169
+
170
+ except Exception as e:
171
+ import traceback
172
+ _load_error = str(e)
173
+ logger.error(f"Failed to load punctuation model: {e}")
174
+ logger.error(traceback.format_exc())
175
+ raise RuntimeError(f"Punctuation model load failed: {e}")
176
+
177
+
178
+ def is_loaded() -> bool:
179
+ """Check if the punctuation model is loaded."""
180
+ return _punctuation_checker is not None
181
+
182
+
183
+ def get_load_error() -> str:
184
+ """Return the last load error, or empty string."""
185
+ return _load_error or ""