fadimari commited on
Commit
b65ca7a
·
verified ·
1 Parent(s): 240af3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -1,4 +1,6 @@
1
- # app.py (بدون camel_tools)
 
 
2
  import gradio as gr
3
  import torch
4
  import torch.nn.functional as F
@@ -6,33 +8,32 @@ from transformers import AutoTokenizer, BertForTokenClassification, MBartForCond
6
 
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # نماذج CAMeL‑Lab
10
- GED_MODEL_NAME = "CAMeL-Lab/camelbert-msa-qalb14-ged-13" # أو qalb15
11
- GEC_MODEL_NAME = "CAMeL-Lab/arabart-qalb14-gec-ged-13" # أو qalb15
12
 
 
13
  ged_tokenizer = AutoTokenizer.from_pretrained(GED_MODEL_NAME)
14
  ged_model = BertForTokenClassification.from_pretrained(GED_MODEL_NAME).to(DEVICE).eval()
15
 
16
  gec_tokenizer = AutoTokenizer.from_pretrained(GEC_MODEL_NAME)
17
  gec_model = MBartForConditionalGeneration.from_pretrained(GEC_MODEL_NAME).to(DEVICE).eval()
18
 
 
19
  def camel_gec_correct(text: str) -> str:
20
  if not text or not text.strip():
21
  return "⚠️ يرجى إدخال نص."
22
- # 1) GED مباشرة على النص المُدخل
 
23
  inputs = ged_tokenizer([text], return_tensors="pt")
24
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
25
  with torch.no_grad():
26
  logits = ged_model(**inputs).logits
27
- probs = torch.softmax(logits, dim=-1).squeeze()[1:-1]
28
  pred_ids = torch.argmax(probs, -1)
29
  pred_ged_labels = [ged_model.config.id2label[i.item()] for i in pred_ids]
30
 
31
- # 2) محاذاة وسوم GED مع تقطيع GEC
32
- ged_label2ids = getattr(gec_model.config, "ged_label2id", None)
33
- if ged_label2ids is None:
34
- raise RuntimeError("Model config missing ged_label2id mapping.")
35
-
36
  tokens, ged_labels = [], []
37
  for word, label in zip(text.split(), pred_ged_labels):
38
  w_toks = gec_tokenizer.tokenize(word)
@@ -40,29 +41,33 @@ def camel_gec_correct(text: str) -> str:
40
  tokens.extend(w_toks)
41
  ged_labels.extend([label] * len(w_toks))
42
 
43
- input_ids = [gec_tokenizer.bos_token_id] + gec_tokenizer.convert_tokens_to_ids(tokens) + [gec_tokenizer.eos_token_id]
44
- label_ids = [ged_label2ids.get("UC", 0)] + [ged_label2ids.get(l, ged_label2ids.get("<pad>", 0)) for l in ged_labels] + [ged_label2ids.get("UC", 0)]
45
- attention_mask = [1] * len(input_ids)
 
46
 
47
-
48
- with torch.no_grad():
49
- generated = gec_model.generate(
50
- torch.tensor([input_ids], device=DEVICE),
51
- num_beams=5,
52
- max_length=max(64, len(input_ids) + 20),
53
- num_return_sequences=1,
54
- )
55
- return gec_tokenizer.batch_decode(
56
- generated, skip_special_tokens=True, clean_up_tokenization_spaces=False
57
- )[0]
 
58
 
59
- def ui_correct(text): return camel_gec_correct(text)
 
 
60
 
61
- with gr.Blocks(title="CAMeL-Lab Arabic GEC/GED") as demo:
62
- gr.Markdown("### تصحيح نحوي عربي (CAMeL-Lab GEDGEC)")
63
- inp = gr.Textbox(label="النص", lines=6, placeholder="أدخل النص العربي...")
64
  btn = gr.Button("تصحيح")
65
- out = gr.Textbox(label="النص المصحح")
66
  btn.click(fn=ui_correct, inputs=inp, outputs=out)
67
 
68
  if __name__ == "__main__":
 
1
+ # -*- coding: utf-8 -*-
2
+ # app.py — CAMeL‑Lab GED→GEC (بدون camel_tools / بدون ged_tags في generate)
3
+
4
  import gradio as gr
5
  import torch
6
  import torch.nn.functional as F
 
8
 
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # 1) أسماء النماذج (يمكن ترقية qalb14 إلى qalb15 إن رغبت)
12
+ GED_MODEL_NAME = "CAMeL-Lab/camelbert-msa-qalb14-ged-13"
13
+ GEC_MODEL_NAME = "CAMeL-Lab/arabart-qalb14-gec-ged-13"
14
 
15
+ # 2) تحميل النماذج والمُرمّزات
16
  ged_tokenizer = AutoTokenizer.from_pretrained(GED_MODEL_NAME)
17
  ged_model = BertForTokenClassification.from_pretrained(GED_MODEL_NAME).to(DEVICE).eval()
18
 
19
  gec_tokenizer = AutoTokenizer.from_pretrained(GEC_MODEL_NAME)
20
  gec_model = MBartForConditionalGeneration.from_pretrained(GEC_MODEL_NAME).to(DEVICE).eval()
21
 
22
+ # 3) الدالة الأساسية
23
  def camel_gec_correct(text: str) -> str:
24
  if not text or not text.strip():
25
  return "⚠️ يرجى إدخال نص."
26
+
27
+ # (أ) GED مباشرة على النص
28
  inputs = ged_tokenizer([text], return_tensors="pt")
29
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
30
  with torch.no_grad():
31
  logits = ged_model(**inputs).logits
32
+ probs = F.softmax(logits, dim=-1).squeeze()[1:-1] # تجاهل [CLS]/[SEP]
33
  pred_ids = torch.argmax(probs, -1)
34
  pred_ged_labels = [ged_model.config.id2label[i.item()] for i in pred_ids]
35
 
36
+ # ) محاذاة وسوم GED مع تقطيع GEC
 
 
 
 
37
  tokens, ged_labels = [], []
38
  for word, label in zip(text.split(), pred_ged_labels):
39
  w_toks = gec_tokenizer.tokenize(word)
 
41
  tokens.extend(w_toks)
42
  ged_labels.extend([label] * len(w_toks))
43
 
44
+ # (ج) نبني تسلسل الإدخال لـ GEC (بدون ged_tags)
45
+ input_ids = [gec_tokenizer.bos_token_id] + \
46
+ gec_tokenizer.convert_tokens_to_ids(tokens) + \
47
+ [gec_tokenizer.eos_token_id]
48
 
49
+ # (د) توليد التصحيح القياسي
50
+ with torch.no_grad():
51
+ generated = gec_model.generate(
52
+ torch.tensor([input_ids], device=DEVICE),
53
+ num_beams=5,
54
+ max_length=max(128, len(input_ids) + 32),
55
+ num_return_sequences=1,
56
+ )
57
+ corrected = gec_tokenizer.batch_decode(
58
+ generated, skip_special_tokens=True, clean_up_tokenization_spaces=False
59
+ )[0]
60
+ return corrected
61
 
62
+ # 4) واجهة Gradio
63
+ def ui_correct(text):
64
+ return camel_gec_correct(text)
65
 
66
+ with gr.Blocks(title="المصحح النحوي عربي (CAMeLLab GED→GEC)") as demo:
67
+ gr.Markdown("### تصحيح نحوي عربي باستخدام نماذج CAMeLLab (GED ثم GEC)")
68
+ inp = gr.Textbox(label="النص", lines=8, placeholder="أدخل النص العربي هنا...")
69
  btn = gr.Button("تصحيح")
70
+ out = gr.Textbox(label="النص المصحح", lines=8)
71
  btn.click(fn=ui_correct, inputs=inp, outputs=out)
72
 
73
  if __name__ == "__main__":