Abelex commited on
Commit
9710d8f
Β·
verified Β·
1 Parent(s): 5b98dd2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+ # --------------------------------------------------
6
+ # Configuration
7
+ # --------------------------------------------------
8
+ MODEL_NAME = "Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext"
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # --------------------------------------------------
12
+ # Load model and tokenizer
13
+ # --------------------------------------------------
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
16
+ model.to(DEVICE)
17
+ model.eval()
18
+
19
+ # --------------------------------------------------
20
+ # Prediction function
21
+ # --------------------------------------------------
22
+ def classify_text(text):
23
+ if text.strip() == "":
24
+ return "⚠️ Please enter Amharic text.", {}
25
+
26
+ inputs = tokenizer(
27
+ text,
28
+ return_tensors="pt",
29
+ truncation=True,
30
+ padding=True,
31
+ max_length=512
32
+ ).to(DEVICE)
33
+
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ logits = outputs.logits
37
+ probs = torch.softmax(logits, dim=-1)[0]
38
+
39
+ # Predicted label
40
+ pred_id = torch.argmax(probs).item()
41
+ pred_label = model.config.id2label.get(pred_id, str(pred_id))
42
+
43
+ # All label probabilities
44
+ scores = {
45
+ model.config.id2label.get(i, str(i)): float(probs[i])
46
+ for i in range(len(probs))
47
+ }
48
+
49
+ return pred_label, scores
50
+
51
+ # --------------------------------------------------
52
+ # Gradio UI
53
+ # --------------------------------------------------
54
+ with gr.Blocks(title="Amharic Text Classification") as demo:
55
+ gr.Markdown(
56
+ """
57
+ ## πŸ“„ Amharic Text Classification
58
+ This app classifies **Amharic long text** using a pretrained **AfriBERTa model**.
59
+ """
60
+ )
61
+
62
+ input_text = gr.Textbox(
63
+ lines=8,
64
+ placeholder="αŠ₯α‰£αŠ­α‹Ž α‹¨αŠ αˆ›αˆ­αŠ› αŒ½αˆ‘α αŠ₯α‹šαˆ… α‹«αˆ΅αŒˆα‰‘...",
65
+ label="Input Text"
66
+ )
67
+
68
+ classify_btn = gr.Button("πŸ” Classify")
69
+
70
+ output_label = gr.Label(label="Predicted Label")
71
+ output_scores = gr.JSON(label="Class Probabilities")
72
+
73
+ classify_btn.click(
74
+ fn=classify_text,
75
+ inputs=input_text,
76
+ outputs=[output_label, output_scores]
77
+ )
78
+
79
+ gr.Markdown(
80
+ """
81
+ ---
82
+ **Model:** Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext
83
+ Built with ❀️ using Gradio & Hugging Face
84
+ """
85
+ )
86
+
87
+ # --------------------------------------------------
88
+ # Launch app
89
+ # --------------------------------------------------
90
+ if __name__ == "__main__":
91
+ demo.launch()