import os import re import string import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoTokenizer, AutoModel from peft import LoraConfig, get_peft_model from huggingface_hub import hf_hub_download # ========================================== # 1. Config & Global Initialization (Cached) # ========================================== MODEL_NAME = "CAMeL-Lab/bert-base-arabic-camelbert-mix-sentiment" # PLEASE UPDATE THIS TO YOUR CREATED MODEL REPO FROM STEP 1 REPO_ID = "mahmoudmohammad/MARBERTv2-Sentiment_Classification" MODEL_FILE = "best_stl_model.pth" MAX_LEN = 256 LORA_R = 32 LORA_ALPHA = 64 LORA_DROPOUT = 0.1 # Classes matched from sklearn LabelEncoder mapping: 0, 1, 2 LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ========================================== # 2. Text Preprocessing Function # ========================================== def clean_text(text): if not isinstance(text, str): return str(text) text = re.sub(r'http\S+|www\.\S+', '', text) text = re.sub(r'[@]\S+', '', text) text = re.sub(r'\S+@\S+', ' ', text) text = re.sub(r'\d+|[٠١٢٣٤٥٦٧٨٩]+', '', text) text = text.replace('#', ' ').replace('_', ' ') text = re.sub("[إأآا]", "ا", text) text = re.sub("ى", "ي", text) text = re.sub("ؤ", "و", text) text = re.sub("ئ", "ي", text) text = re.sub("ة", "ه", text) arabic_punc = '`÷×؛«»<>()*&^%][ـ،/:".،,\'{}~¦+|"…""–ـ' eng_punc = string.punctuation.replace('!', '').replace('?', '') text = text.translate(str.maketrans('', '', arabic_punc + eng_punc)) return re.sub(r'\s+', ' ', text).strip() # ========================================== # 3. Model Architecture # ========================================== class SentimentSTL(nn.Module): def __init__(self, num_classes=3): super().__init__() self.encoder = AutoModel.from_pretrained(MODEL_NAME) peft_config = LoraConfig( task_type="FEATURE_EXTRACTION", r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=["query", "value"] ) self.peft_model = get_peft_model(self.encoder, peft_config) hidden_size = self.peft_model.config.hidden_size self.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size // 2, num_classes), ) def forward(self, input_ids, attention_mask): outputs = self.peft_model(input_ids=input_ids, attention_mask=attention_mask) hidden = outputs.last_hidden_state mask = attention_mask.unsqueeze(-1).float() mean_rep = (hidden * mask).sum(1) / mask.sum(1) cls_rep = hidden[:, 0, :] cls_rep = (cls_rep + mean_rep) / 2.0 return self.classifier(cls_rep) # ========================================== # 4. Cache Load Logic # ========================================== print("Downloading / verifying cached weights...") weights_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE) print("Initializing Tokenizer and Model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = SentimentSTL(num_classes=3).to(device) print("Loading trained weights into model...") model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() # Freeze layers into inference mode print("Model initialized completely!") # ========================================== # 5. Inference Logic for Gradio # ========================================== def predict(text): if not text.strip(): return {LABELS[k]: 0.0 for k in LABELS} cleaned_text = clean_text(text) enc = tokenizer( str(cleaned_text), add_special_tokens=True, max_length=MAX_LEN, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt', ) ids = enc['input_ids'].to(device) mask = enc['attention_mask'].to(device) with torch.no_grad(): logits = model(ids, mask) # Apply Softmax to map network score between [0,1] probs = F.softmax(logits, dim=1).squeeze().cpu().tolist() # Return formatted Dictionary for the gr.Label() output block output = { LABELS[0]: float(probs[0]), LABELS[1]: float(probs[1]), LABELS[2]: float(probs[2]) } return output # ========================================== # 6. Build User Interface with Gradio # ========================================== custom_css = """ body { font-family: 'Tajawal', sans-serif; } .output-class { font-weight: bold; } """ with gr.Blocks(css=custom_css, title="Arabic Sentiment Analyzer") as demo: gr.Markdown("