Abelex commited on
Commit
bd815d9
Β·
verified Β·
1 Parent(s): c63b323

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -81
app.py CHANGED
@@ -1,97 +1,142 @@
 
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModel
 
4
 
5
- # ==================================================
6
- # Configuration
7
- # ==================================================
8
- MODEL_NAME = "Abelex/afro-xlmr-large"
9
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # ==================================================
12
- # Load tokenizer & model
13
- # ==================================================
14
- tokenizer = AutoTokenizer.from_pretrained(
15
- MODEL_NAME,
16
- trust_remote_code=True
17
- )
18
 
19
- model = AutoModel.from_pretrained(
20
- MODEL_NAME,
21
- trust_remote_code=True
22
- )
23
 
 
 
 
24
  model.to(DEVICE)
25
  model.eval()
26
 
27
- # ==================================================
28
- # Prediction function (FULLY FIXED)
29
- # ==================================================
30
- def classify_text(text):
31
- # ---- Validation ----
32
- if not text or not text.strip():
33
- return "⚠️ Please enter Amharic text.", None
34
-
35
- # ---- Tokenization ----
36
- inputs = tokenizer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  text,
38
- return_tensors="pt",
39
  truncation=True,
40
- padding=True,
41
- max_length=1024
42
- ).to(DEVICE)
43
-
44
- # ---- Inference ----
45
- with torch.no_grad():
46
- outputs = model(**inputs)
47
- logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
48
- probs = torch.softmax(logits, dim=-1)[0]
49
-
50
- # ---- Prediction ----
51
- pred_id = torch.argmax(probs).item()
52
- id2label = getattr(model.config, "id2label", {})
53
- pred_label = id2label.get(pred_id, f"Class {pred_id}")
54
-
55
- scores = {
56
- id2label.get(i, f"Class {i}"): float(probs[i])
57
- for i in range(len(probs))
58
- }
59
-
60
- return f"🏷️ **{pred_label}**", scores
61
-
62
- # ==================================================
63
- # Gradio UI (MINIMAL & STABLE)
64
- # ==================================================
65
- with gr.Blocks(
66
- title="Amharic Text Classification",
67
- theme=gr.themes.Soft()
68
- ) as demo:
69
-
70
- gr.Markdown("## πŸ‡ͺπŸ‡Ή Amharic Text Classification")
71
-
72
- input_text = gr.Textbox(
73
- lines=4,
74
- placeholder="αŠ₯α‰£αŠ­α‹Ž α‹¨αŠ αˆ›αˆ­αŠ› αŒ½αˆ‘α αŠ₯α‹šαˆ… α‹«αˆ΅αŒˆα‰‘...",
75
- show_label=False
76
  )
77
 
78
- classify_btn = gr.Button("Classify", variant="primary")
79
-
80
- output_label = gr.Markdown()
81
- output_scores = gr.JSON(label="Class Probabilities", visible=False)
82
-
83
- classify_btn.click(
84
- fn=classify_text,
85
- inputs=input_text,
86
- outputs=[output_label, output_scores]
87
- )
88
 
89
- gr.Markdown(
90
- "<small>Model: <b>Abelex/afro-xlmr-large</b></small>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
 
92
 
93
- # ==================================================
94
- # Launch
95
- # ==================================================
96
- if __name__ == "__main__":
97
- demo.launch()
 
1
+
2
  import torch
3
  import gradio as gr
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
7
+ # ----------------------------------------
8
+ # 1. Load from Hugging Face Hub
9
+ # ----------------------------------------
 
 
10
 
11
+ # Change this to YOUR pushed model repo
12
+ HUB_MODEL_ID = "Abelex/amharic-news-bert-multilingual-cased"
13
+ # <--- EDIT IF NEEDED
 
 
 
 
14
 
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ MAX_LENGTH = 512 # model context window in TOKENS
 
 
17
 
18
+ # Load tokenizer and model directly from HF Hub
19
+ tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
20
+ model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID)
21
  model.to(DEVICE)
22
  model.eval()
23
 
24
+ # Label mapping from config
25
+ id2label = {int(k): v for k, v in model.config.id2label.items()}
26
+ num_labels = len(id2label)
27
+
28
+ # ----------------------------------------
29
+ # Helper: highlight tokens after MAX_LENGTH in red (HTML)
30
+ # ----------------------------------------
31
+ def highlight_token_overflow(text: str, max_tokens: int = 512) -> str:
32
+ """
33
+ Tokenize the input text and generate HTML where tokens beyond
34
+ `max_tokens` are wrapped in red. This shows exactly which tokens
35
+ are outside the model's context window.
36
+ """
37
+ if not text.strip():
38
+ return "<i>No text provided.</i>"
39
+
40
+ # Tokenize without truncation (so we can see ALL tokens)
41
+ tokens = tokenizer.tokenize(text)
42
+ if len(tokens) == 0:
43
+ return "<i>No tokens produced by tokenizer.</i>"
44
+
45
+ spans = []
46
+ for i, tok in enumerate(tokens):
47
+ # minimal HTML escape
48
+ safe_tok = (
49
+ tok.replace("&", "&amp;")
50
+ .replace("<", "&lt;")
51
+ .replace(">", "&gt;")
52
+ )
53
+
54
+ if i >= max_tokens:
55
+ spans.append(f"<span style='color:red;font-weight:bold;'>{safe_tok}</span>")
56
+ else:
57
+ spans.append(f"<span>{safe_tok}</span>")
58
+
59
+ html = " ".join(spans)
60
+
61
+ if len(tokens) > max_tokens:
62
+ html += (
63
+ f"<br><br>"
64
+ f"<small style='color:red;'>"
65
+ f"Note: Tokens in <b>red</b> are beyond the model context window "
66
+ f"({max_tokens} tokens) and will be truncated."
67
+ f"</small>"
68
+ )
69
+ else:
70
+ html += (
71
+ f"<br><br>"
72
+ f"<small>Token count: {len(tokens)} (≀ {max_tokens}, no truncation).</small>"
73
+ )
74
+
75
+ return html
76
+
77
+ # ----------------------------------------
78
+ # 2. Prediction
79
+ # ----------------------------------------
80
+ def predict_amharic_news(text):
81
+ if not text.strip():
82
+ # Also return highlighted version (empty)
83
+ return "Please enter text.", None, "<i>No text provided.</i>"
84
+
85
+ # For actual model inference: truncate to MAX_LENGTH tokens
86
+ encoded = tokenizer(
87
  text,
 
88
  truncation=True,
89
+ padding="max_length",
90
+ max_length=MAX_LENGTH,
91
+ return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
94
+ encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
 
 
 
 
 
 
 
 
 
95
 
96
+ with torch.no_grad():
97
+ outputs = model(**encoded)
98
+ logits = outputs.logits
99
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
100
+
101
+ pred_id = int(np.argmax(probs))
102
+ pred_label = id2label.get(pred_id, f"LABEL_{pred_id}")
103
+
104
+ # Prepare probability table
105
+ rows = []
106
+ for i in range(num_labels):
107
+ rows.append((id2label.get(i, f"LABEL_{i}"), float(probs[i])))
108
+
109
+ rows = sorted(rows, key=lambda x: x[1], reverse=True)
110
+
111
+ # Build HTML showing tokens; tokens >512 in red
112
+ token_highlight_html = highlight_token_overflow(text, max_tokens=MAX_LENGTH)
113
+
114
+ # Now we return 3 outputs: prediction, probs table, token visualization
115
+ return f"Predicted Label: {pred_label}", rows, token_highlight_html
116
+
117
+ # ----------------------------------------
118
+ # 3. Gradio Interface
119
+ # ----------------------------------------
120
+ demo = gr.Interface(
121
+ fn=predict_amharic_news,
122
+ inputs=gr.Textbox(
123
+ lines=5,
124
+ label="Enter Amharic News Text",
125
+ placeholder="αŠ₯α‰£αŠ­α‹Ž α‹¨αŠ αˆ›αˆ­αŠ› α‹œαŠ“ αŒ½αˆ‘α α‹«αˆ΅αŒˆα‰‘..."
126
+ ),
127
+ outputs=[
128
+ gr.Textbox(label="Prediction"),
129
+ gr.Dataframe(
130
+ headers=["Label", "Probability"],
131
+ label="Class Probabilities"
132
+ ),
133
+ gr.HTML(label="Tokenizer view (tokens > 512 are red)")
134
+ ],
135
+ title="Amharic News Classifier",
136
+ description=(
137
+ "XLM-RoBERTa model loaded directly from Hugging Face Hub (raw text input, no preprocessing). "
138
+ "Below, tokenizer output shows which tokens are beyond the 512-token context window (in red)."
139
  )
140
+ )
141
 
142
+ demo.launch()