Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| # --- 1. Configuration --- | |
| # This path should point to the folder where your final model was saved. | |
| # Using a clean name like this avoids path errors during deployment. | |
| MODEL_PATH = "Merged_AraBERT_Optuna_EvalNew_Infer" | |
| # This dictionary maps the model's output IDs (0, 1, 2) to readable labels. | |
| ID2LABEL = {0: "Favor", 1: "Against", 2: "None"} | |
| # --- 2. Load Saved Model and Tokenizer --- | |
| print("Loading model and tokenizer from local directory...") | |
| try: | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| print("Model and tokenizer loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Exit if the model can't be loaded | |
| exit() | |
| # --- 3. Prediction Function --- | |
| def predict_stance(text_input: str): | |
| """ | |
| This function takes a string of text, runs it through the fine-tuned model, | |
| and returns a dictionary of stance probabilities. | |
| """ | |
| # Return empty dictionary if input is empty to avoid errors | |
| if not text_input: | |
| return {} | |
| # Tokenize the input text | |
| inputs = tokenizer( | |
| text_input, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=256 | |
| ) | |
| # Perform inference with no gradient calculation for efficiency | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Apply softmax to convert raw outputs (logits) to probabilities | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Format the output into a dictionary of {label: probability} | |
| confidences = {ID2LABEL[i]: float(probabilities[0][i]) for i in range(probabilities.shape[1])} | |
| return confidences | |
| # --- 4. Gradio Interface Definition --- | |
| # This section defines the title, description, and components of your web demo. | |
| demo = gr.Interface( | |
| fn=predict_stance, | |
| inputs=gr.Textbox( | |
| lines=5, | |
| placeholder="أدخل النص العربي هنا لتحليل الموقف...", | |
| label="Arabic Text Input" | |
| ), | |
| outputs=gr.Label(num_top_classes=3, label="Stance Prediction"), | |
| title="🔎 Arabic Stance Detection Model", | |
| description=( | |
| "This demo uses a fine-tuned CamelBERT model to predict the stance (Favor, Against, or None) of Arabic text. " | |
| "The model was trained on the Mawqif dataset and achieved an F1-score of 84.1% on the 3-class problem. " | |
| "Enter a sentence to see the prediction." | |
| ), | |
| examples=[ | |
| ["الحكومة اتخذت القرار الصائب، هذا ما كنا ننتظره."], | |
| ["أسوأ قرار تم اتخاذه على الإطلاق، أنا ضده تماما."], | |
| ["الطقس في الدمام اليوم حار جداً."] | |
| ], | |
| allow_flagging="never", | |
| article=( | |
| "<p style='text-align: center; margin-top: 20px;'>" | |
| "A project fine-tuned and deployed by [Your Name Here]." | |
| "</p>" | |
| ) | |
| ) | |
| # --- 5. Launch the Application --- | |
| # This makes the script runnable and launches the Gradio web server. | |
| if __name__ == "__main__": | |
| demo.launch() |