Spaces:
Sleeping
Sleeping
File size: 6,548 Bytes
f1ee6d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | import os
from typing import Any
import gradio as gr
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
ADAPTED_MODEL_ID = os.getenv("ADAPTED_MODEL_ID", os.getenv("MODEL_ID", "Rogendo/afribert-kenya-adapted"))
BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "castorini/afriberta_large")
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "castorini/afriberta_large")
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
def load_models() -> tuple[Any, Any, Any, torch.device]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, token=HF_TOKEN, use_fast=False)
base_model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL_ID, use_safetensors=True)
adapted_model = AutoModelForMaskedLM.from_pretrained(
ADAPTED_MODEL_ID,
token=HF_TOKEN,
use_safetensors=True,
)
base_model.to(device)
adapted_model.to(device)
base_model.eval()
adapted_model.eval()
return tokenizer, base_model, adapted_model, device
tokenizer, base_model, adapted_model, device = load_models()
MASK_TOKEN = tokenizer.mask_token or "[MASK]"
EXAMPLES = [
f"Oya, twendeni zetu, kuna {MASK_TOKEN} flani ameniudhi.",
f"Tuma {MASK_TOKEN} kwa kutumia nambari ya simu kupitia huduma ya M-PESA.",
f"Mtoto aliripotiwa kwa ofisi ya {MASK_TOKEN} wa jamii baada ya kudhulumiwa nyumbani.",
f"Tulifanya meeting jana na manager akasema {MASK_TOKEN} itakuwa ready wiki ijayo.",
f"Msee alikuwa poa sana, akanisaidia kupata {MASK_TOKEN} ya ofisi.",
]
def normalize_input(text: str) -> str:
text = (text or "").strip()
if "[MASK]" in text and MASK_TOKEN != "[MASK]":
text = text.replace("[MASK]", MASK_TOKEN)
return text
def model_predictions(model, inputs, mask_positions, top_k: int, model_label: str) -> list[list[Any]]:
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0]
rows = []
for mask_index, position in enumerate(mask_positions.tolist(), start=1):
probabilities = torch.softmax(logits[position], dim=-1)
scores, token_ids = torch.topk(probabilities, k=int(top_k))
for rank, (score, token_id) in enumerate(zip(scores, token_ids), start=1):
token = tokenizer.decode([token_id.item()]).strip()
completed = inputs["input_ids"][0].clone()
completed[position] = token_id
sequence = tokenizer.decode(completed, skip_special_tokens=True)
rows.append([
model_label,
mask_index,
rank,
token,
round(float(score.item()), 4),
sequence,
])
return rows
def predict_masks(text: str, top_k: int) -> tuple[str, list[list[Any]], list[list[Any]]]:
text = normalize_input(text)
if not text:
return "Enter a sentence with a mask token.", [], []
if MASK_TOKEN not in text:
return f"Add at least one mask token: `{MASK_TOKEN}`", [], []
inputs = tokenizer(text, return_tensors="pt").to(device)
mask_positions = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0]
if len(mask_positions) == 0:
return f"No valid mask token found. Use `{MASK_TOKEN}`.", [], []
base_rows = model_predictions(base_model, inputs, mask_positions, top_k, "Base AfriBERT")
adapted_rows = model_predictions(adapted_model, inputs, mask_positions, top_k, "Adapted AfriBERT Kenya")
comparison_rows = []
for base_row, adapted_row in zip(base_rows, adapted_rows):
comparison_rows.append([
base_row[1],
base_row[2],
base_row[3],
base_row[4],
adapted_row[3],
adapted_row[4],
])
summary = (
f"Base model: `{BASE_MODEL_ID}`\n\n"
f"Adapted model: `{ADAPTED_MODEL_ID}`\n\n"
f"Tokenizer: `{TOKENIZER_ID}`\n\n"
f"Mask token: `{MASK_TOKEN}`\n\n"
f"Found {len(mask_positions)} mask position{'s' if len(mask_positions) != 1 else ''}."
)
return summary, comparison_rows, base_rows + adapted_rows
with gr.Blocks(title="AfriBERT Kenya Masked LM") as demo:
gr.Markdown(
"""
# AfriBERT Kenya Masked Language Modeling
Compare base AfriBERT against the Kenya-adapted model on Swahili, Sheng,
Kenyan institutional text, M-PESA language, and English-Swahili code-switching.
"""
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input text",
value=EXAMPLES[0],
lines=4,
placeholder=f"Type a sentence containing {MASK_TOKEN}",
)
top_k = gr.Slider(
label="Top predictions",
minimum=1,
maximum=10,
value=5,
step=1,
)
predict_button = gr.Button("Compare masked-token predictions", variant="primary")
with gr.Column(scale=1):
gr.Markdown(
f"""
**How to use**
Add `{MASK_TOKEN}` where you want the model to predict a token.
`[MASK]` is also accepted and converted automatically.
For private models, set `HF_TOKEN` before launching the app.
The same base AfriBERT tokenizer is used for both models.
"""
)
summary_output = gr.Markdown()
comparison_output = gr.Dataframe(
headers=["Mask", "Rank", "Base prediction", "Base score", "Adapted prediction", "Adapted score"],
datatype=["number", "number", "str", "number", "str", "number"],
label="Side-by-side comparison",
wrap=True,
)
details_output = gr.Dataframe(
headers=["Model", "Mask", "Rank", "Prediction", "Score", "Completed sentence"],
datatype=["str", "number", "number", "str", "number", "str"],
label="Detailed predictions",
wrap=True,
)
gr.Examples(
examples=EXAMPLES,
inputs=text_input,
)
predict_button.click(
fn=predict_masks,
inputs=[text_input, top_k],
outputs=[summary_output, comparison_output, details_output],
)
text_input.submit(
fn=predict_masks,
inputs=[text_input, top_k],
outputs=[summary_output, comparison_output, details_output],
)
if __name__ == "__main__":
demo.launch()
|