File size: 4,547 Bytes
fe4863c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

import torch
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ----------------------------------------
# 1. Load from Hugging Face Hub
# ----------------------------------------

# Change this to YOUR pushed model repo
HUB_MODEL_ID = "Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext"
  # <--- EDIT IF NEEDED

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512  # model context window in TOKENS

# Load tokenizer and model directly from HF Hub
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID)
model.to(DEVICE)
model.eval()

# Label mapping from config
id2label = {int(k): v for k, v in model.config.id2label.items()}
num_labels = len(id2label)

# ----------------------------------------
# Helper: highlight tokens after MAX_LENGTH in red (HTML)
# ----------------------------------------
def highlight_token_overflow(text: str, max_tokens: int = 512) -> str:
    """
    Tokenize the input text and generate HTML where tokens beyond
    `max_tokens` are wrapped in red. This shows exactly which tokens
    are outside the model's context window.
    """
    if not text.strip():
        return "<i>No text provided.</i>"

    # Tokenize without truncation (so we can see ALL tokens)
    tokens = tokenizer.tokenize(text)
    if len(tokens) == 0:
        return "<i>No tokens produced by tokenizer.</i>"

    spans = []
    for i, tok in enumerate(tokens):
        # minimal HTML escape
        safe_tok = (
            tok.replace("&", "&amp;")
               .replace("<", "&lt;")
               .replace(">", "&gt;")
        )

        if i >= max_tokens:
            spans.append(f"<span style='color:red;font-weight:bold;'>{safe_tok}</span>")
        else:
            spans.append(f"<span>{safe_tok}</span>")

    html = " ".join(spans)

    if len(tokens) > max_tokens:
        html += (
            f"<br><br>"
            f"<small style='color:red;'>"
            f"Note: Tokens in <b>red</b> are beyond the model context window "
            f"({max_tokens} tokens) and will be truncated."
            f"</small>"
        )
    else:
        html += (
            f"<br><br>"
            f"<small>Token count: {len(tokens)} (≀ {max_tokens}, no truncation).</small>"
        )

    return html

# ----------------------------------------
# 2. Prediction
# ----------------------------------------
def predict_amharic_news(text):
    if not text.strip():
        # Also return highlighted version (empty)
        return "Please enter text.", None, "<i>No text provided.</i>"

    # For actual model inference: truncate to MAX_LENGTH tokens
    encoded = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )

    encoded = {k: v.to(DEVICE) for k, v in encoded.items()}

    with torch.no_grad():
        outputs = model(**encoded)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]

    pred_id = int(np.argmax(probs))
    pred_label = id2label.get(pred_id, f"LABEL_{pred_id}")

    # Prepare probability table
    rows = []
    for i in range(num_labels):
        rows.append((id2label.get(i, f"LABEL_{i}"), float(probs[i])))

    rows = sorted(rows, key=lambda x: x[1], reverse=True)

    # Build HTML showing tokens; tokens >512 in red
    token_highlight_html = highlight_token_overflow(text, max_tokens=MAX_LENGTH)

    # Now we return 3 outputs: prediction, probs table, token visualization
    return f"Predicted Label: {pred_label}", rows, token_highlight_html

# ----------------------------------------
# 3. Gradio Interface
# ----------------------------------------
demo = gr.Interface(
    fn=predict_amharic_news,
    inputs=gr.Textbox(
        lines=5,
        label="Enter Amharic News Text",
        placeholder="αŠ₯α‰£αŠ­α‹Ž α‹¨αŠ αˆ›αˆ­αŠ› α‹œαŠ“ αŒ½αˆ‘α α‹«αˆ΅αŒˆα‰‘..."
    ),
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.Dataframe(
            headers=["Label", "Probability"],
            label="Class Probabilities"
        ),
        gr.HTML(label="Tokenizer view (tokens > 512 are red)")
    ],
    title="Amharic News Classifier",
    description=(
        "XLM-RoBERTa model loaded directly from Hugging Face Hub (raw text input, no preprocessing). "
        "Below, tokenizer output shows which tokens are beyond the 512-token context window (in red)."
    )
)

demo.launch()