Abelex commited on
Commit
fe4863c
Β·
verified Β·
1 Parent(s): 201da85

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ # ----------------------------------------
8
+ # 1. Load from Hugging Face Hub
9
+ # ----------------------------------------
10
+
11
+ # Change this to YOUR pushed model repo
12
+ HUB_MODEL_ID = "Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext"
13
+ # <--- EDIT IF NEEDED
14
+
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ MAX_LENGTH = 512 # model context window in TOKENS
17
+
18
+ # Load tokenizer and model directly from HF Hub
19
+ tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
20
+ model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID)
21
+ model.to(DEVICE)
22
+ model.eval()
23
+
24
+ # Label mapping from config
25
+ id2label = {int(k): v for k, v in model.config.id2label.items()}
26
+ num_labels = len(id2label)
27
+
28
+ # ----------------------------------------
29
+ # Helper: highlight tokens after MAX_LENGTH in red (HTML)
30
+ # ----------------------------------------
31
+ def highlight_token_overflow(text: str, max_tokens: int = 512) -> str:
32
+ """
33
+ Tokenize the input text and generate HTML where tokens beyond
34
+ `max_tokens` are wrapped in red. This shows exactly which tokens
35
+ are outside the model's context window.
36
+ """
37
+ if not text.strip():
38
+ return "<i>No text provided.</i>"
39
+
40
+ # Tokenize without truncation (so we can see ALL tokens)
41
+ tokens = tokenizer.tokenize(text)
42
+ if len(tokens) == 0:
43
+ return "<i>No tokens produced by tokenizer.</i>"
44
+
45
+ spans = []
46
+ for i, tok in enumerate(tokens):
47
+ # minimal HTML escape
48
+ safe_tok = (
49
+ tok.replace("&", "&amp;")
50
+ .replace("<", "&lt;")
51
+ .replace(">", "&gt;")
52
+ )
53
+
54
+ if i >= max_tokens:
55
+ spans.append(f"<span style='color:red;font-weight:bold;'>{safe_tok}</span>")
56
+ else:
57
+ spans.append(f"<span>{safe_tok}</span>")
58
+
59
+ html = " ".join(spans)
60
+
61
+ if len(tokens) > max_tokens:
62
+ html += (
63
+ f"<br><br>"
64
+ f"<small style='color:red;'>"
65
+ f"Note: Tokens in <b>red</b> are beyond the model context window "
66
+ f"({max_tokens} tokens) and will be truncated."
67
+ f"</small>"
68
+ )
69
+ else:
70
+ html += (
71
+ f"<br><br>"
72
+ f"<small>Token count: {len(tokens)} (≀ {max_tokens}, no truncation).</small>"
73
+ )
74
+
75
+ return html
76
+
77
+ # ----------------------------------------
78
+ # 2. Prediction
79
+ # ----------------------------------------
80
+ def predict_amharic_news(text):
81
+ if not text.strip():
82
+ # Also return highlighted version (empty)
83
+ return "Please enter text.", None, "<i>No text provided.</i>"
84
+
85
+ # For actual model inference: truncate to MAX_LENGTH tokens
86
+ encoded = tokenizer(
87
+ text,
88
+ truncation=True,
89
+ padding="max_length",
90
+ max_length=MAX_LENGTH,
91
+ return_tensors="pt"
92
+ )
93
+
94
+ encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
95
+
96
+ with torch.no_grad():
97
+ outputs = model(**encoded)
98
+ logits = outputs.logits
99
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
100
+
101
+ pred_id = int(np.argmax(probs))
102
+ pred_label = id2label.get(pred_id, f"LABEL_{pred_id}")
103
+
104
+ # Prepare probability table
105
+ rows = []
106
+ for i in range(num_labels):
107
+ rows.append((id2label.get(i, f"LABEL_{i}"), float(probs[i])))
108
+
109
+ rows = sorted(rows, key=lambda x: x[1], reverse=True)
110
+
111
+ # Build HTML showing tokens; tokens >512 in red
112
+ token_highlight_html = highlight_token_overflow(text, max_tokens=MAX_LENGTH)
113
+
114
+ # Now we return 3 outputs: prediction, probs table, token visualization
115
+ return f"Predicted Label: {pred_label}", rows, token_highlight_html
116
+
117
+ # ----------------------------------------
118
+ # 3. Gradio Interface
119
+ # ----------------------------------------
120
+ demo = gr.Interface(
121
+ fn=predict_amharic_news,
122
+ inputs=gr.Textbox(
123
+ lines=5,
124
+ label="Enter Amharic News Text",
125
+ placeholder="αŠ₯α‰£αŠ­α‹Ž α‹¨αŠ αˆ›αˆ­αŠ› α‹œαŠ“ αŒ½αˆ‘α α‹«αˆ΅αŒˆα‰‘..."
126
+ ),
127
+ outputs=[
128
+ gr.Textbox(label="Prediction"),
129
+ gr.Dataframe(
130
+ headers=["Label", "Probability"],
131
+ label="Class Probabilities"
132
+ ),
133
+ gr.HTML(label="Tokenizer view (tokens > 512 are red)")
134
+ ],
135
+ title="Amharic News Classifier",
136
+ description=(
137
+ "XLM-RoBERTa model loaded directly from Hugging Face Hub (raw text input, no preprocessing). "
138
+ "Below, tokenizer output shows which tokens are beyond the 512-token context window (in red)."
139
+ )
140
+ )
141
+
142
+ demo.launch()