Abelex commited on
Commit
ed7df25
Β·
verified Β·
1 Parent(s): c7d8bdd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===============================
2
+ # Sentence-ChuLo Gradio Demo (HF Spaces Ready)
3
+ # ===============================
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import os
10
+ import re
11
+ from transformers import AutoTokenizer, AutoModel
12
+
13
+ # --------------------------------------------------
14
+ # Configuration
15
+ # --------------------------------------------------
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ PRETRAINED = "Davlan/afro-xlmr-large"
19
+ HF_MODEL_ID = "Abelex/Sentence-Chunking-Afri_BERTA_amharic_text"
20
+
21
+ CHUNK_SIZE = 512
22
+ MAX_CHUNKS = 8
23
+ CHUNK_DMODEL = 256
24
+ DROPOUT = 0.1
25
+ NUM_LABELS = 8
26
+
27
+ # ⚠️ MUST match training
28
+ id2label = {
29
+ 0: "Politics",
30
+ 1: "Business",
31
+ 2: "Sports",
32
+ 3: "Technology",
33
+ 4: "Health",
34
+ 5: "Entertainment",
35
+ 6: "Education",
36
+ 7: "Other"
37
+ }
38
+
39
+ # ========================================================
40
+ # MODEL
41
+ # ========================================================
42
+ class HybridSentenceChuLo(nn.Module):
43
+ def __init__(self, pretrained_name, num_labels):
44
+ super().__init__()
45
+
46
+ self.bert = AutoModel.from_pretrained(
47
+ pretrained_name,
48
+ trust_remote_code=True
49
+ )
50
+
51
+ hidden_size = self.bert.config.hidden_size
52
+
53
+ self.proj = nn.Linear(hidden_size, CHUNK_DMODEL) if hidden_size != CHUNK_DMODEL else nn.Identity()
54
+ self.token_attn_vec = nn.Parameter(torch.randn(CHUNK_DMODEL))
55
+
56
+ encoder_layer = nn.TransformerEncoderLayer(
57
+ d_model=CHUNK_DMODEL,
58
+ nhead=8,
59
+ dim_feedforward=4 * CHUNK_DMODEL,
60
+ batch_first=True,
61
+ dropout=DROPOUT
62
+ )
63
+ self.chunk_transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
64
+
65
+ self.classifier = nn.Sequential(
66
+ nn.LayerNorm(CHUNK_DMODEL),
67
+ nn.Linear(CHUNK_DMODEL, num_labels)
68
+ )
69
+
70
+ def forward(self, input_ids, attention_mask):
71
+ B, C, T = input_ids.size()
72
+
73
+ flat_ids = input_ids.view(B * C, T)
74
+ flat_mask = attention_mask.view(B * C, T)
75
+
76
+ bert_out = self.bert(input_ids=flat_ids, attention_mask=flat_mask)
77
+ token_vecs = bert_out.last_hidden_state
78
+
79
+ proj = self.proj(token_vecs)
80
+ attn_scores = torch.matmul(proj, self.token_attn_vec)
81
+
82
+ attn_scores = attn_scores.masked_fill(flat_mask == 0, torch.finfo(attn_scores.dtype).min)
83
+ attn_weights = torch.softmax(attn_scores, dim=1).unsqueeze(-1)
84
+
85
+ chunk_vecs = (proj * attn_weights).sum(dim=1).view(B, C, CHUNK_DMODEL)
86
+
87
+ chunk_mask = (attention_mask.sum(dim=2) > 0)
88
+ key_padding_mask = ~chunk_mask
89
+
90
+ chunk_out = self.chunk_transformer(chunk_vecs, src_key_padding_mask=key_padding_mask)
91
+
92
+ valid_mask = (~key_padding_mask).unsqueeze(-1).float()
93
+ doc_vec = (chunk_out * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1e-6)
94
+
95
+ return self.classifier(doc_vec)
96
+
97
+ # ========================================================
98
+ # Load tokenizer & model
99
+ # ========================================================
100
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
101
+
102
+ model = HybridSentenceChuLo(
103
+ pretrained_name=PRETRAINED,
104
+ num_labels=NUM_LABELS
105
+ ).to(DEVICE)
106
+
107
+ # Load weights from HF Hub
108
+ state_dict = torch.hub.load_state_dict_from_url(
109
+ f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/pytorch_model.bin",
110
+ map_location=DEVICE
111
+ )
112
+ model.load_state_dict(state_dict, strict=False)
113
+ model.eval()
114
+
115
+ # ========================================================
116
+ # Sentence Utilities
117
+ # ========================================================
118
+ def split_sentences(text):
119
+ return [s.strip() for s in re.split(r"(?<=[ፒፀ!?])\s+", text) if s.strip()]
120
+
121
+ def select_topk(sentences):
122
+ n = len(sentences)
123
+ if n == 0:
124
+ return []
125
+ return [sentences[0], sentences[n // 2], sentences[-1]]
126
+
127
+ def encode_sentence_chunks(sentences):
128
+ chunks, masks = [], []
129
+
130
+ for sent in sentences:
131
+ enc = tokenizer(
132
+ sent,
133
+ max_length=CHUNK_SIZE,
134
+ padding="max_length",
135
+ truncation=True,
136
+ return_tensors="pt"
137
+ )
138
+ chunks.append(enc["input_ids"][0])
139
+ masks.append(enc["attention_mask"][0])
140
+
141
+ while len(chunks) < MAX_CHUNKS:
142
+ chunks.append(torch.zeros(CHUNK_SIZE, dtype=torch.long))
143
+ masks.append(torch.zeros(CHUNK_SIZE, dtype=torch.long))
144
+
145
+ return torch.stack(chunks), torch.stack(masks)
146
+
147
+ def build_html(all_sents, selected):
148
+ html = "<div style='font-size:15px; line-height:1.6;'>"
149
+ for s in all_sents:
150
+ safe = s.replace("<", "&lt;").replace(">", "&gt;")
151
+ if s in selected:
152
+ html += f"<p style='background:#d4edda; padding:4px;'><b>{safe}</b></p>"
153
+ else:
154
+ html += f"<p>{safe}</p>"
155
+ html += "</div>"
156
+ return html
157
+
158
+ # ========================================================
159
+ # Prediction
160
+ # ========================================================
161
+ def chulo_predict(text):
162
+ if not text or not text.strip():
163
+ return "⚠️ Please enter Amharic text.", [], ""
164
+
165
+ sents = split_sentences(text)
166
+ selected = select_topk(sents)
167
+
168
+ chunks, masks = encode_sentence_chunks(selected)
169
+
170
+ with torch.no_grad():
171
+ logits = model(
172
+ input_ids=chunks.unsqueeze(0).to(DEVICE),
173
+ attention_mask=masks.unsqueeze(0).to(DEVICE)
174
+ )
175
+ probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
176
+
177
+ pred = id2label[int(np.argmax(probs))]
178
+ table = [(id2label[i], float(probs[i])) for i in range(len(probs))]
179
+
180
+ return f"🏷️ {pred}", table, build_html(sents, selected)
181
+
182
+ # ========================================================
183
+ # Gradio UI (HF Friendly)
184
+ # ========================================================
185
+ demo = gr.Interface(
186
+ fn=chulo_predict,
187
+ inputs=gr.Textbox(lines=8, placeholder="αŠ₯α‰£αŠ­α‹Ž α‹¨αŠ αˆ›αˆ­αŠ› α‹œαŠ“ αŒ½αˆ‘α αŠ₯α‹šαˆ… α‹«αˆ΅αŒˆα‰‘"),
188
+ outputs=[
189
+ gr.Textbox(label="Prediction"),
190
+ gr.Dataframe(headers=["Label", "Probability"], label="Class Probabilities"),
191
+ gr.HTML(label="Highlighted Document")
192
+ ],
193
+ title="Sentence-ChuLo β€” Amharic News Classification",
194
+ description="Uses EXACT Beginning–Middle–End sentence selection."
195
+ )
196
+
197
+ demo.launch()