File size: 2,149 Bytes
0095f18
 
a55d9d5
0095f18
 
a55d9d5
 
 
 
 
0095f18
a55d9d5
 
 
0095f18
 
eb9d9bc
a55d9d5
 
0095f18
 
a55d9d5
eb9d9bc
 
 
a55d9d5
0095f18
a55d9d5
eb9d9bc
 
a55d9d5
 
eb9d9bc
 
a55d9d5
 
 
 
 
 
eb9d9bc
 
 
 
 
 
 
 
 
 
 
 
a55d9d5
eb9d9bc
 
a55d9d5
 
0095f18
a55d9d5
0095f18
d39a018
 
0095f18
d39a018
a55d9d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from transformers import pipeline
import requests
import json

# Model name on Hugging Face Hub
model_name = "Woolv7007/egyptian-text-classification"

# Load labels.json from Hugging Face
labels_url = f"https://huggingface.co/{model_name}/resolve/main/labels.json"
try:
    response = requests.get(labels_url)
    response.raise_for_status()
    labels = response.json()
    if isinstance(labels, dict):
        labels = list(labels.values())
    print("Labels loaded:", labels)
except requests.exceptions.RequestException as e:
    print("Failed to load labels.json:", e)
    labels = None

# Load the model pipeline
pipe = pipeline("text-classification", model=model_name)
print("Model loaded.")

# Prediction function
def predict(text):
    print("Input:", text)
    try:
        result = pipe(text)[0]
        print("Raw result:", result)

        label_id = int(result['label'].replace("LABEL_", ""))
        label_text = labels[label_id] if labels and label_id < len(labels) else result['label']
        print("Mapped label:", label_text)

        # Define which labels are considered "True"
        true_labels = ["ads", "neutral"]
        prediction_bool = label_text.lower() in true_labels

        confidence = round(result['score'], 3)

        json_output = {
            "prediction": prediction_bool,
            "original_label": label_text,
            "confidence": confidence
        }

        return str(prediction_bool), json.dumps(json_output, indent=4, ensure_ascii=False)

    except Exception as e:
        error_msg = str(e)
        print("Prediction error:", error_msg)
        return "Error", json.dumps({"error": error_msg}, indent=4, ensure_ascii=False)

# Gradio interface
gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="Enter Egyptian Arabic text..."),
    outputs=[
        gr.Textbox(label="Prediction (True/False)"),
        gr.Textbox(label="Full JSON Output")
    ],
    title="Egyptian Text Classification",
    description="This model classifies Egyptian Arabic text. Only 'ads' and 'neutral' are considered True; all other labels are considered False."
).launch()