Abelex commited on
Commit
6b85f31
·
verified ·
1 Parent(s): c9ce5ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -34
app.py CHANGED
@@ -9,7 +9,7 @@ MODEL_NAME = "Abelex/afro-xlmr-large"
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # --------------------------------------------------
12
- # Load tokenizer & model (CRITICAL FIX)
13
  # --------------------------------------------------
14
  tokenizer = AutoTokenizer.from_pretrained(
15
  MODEL_NAME,
@@ -28,8 +28,8 @@ model.eval()
28
  # Prediction function
29
  # --------------------------------------------------
30
  def classify_text(text):
31
- if not text.strip():
32
- return "⚠️ Please enter Amharic text", {}
33
 
34
  inputs = tokenizer(
35
  text,
@@ -41,50 +41,54 @@ def classify_text(text):
41
 
42
  with torch.no_grad():
43
  outputs = model(**inputs)
44
-
45
- # 🔑 Handle custom forward outputs safely
46
- if hasattr(outputs, "logits"):
47
- logits = outputs.logits
48
- else:
49
- logits = outputs[0]
50
-
51
  probs = torch.softmax(logits, dim=-1)[0]
52
 
53
  pred_id = torch.argmax(probs).item()
54
-
55
- # Safe label handling
56
- id2label = getattr(model.config, "id2label", None)
57
- if id2label:
58
- pred_label = id2label.get(pred_id, str(pred_id))
59
- else:
60
- pred_label = f"Class {pred_id}"
61
 
62
  scores = {
63
- (id2label.get(i, f"Class {i}") if id2label else f"Class {i}"): float(probs[i])
64
  for i in range(len(probs))
65
  }
66
 
67
- return pred_label, scores
 
 
 
68
 
69
  # --------------------------------------------------
70
- # Gradio UI
71
  # --------------------------------------------------
72
- with gr.Blocks(title="Amharic Text Classification") as demo:
73
- gr.Markdown("""
74
- ## 📄 Amharic Text Classification
75
- **Model:** Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext
76
- """)
77
-
78
- input_text = gr.Textbox(
79
- lines=8,
80
- label="Input Amharic Text",
81
- placeholder="እባክዎ የአማርኛ ጽሑፍ እዚህ ያስገቡ..."
 
82
  )
83
 
84
- classify_btn = gr.Button("🔍 Classify")
 
 
 
 
 
 
 
 
 
 
85
 
86
- output_label = gr.Label(label="Predicted Label")
87
- output_scores = gr.JSON(label="Class Probabilities")
 
88
 
89
  classify_btn.click(
90
  fn=classify_text,
@@ -92,7 +96,24 @@ with gr.Blocks(title="Amharic Text Classification") as demo:
92
  outputs=[output_label, output_scores]
93
  )
94
 
95
- gr.Markdown("---\nBuilt with ❤️ using Gradio & Hugging Face")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # --------------------------------------------------
98
  # Launch
 
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # --------------------------------------------------
12
+ # Load tokenizer & model
13
  # --------------------------------------------------
14
  tokenizer = AutoTokenizer.from_pretrained(
15
  MODEL_NAME,
 
28
  # Prediction function
29
  # --------------------------------------------------
30
  def classify_text(text):
31
+ if not text or not text.strip():
32
+ return gr.Markdown("⚠️ **Please enter Amharic text**"), None
33
 
34
  inputs = tokenizer(
35
  text,
 
41
 
42
  with torch.no_grad():
43
  outputs = model(**inputs)
44
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
 
 
 
 
 
 
45
  probs = torch.softmax(logits, dim=-1)[0]
46
 
47
  pred_id = torch.argmax(probs).item()
48
+ id2label = getattr(model.config, "id2label", {})
49
+ pred_label = id2label.get(pred_id, f"Class {pred_id}")
 
 
 
 
 
50
 
51
  scores = {
52
+ id2label.get(i, f"Class {i}"): float(probs[i])
53
  for i in range(len(probs))
54
  }
55
 
56
+ return (
57
+ gr.Markdown(f"### 🏷️ **{pred_label}**"),
58
+ scores
59
+ )
60
 
61
  # --------------------------------------------------
62
+ # Gradio UI (MINIMAL & ATTRACTIVE)
63
  # --------------------------------------------------
64
+ with gr.Blocks(
65
+ title="Amharic Text Classification",
66
+ theme=gr.themes.Soft()
67
+ ) as demo:
68
+
69
+ gr.Markdown(
70
+ """
71
+ # 🇪🇹 Amharic Text Classification
72
+ <small>Powered by **Afro-XLMR / AfriBERTa**</small>
73
+ """,
74
+ elem_id="title"
75
  )
76
 
77
+ with gr.Column(scale=1):
78
+ input_text = gr.Textbox(
79
+ lines=5,
80
+ placeholder="እባክዎ የአማርኛ ጽሑፍ እዚህ ያስገቡ...",
81
+ show_label=False
82
+ )
83
+
84
+ classify_btn = gr.Button(
85
+ "Classify",
86
+ variant="primary"
87
+ )
88
 
89
+ with gr.Column():
90
+ output_label = gr.Markdown()
91
+ output_scores = gr.JSON(label="Class Probabilities")
92
 
93
  classify_btn.click(
94
  fn=classify_text,
 
96
  outputs=[output_label, output_scores]
97
  )
98
 
99
+ gr.Examples(
100
+ examples=[
101
+ ["የኢትዮጵያ ኢኮኖሚ በአዲስ እቅድ እየተሻሻለ ነው።"],
102
+ ["የእግር ኳስ ውድድር በአዲስ መልክ ተጀመረ።"]
103
+ ],
104
+ inputs=input_text,
105
+ label="Examples"
106
+ )
107
+
108
+ gr.Markdown(
109
+ """
110
+ ---
111
+ <small>
112
+ Model: <b>Abelex/afro-xlmr-large</b><br>
113
+ Built with ❤️ using Gradio
114
+ </small>
115
+ """
116
+ )
117
 
118
  # --------------------------------------------------
119
  # Launch