import os import re import string import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoTokenizer, AutoModel from peft import LoraConfig, get_peft_model from huggingface_hub import hf_hub_download # ========================================== # 1. Config & Global Initialization (Cached) # ========================================== MODEL_NAME = "CAMeL-Lab/bert-base-arabic-camelbert-mix-sentiment" # PLEASE UPDATE THIS TO YOUR CREATED MODEL REPO FROM STEP 1 REPO_ID = "mahmoudmohammad/MARBERTv2-Sentiment_Classification" MODEL_FILE = "best_stl_model.pth" MAX_LEN = 256 LORA_R = 32 LORA_ALPHA = 64 LORA_DROPOUT = 0.1 # Classes matched from sklearn LabelEncoder mapping: 0, 1, 2 LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ========================================== # 2. Text Preprocessing Function # ========================================== def clean_text(text): if not isinstance(text, str): return str(text) text = re.sub(r'http\S+|www\.\S+', '', text) text = re.sub(r'[@]\S+', '', text) text = re.sub(r'\S+@\S+', ' ', text) text = re.sub(r'\d+|[٠١٢٣٤٥٦٧٨٩]+', '', text) text = text.replace('#', ' ').replace('_', ' ') text = re.sub("[إأآا]", "ا", text) text = re.sub("ى", "ي", text) text = re.sub("ؤ", "و", text) text = re.sub("ئ", "ي", text) text = re.sub("ة", "ه", text) arabic_punc = '`÷×؛«»<>()*&^%][ـ،/:".،,\'{}~¦+|"…""–ـ' eng_punc = string.punctuation.replace('!', '').replace('?', '') text = text.translate(str.maketrans('', '', arabic_punc + eng_punc)) return re.sub(r'\s+', ' ', text).strip() # ========================================== # 3. Model Architecture # ========================================== class SentimentSTL(nn.Module): def __init__(self, num_classes=3): super().__init__() self.encoder = AutoModel.from_pretrained(MODEL_NAME) peft_config = LoraConfig( task_type="FEATURE_EXTRACTION", r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=["query", "value"] ) self.peft_model = get_peft_model(self.encoder, peft_config) hidden_size = self.peft_model.config.hidden_size self.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size // 2, num_classes), ) def forward(self, input_ids, attention_mask): outputs = self.peft_model(input_ids=input_ids, attention_mask=attention_mask) hidden = outputs.last_hidden_state mask = attention_mask.unsqueeze(-1).float() mean_rep = (hidden * mask).sum(1) / mask.sum(1) cls_rep = hidden[:, 0, :] cls_rep = (cls_rep + mean_rep) / 2.0 return self.classifier(cls_rep) # ========================================== # 4. Cache Load Logic # ========================================== print("Downloading / verifying cached weights...") weights_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE) print("Initializing Tokenizer and Model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = SentimentSTL(num_classes=3).to(device) print("Loading trained weights into model...") model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() # Freeze layers into inference mode print("Model initialized completely!") # ========================================== # 5. Inference Logic for Gradio # ========================================== def predict(text): if not text.strip(): return {LABELS[k]: 0.0 for k in LABELS} cleaned_text = clean_text(text) enc = tokenizer( str(cleaned_text), add_special_tokens=True, max_length=MAX_LEN, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt', ) ids = enc['input_ids'].to(device) mask = enc['attention_mask'].to(device) with torch.no_grad(): logits = model(ids, mask) # Apply Softmax to map network score between [0,1] probs = F.softmax(logits, dim=1).squeeze().cpu().tolist() # Return formatted Dictionary for the gr.Label() output block output = { LABELS[0]: float(probs[0]), LABELS[1]: float(probs[1]), LABELS[2]: float(probs[2]) } return output # ========================================== # 6. Build User Interface with Gradio # ========================================== custom_css = """ body { font-family: 'Tajawal', sans-serif; } .output-class { font-weight: bold; } """ with gr.Blocks(css=custom_css, title="Arabic Sentiment Analyzer") as demo: gr.Markdown("

🌐 Arabic Sentiment Analysis — CAMeLBERT + LoRA

") gr.Markdown("
Analyze Arabic text and detect Negative, Neutral, or Positive sentiment.
This model is balanced trained specifically on an augmented pipeline, leveraging LoRA representation over base features.
") with gr.Row(): with gr.Column(scale=2): input_textbox = gr.Textbox( label="Arabic Input Text", placeholder="...اكتب أو ألصق الجملة العربية هنا", lines=5, rtl=True ) submit_button = gr.Button("Analyze / تحليل", variant="primary") gr.Examples( examples=[ "هذا المطعم يقدم طعام سيء جدًا ومحروق، لا أنصح أحد بزيارته.", "المؤتمر الصحفي لوزير الاتصالات تم بالأمس لمناقشة بعض الأمور الإقتصادية.", "بصراحة، الخدمة كانت ممتازة والموظفين غاية في الإحترام، أنصح بشدة!" ], inputs=input_textbox, label="Examples / أمثلة سريعة" ) with gr.Column(scale=1): output_label = gr.Label( label="Sentiment Probabilities", num_top_classes=3 ) submit_button.click(fn=predict, inputs=input_textbox, outputs=output_label) # Launch UI! if __name__ == "__main__": demo.launch(share=False)