AR_Detection / app.py
Salajmi1's picture
Update app.py
71440a9 verified
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# --- 1. Configuration ---
# This path should point to the folder where your final model was saved.
# Using a clean name like this avoids path errors during deployment.
MODEL_PATH = "Merged_AraBERT_Optuna_EvalNew_Infer"
# This dictionary maps the model's output IDs (0, 1, 2) to readable labels.
ID2LABEL = {0: "Favor", 1: "Against", 2: "None"}
# --- 2. Load Saved Model and Tokenizer ---
print("Loading model and tokenizer from local directory...")
try:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Exit if the model can't be loaded
exit()
# --- 3. Prediction Function ---
def predict_stance(text_input: str):
"""
This function takes a string of text, runs it through the fine-tuned model,
and returns a dictionary of stance probabilities.
"""
# Return empty dictionary if input is empty to avoid errors
if not text_input:
return {}
# Tokenize the input text
inputs = tokenizer(
text_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
# Perform inference with no gradient calculation for efficiency
with torch.no_grad():
outputs = model(**inputs)
# Apply softmax to convert raw outputs (logits) to probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Format the output into a dictionary of {label: probability}
confidences = {ID2LABEL[i]: float(probabilities[0][i]) for i in range(probabilities.shape[1])}
return confidences
# --- 4. Gradio Interface Definition ---
# This section defines the title, description, and components of your web demo.
demo = gr.Interface(
fn=predict_stance,
inputs=gr.Textbox(
lines=5,
placeholder="أدخل النص العربي هنا لتحليل الموقف...",
label="Arabic Text Input"
),
outputs=gr.Label(num_top_classes=3, label="Stance Prediction"),
title="🔎 Arabic Stance Detection Model",
description=(
"This demo uses a fine-tuned CamelBERT model to predict the stance (Favor, Against, or None) of Arabic text. "
"The model was trained on the Mawqif dataset and achieved an F1-score of 84.1% on the 3-class problem. "
"Enter a sentence to see the prediction."
),
examples=[
["الحكومة اتخذت القرار الصائب، هذا ما كنا ننتظره."],
["أسوأ قرار تم اتخاذه على الإطلاق، أنا ضده تماما."],
["الطقس في الدمام اليوم حار جداً."]
],
allow_flagging="never",
article=(
"<p style='text-align: center; margin-top: 20px;'>"
"A project fine-tuned and deployed by [Your Name Here]."
"</p>"
)
)
# --- 5. Launch the Application ---
# This makes the script runnable and launches the Gradio web server.
if __name__ == "__main__":
demo.launch()