| import os
|
| import re
|
| import string
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import gradio as gr
|
| from transformers import AutoTokenizer, AutoModel
|
| from peft import LoraConfig, get_peft_model
|
| from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
| MODEL_NAME = "CAMeL-Lab/bert-base-arabic-camelbert-mix-sentiment"
|
|
|
| REPO_ID = "mahmoudmohammad/MARBERTv2-Sentiment_Classification"
|
| MODEL_FILE = "best_stl_model.pth"
|
|
|
| MAX_LEN = 256
|
| LORA_R = 32
|
| LORA_ALPHA = 64
|
| LORA_DROPOUT = 0.1
|
|
|
|
|
| LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"}
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
| def clean_text(text):
|
| if not isinstance(text, str):
|
| return str(text)
|
| text = re.sub(r'http\S+|www\.\S+', '', text)
|
| text = re.sub(r'[@]\S+', '', text)
|
| text = re.sub(r'\S+@\S+', ' ', text)
|
| text = re.sub(r'\d+|[٠١٢٣٤٥٦٧٨٩]+', '', text)
|
| text = text.replace('#', ' ').replace('_', ' ')
|
| text = re.sub("[إأآا]", "ا", text)
|
| text = re.sub("ى", "ي", text)
|
| text = re.sub("ؤ", "و", text)
|
| text = re.sub("ئ", "ي", text)
|
| text = re.sub("ة", "ه", text)
|
| arabic_punc = '`÷×؛«»<>()*&^%][ـ،/:".،,\'{}~¦+|"…""–ـ'
|
| eng_punc = string.punctuation.replace('!', '').replace('?', '')
|
| text = text.translate(str.maketrans('', '', arabic_punc + eng_punc))
|
| return re.sub(r'\s+', ' ', text).strip()
|
|
|
|
|
|
|
|
|
| class SentimentSTL(nn.Module):
|
| def __init__(self, num_classes=3):
|
| super().__init__()
|
| self.encoder = AutoModel.from_pretrained(MODEL_NAME)
|
|
|
| peft_config = LoraConfig(
|
| task_type="FEATURE_EXTRACTION",
|
| r=LORA_R,
|
| lora_alpha=LORA_ALPHA,
|
| lora_dropout=LORA_DROPOUT,
|
| target_modules=["query", "value"]
|
| )
|
| self.peft_model = get_peft_model(self.encoder, peft_config)
|
| hidden_size = self.peft_model.config.hidden_size
|
|
|
| self.classifier = nn.Sequential(
|
| nn.Dropout(0.3),
|
| nn.Linear(hidden_size, hidden_size // 2),
|
| nn.ReLU(),
|
| nn.Dropout(0.1),
|
| nn.Linear(hidden_size // 2, num_classes),
|
| )
|
|
|
| def forward(self, input_ids, attention_mask):
|
| outputs = self.peft_model(input_ids=input_ids, attention_mask=attention_mask)
|
| hidden = outputs.last_hidden_state
|
| mask = attention_mask.unsqueeze(-1).float()
|
| mean_rep = (hidden * mask).sum(1) / mask.sum(1)
|
| cls_rep = hidden[:, 0, :]
|
| cls_rep = (cls_rep + mean_rep) / 2.0
|
| return self.classifier(cls_rep)
|
|
|
|
|
|
|
|
|
| print("Downloading / verifying cached weights...")
|
| weights_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
|
|
|
| print("Initializing Tokenizer and Model...")
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| model = SentimentSTL(num_classes=3).to(device)
|
|
|
| print("Loading trained weights into model...")
|
| model.load_state_dict(torch.load(weights_path, map_location=device))
|
| model.eval()
|
| print("Model initialized completely!")
|
|
|
|
|
|
|
|
|
| def predict(text):
|
| if not text.strip():
|
| return {LABELS[k]: 0.0 for k in LABELS}
|
|
|
| cleaned_text = clean_text(text)
|
|
|
| enc = tokenizer(
|
| str(cleaned_text),
|
| add_special_tokens=True,
|
| max_length=MAX_LEN,
|
| padding='max_length',
|
| truncation=True,
|
| return_attention_mask=True,
|
| return_tensors='pt',
|
| )
|
|
|
| ids = enc['input_ids'].to(device)
|
| mask = enc['attention_mask'].to(device)
|
|
|
| with torch.no_grad():
|
| logits = model(ids, mask)
|
|
|
| probs = F.softmax(logits, dim=1).squeeze().cpu().tolist()
|
|
|
|
|
| output = {
|
| LABELS[0]: float(probs[0]),
|
| LABELS[1]: float(probs[1]),
|
| LABELS[2]: float(probs[2])
|
| }
|
|
|
| return output
|
|
|
|
|
|
|
|
|
| custom_css = """
|
| body { font-family: 'Tajawal', sans-serif; }
|
| .output-class { font-weight: bold; }
|
| """
|
|
|
| with gr.Blocks(css=custom_css, title="Arabic Sentiment Analyzer") as demo:
|
| gr.Markdown("<center><h1>🌐 Arabic Sentiment Analysis — CAMeLBERT + LoRA</h1></center>")
|
| gr.Markdown("<center><b>Analyze Arabic text and detect Negative, Neutral, or Positive sentiment.</b><br> This model is balanced trained specifically on an augmented pipeline, leveraging LoRA representation over base features.</center>")
|
|
|
| with gr.Row():
|
| with gr.Column(scale=2):
|
| input_textbox = gr.Textbox(
|
| label="Arabic Input Text",
|
| placeholder="...اكتب أو ألصق الجملة العربية هنا",
|
| lines=5,
|
| rtl=True
|
| )
|
| submit_button = gr.Button("Analyze / تحليل", variant="primary")
|
|
|
| gr.Examples(
|
| examples=[
|
| "هذا المطعم يقدم طعام سيء جدًا ومحروق، لا أنصح أحد بزيارته.",
|
| "المؤتمر الصحفي لوزير الاتصالات تم بالأمس لمناقشة بعض الأمور الإقتصادية.",
|
| "بصراحة، الخدمة كانت ممتازة والموظفين غاية في الإحترام، أنصح بشدة!"
|
| ],
|
| inputs=input_textbox,
|
| label="Examples / أمثلة سريعة"
|
| )
|
|
|
| with gr.Column(scale=1):
|
| output_label = gr.Label(
|
| label="Sentiment Probabilities",
|
| num_top_classes=3
|
| )
|
|
|
| submit_button.click(fn=predict, inputs=input_textbox, outputs=output_label)
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.launch(share=False) |