Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
#
|
|
|
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
|
@@ -6,33 +8,32 @@ from transformers import AutoTokenizer, BertForTokenClassification, MBartForCond
|
|
| 6 |
|
| 7 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
|
| 9 |
-
# نماذج
|
| 10 |
-
GED_MODEL_NAME = "CAMeL-Lab/camelbert-msa-qalb14-ged-13"
|
| 11 |
-
GEC_MODEL_NAME = "CAMeL-Lab/arabart-qalb14-gec-ged-13"
|
| 12 |
|
|
|
|
| 13 |
ged_tokenizer = AutoTokenizer.from_pretrained(GED_MODEL_NAME)
|
| 14 |
ged_model = BertForTokenClassification.from_pretrained(GED_MODEL_NAME).to(DEVICE).eval()
|
| 15 |
|
| 16 |
gec_tokenizer = AutoTokenizer.from_pretrained(GEC_MODEL_NAME)
|
| 17 |
gec_model = MBartForConditionalGeneration.from_pretrained(GEC_MODEL_NAME).to(DEVICE).eval()
|
| 18 |
|
|
|
|
| 19 |
def camel_gec_correct(text: str) -> str:
|
| 20 |
if not text or not text.strip():
|
| 21 |
return "⚠️ يرجى إدخال نص."
|
| 22 |
-
|
|
|
|
| 23 |
inputs = ged_tokenizer([text], return_tensors="pt")
|
| 24 |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 25 |
with torch.no_grad():
|
| 26 |
logits = ged_model(**inputs).logits
|
| 27 |
-
probs =
|
| 28 |
pred_ids = torch.argmax(probs, -1)
|
| 29 |
pred_ged_labels = [ged_model.config.id2label[i.item()] for i in pred_ids]
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
ged_label2ids = getattr(gec_model.config, "ged_label2id", None)
|
| 33 |
-
if ged_label2ids is None:
|
| 34 |
-
raise RuntimeError("Model config missing ged_label2id mapping.")
|
| 35 |
-
|
| 36 |
tokens, ged_labels = [], []
|
| 37 |
for word, label in zip(text.split(), pred_ged_labels):
|
| 38 |
w_toks = gec_tokenizer.tokenize(word)
|
|
@@ -40,29 +41,33 @@ def camel_gec_correct(text: str) -> str:
|
|
| 40 |
tokens.extend(w_toks)
|
| 41 |
ged_labels.extend([label] * len(w_toks))
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
with torch.no_grad():
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
)[0]
|
|
|
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
with gr.Blocks(title="CAMeL
|
| 62 |
-
gr.Markdown("### تصحيح نحوي عربي
|
| 63 |
-
inp = gr.Textbox(label="النص", lines=
|
| 64 |
btn = gr.Button("تصحيح")
|
| 65 |
-
out = gr.Textbox(label="النص المصحح")
|
| 66 |
btn.click(fn=ui_correct, inputs=inp, outputs=out)
|
| 67 |
|
| 68 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# app.py — CAMeL‑Lab GED→GEC (بدون camel_tools / بدون ged_tags في generate)
|
| 3 |
+
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
|
|
|
| 8 |
|
| 9 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
|
| 11 |
+
# 1) أسماء النماذج (يمكن ترقية qalb14 إلى qalb15 إن رغبت)
|
| 12 |
+
GED_MODEL_NAME = "CAMeL-Lab/camelbert-msa-qalb14-ged-13"
|
| 13 |
+
GEC_MODEL_NAME = "CAMeL-Lab/arabart-qalb14-gec-ged-13"
|
| 14 |
|
| 15 |
+
# 2) تحميل النماذج والمُرمّزات
|
| 16 |
ged_tokenizer = AutoTokenizer.from_pretrained(GED_MODEL_NAME)
|
| 17 |
ged_model = BertForTokenClassification.from_pretrained(GED_MODEL_NAME).to(DEVICE).eval()
|
| 18 |
|
| 19 |
gec_tokenizer = AutoTokenizer.from_pretrained(GEC_MODEL_NAME)
|
| 20 |
gec_model = MBartForConditionalGeneration.from_pretrained(GEC_MODEL_NAME).to(DEVICE).eval()
|
| 21 |
|
| 22 |
+
# 3) الدالة الأساسية
|
| 23 |
def camel_gec_correct(text: str) -> str:
|
| 24 |
if not text or not text.strip():
|
| 25 |
return "⚠️ يرجى إدخال نص."
|
| 26 |
+
|
| 27 |
+
# (أ) GED مباشرة على النص
|
| 28 |
inputs = ged_tokenizer([text], return_tensors="pt")
|
| 29 |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 30 |
with torch.no_grad():
|
| 31 |
logits = ged_model(**inputs).logits
|
| 32 |
+
probs = F.softmax(logits, dim=-1).squeeze()[1:-1] # تجاهل [CLS]/[SEP]
|
| 33 |
pred_ids = torch.argmax(probs, -1)
|
| 34 |
pred_ged_labels = [ged_model.config.id2label[i.item()] for i in pred_ids]
|
| 35 |
|
| 36 |
+
# (ب) محاذاة وسوم GED مع تقطيع GEC
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
tokens, ged_labels = [], []
|
| 38 |
for word, label in zip(text.split(), pred_ged_labels):
|
| 39 |
w_toks = gec_tokenizer.tokenize(word)
|
|
|
|
| 41 |
tokens.extend(w_toks)
|
| 42 |
ged_labels.extend([label] * len(w_toks))
|
| 43 |
|
| 44 |
+
# (ج) نبني تسلسل الإدخال لـ GEC (بدون ged_tags)
|
| 45 |
+
input_ids = [gec_tokenizer.bos_token_id] + \
|
| 46 |
+
gec_tokenizer.convert_tokens_to_ids(tokens) + \
|
| 47 |
+
[gec_tokenizer.eos_token_id]
|
| 48 |
|
| 49 |
+
# (د) توليد التصحيح القياسي
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
generated = gec_model.generate(
|
| 52 |
+
torch.tensor([input_ids], device=DEVICE),
|
| 53 |
+
num_beams=5,
|
| 54 |
+
max_length=max(128, len(input_ids) + 32),
|
| 55 |
+
num_return_sequences=1,
|
| 56 |
+
)
|
| 57 |
+
corrected = gec_tokenizer.batch_decode(
|
| 58 |
+
generated, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 59 |
+
)[0]
|
| 60 |
+
return corrected
|
| 61 |
|
| 62 |
+
# 4) واجهة Gradio
|
| 63 |
+
def ui_correct(text):
|
| 64 |
+
return camel_gec_correct(text)
|
| 65 |
|
| 66 |
+
with gr.Blocks(title="المصحح النحوي عربي (CAMeL‑Lab GED→GEC)") as demo:
|
| 67 |
+
gr.Markdown("### تصحيح نحوي عربي باستخدام نماذج CAMeL‑Lab (GED ثم GEC)")
|
| 68 |
+
inp = gr.Textbox(label="النص", lines=8, placeholder="أدخل النص العربي هنا...")
|
| 69 |
btn = gr.Button("تصحيح")
|
| 70 |
+
out = gr.Textbox(label="النص المصحح", lines=8)
|
| 71 |
btn.click(fn=ui_correct, inputs=inp, outputs=out)
|
| 72 |
|
| 73 |
if __name__ == "__main__":
|