Woolv7007 commited on
Commit
d39a018
·
verified ·
1 Parent(s): 7eb94fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -31
app.py CHANGED
@@ -3,71 +3,64 @@ from transformers import pipeline
3
  import requests
4
  import json
5
 
6
- # Model name on Hugging Face
7
  model_name = "Woolv7007/egyptian-text-classification"
8
 
9
- # Load label names from labels.json
10
  labels_url = f"https://huggingface.co/{model_name}/resolve/main/labels.json"
11
  try:
12
  response = requests.get(labels_url)
13
  response.raise_for_status()
14
  labels = response.json()
15
-
16
- # Convert to list if it's a dictionary
17
  if isinstance(labels, dict):
18
  labels = list(labels.values())
19
-
20
  print("Labels loaded:", labels)
21
-
22
  except requests.exceptions.RequestException as e:
23
  print("Failed to load labels.json:", e)
24
  labels = None
25
 
26
- # Load the text classification pipeline
27
  pipe = pipeline("text-classification", model=model_name)
28
- print("Model pipeline loaded.")
29
 
30
- # Define the prediction function
31
  def predict(text):
32
- print("Input text:", text)
33
-
34
  try:
35
- # Run prediction
36
  result = pipe(text)[0]
37
- print("Model output:", result)
38
 
39
- # Extract label index from format like "LABEL_3"
40
  label_id = int(result['label'].replace("LABEL_", ""))
41
- print("Label ID:", label_id)
42
-
43
- # Get the label name using the index
44
  label_text = labels[label_id] if labels and label_id < len(labels) else result['label']
45
- print("Final label:", label_text)
 
 
 
 
46
 
47
- # Get the confidence score
48
  confidence = round(result['score'], 3)
49
- print("Confidence:", confidence)
50
 
51
- # Build JSON result
52
  json_output = {
53
- "prediction": label_text,
 
54
  "confidence": confidence
55
  }
56
 
57
- return label_text, json.dumps(json_output, indent=4, ensure_ascii=False)
58
 
59
  except Exception as e:
60
- print("Error:", e)
61
- return "Error", json.dumps({"error": str(e)}, indent=4, ensure_ascii=False)
 
62
 
63
- # Create the Gradio interface
64
  gr.Interface(
65
  fn=predict,
66
- inputs=gr.Textbox(lines=3, placeholder="Enter a sentence in Egyptian Arabic..."),
67
  outputs=[
68
- gr.Label(label="Predicted Category"),
69
- gr.Textbox(label="JSON Output")
70
  ],
71
- title="Egyptian Arabic Text Classifier",
72
- description="This model classifies Egyptian Arabic text into categories such as Neutral, Offensive, Racist, Religious Discrimination, Ads, etc."
73
  ).launch()
 
3
  import requests
4
  import json
5
 
6
+ # Model name on Hugging Face Hub
7
  model_name = "Woolv7007/egyptian-text-classification"
8
 
9
+ # Load labels.json from Hugging Face
10
  labels_url = f"https://huggingface.co/{model_name}/resolve/main/labels.json"
11
  try:
12
  response = requests.get(labels_url)
13
  response.raise_for_status()
14
  labels = response.json()
 
 
15
  if isinstance(labels, dict):
16
  labels = list(labels.values())
 
17
  print("Labels loaded:", labels)
 
18
  except requests.exceptions.RequestException as e:
19
  print("Failed to load labels.json:", e)
20
  labels = None
21
 
22
+ # Load the model pipeline
23
  pipe = pipeline("text-classification", model=model_name)
24
+ print("Model loaded.")
25
 
26
+ # Prediction function
27
  def predict(text):
28
+ print("Input:", text)
 
29
  try:
 
30
  result = pipe(text)[0]
31
+ print("Raw result:", result)
32
 
 
33
  label_id = int(result['label'].replace("LABEL_", ""))
 
 
 
34
  label_text = labels[label_id] if labels and label_id < len(labels) else result['label']
35
+ print("Mapped label:", label_text)
36
+
37
+ # Define which labels are considered "True"
38
+ true_labels = ["ads", "neutral"]
39
+ prediction_bool = label_text.lower() in true_labels
40
 
 
41
  confidence = round(result['score'], 3)
 
42
 
 
43
  json_output = {
44
+ "prediction": prediction_bool,
45
+ "original_label": label_text,
46
  "confidence": confidence
47
  }
48
 
49
+ return str(prediction_bool), json.dumps(json_output, indent=4, ensure_ascii=False)
50
 
51
  except Exception as e:
52
+ error_msg = str(e)
53
+ print("Prediction error:", error_msg)
54
+ return "Error", json.dumps({"error": error_msg}, indent=4, ensure_ascii=False)
55
 
56
+ # Gradio interface
57
  gr.Interface(
58
  fn=predict,
59
+ inputs=gr.Textbox(lines=3, placeholder="Enter Egyptian Arabic text..."),
60
  outputs=[
61
+ gr.Textbox(label="Prediction (True/False)"),
62
+ gr.Textbox(label="Full JSON Output")
63
  ],
64
+ title="Egyptian Text Classification",
65
+ description="This model classifies Egyptian Arabic text. Only 'ads' and 'neutral' are considered True; all other labels are considered False."
66
  ).launch()