fadimari's picture
Update app.py
b65ca7a verified
# -*- coding: utf-8 -*-
# app.py — CAMeL‑Lab GED→GEC (بدون camel_tools / بدون ged_tags في generate)
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, BertForTokenClassification, MBartForConditionalGeneration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 1) أسماء النماذج (يمكن ترقية qalb14 إلى qalb15 إن رغبت)
GED_MODEL_NAME = "CAMeL-Lab/camelbert-msa-qalb14-ged-13"
GEC_MODEL_NAME = "CAMeL-Lab/arabart-qalb14-gec-ged-13"
# 2) تحميل النماذج والمُرمّزات
ged_tokenizer = AutoTokenizer.from_pretrained(GED_MODEL_NAME)
ged_model = BertForTokenClassification.from_pretrained(GED_MODEL_NAME).to(DEVICE).eval()
gec_tokenizer = AutoTokenizer.from_pretrained(GEC_MODEL_NAME)
gec_model = MBartForConditionalGeneration.from_pretrained(GEC_MODEL_NAME).to(DEVICE).eval()
# 3) الدالة الأساسية
def camel_gec_correct(text: str) -> str:
if not text or not text.strip():
return "⚠️ يرجى إدخال نص."
# (أ) GED مباشرة على النص
inputs = ged_tokenizer([text], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
logits = ged_model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze()[1:-1] # تجاهل [CLS]/[SEP]
pred_ids = torch.argmax(probs, -1)
pred_ged_labels = [ged_model.config.id2label[i.item()] for i in pred_ids]
# (ب) محاذاة وسوم GED مع تقطيع GEC
tokens, ged_labels = [], []
for word, label in zip(text.split(), pred_ged_labels):
w_toks = gec_tokenizer.tokenize(word)
if w_toks:
tokens.extend(w_toks)
ged_labels.extend([label] * len(w_toks))
# (ج) نبني تسلسل الإدخال لـ GEC (بدون ged_tags)
input_ids = [gec_tokenizer.bos_token_id] + \
gec_tokenizer.convert_tokens_to_ids(tokens) + \
[gec_tokenizer.eos_token_id]
# (د) توليد التصحيح القياسي
with torch.no_grad():
generated = gec_model.generate(
torch.tensor([input_ids], device=DEVICE),
num_beams=5,
max_length=max(128, len(input_ids) + 32),
num_return_sequences=1,
)
corrected = gec_tokenizer.batch_decode(
generated, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return corrected
# 4) واجهة Gradio
def ui_correct(text):
return camel_gec_correct(text)
with gr.Blocks(title="المصحح النحوي عربي (CAMeL‑Lab GED→GEC)") as demo:
gr.Markdown("### تصحيح نحوي عربي باستخدام نماذج CAMeL‑Lab (GED ثم GEC)")
inp = gr.Textbox(label="النص", lines=8, placeholder="أدخل النص العربي هنا...")
btn = gr.Button("تصحيح")
out = gr.Textbox(label="النص المصحح", lines=8)
btn.click(fn=ui_correct, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()