|
|
|
|
|
import constants |
|
|
import torch |
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
import matplotlib.pyplot as plt |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
from constants import DIALECTS |
|
|
|
|
|
import altair as alt |
|
|
from altair import X, Y, Scale |
|
|
import base64 |
|
|
|
|
|
import re |
|
|
|
|
|
def predict_binary_outcomes(model, tokenizer, text, threshold=0.3): |
|
|
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit. |
|
|
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid. |
|
|
The model is expected to generate logits for each dialect of the following dialects in the same order: |
|
|
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. |
|
|
Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI. |
|
|
""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
encodings = tokenizer( |
|
|
text, truncation=True, padding=True, max_length=128, return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = encodings["input_ids"].to(device) |
|
|
attention_mask = encodings["attention_mask"].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = outputs.logits |
|
|
|
|
|
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1).tolist() |
|
|
return probabilities |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_text(arabic_text): |
|
|
"""Apply preprocessing to the given Arabic text. |
|
|
|
|
|
Args: |
|
|
arabic_text: The Arabic text to be preprocessed. |
|
|
|
|
|
Returns: |
|
|
The preprocessed Arabic text. |
|
|
""" |
|
|
no_urls = re.sub( |
|
|
r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b", |
|
|
"", |
|
|
arabic_text, |
|
|
flags=re.MULTILINE, |
|
|
) |
|
|
no_english = re.sub(r"[a-zA-Z]", "", no_urls) |
|
|
|
|
|
return no_english |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def render_svg(svg): |
|
|
"""Renders the given svg string.""" |
|
|
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") |
|
|
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>' |
|
|
c = st.container() |
|
|
c.write(html, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def convert_df(df): |
|
|
|
|
|
return df.to_csv(index=None).encode("utf-8") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(model_name): |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
return model |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME) |
|
|
model = load_model(constants.MODEL_NAME) |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def render_metadata(): |
|
|
"""Renders the metadata.""" |
|
|
|
|
|
html = r"""<p align="center"> |
|
|
<a href="https://huggingface.co/AMR-KELEG/Sentence-ALDi"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a> |
|
|
<a href="https://github.com/AMR-KELEG/ALDi"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a> |
|
|
<a href="https://arxiv.org/abs/2310.13747"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2310.13747-b31b1b.svg"></a> |
|
|
</p>""" |
|
|
c = st.container() |
|
|
c.write(html, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
render_metadata() |
|
|
|
|
|
|
|
|
sent = st.text_input( |
|
|
"Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None |
|
|
) |
|
|
|
|
|
|
|
|
clicked = st.button("Submit") |
|
|
|
|
|
if sent: |
|
|
probabilities = predict_binary_outcomes(model, tokenizer, sent) |
|
|
|
|
|
ORANGE_COLOR = "#FF8000" |
|
|
fig, ax = plt.subplots(figsize=(8, 1)) |
|
|
fig.patch.set_facecolor("none") |
|
|
ax.set_facecolor("none") |
|
|
|
|
|
ax.spines["left"].set_color(ORANGE_COLOR) |
|
|
ax.spines["bottom"].set_color(ORANGE_COLOR) |
|
|
ax.tick_params(axis="x", colors=ORANGE_COLOR) |
|
|
|
|
|
ax.spines[["right", "top"]].set_visible(False) |
|
|
|
|
|
|
|
|
im = ax.imshow([probabilities], cmap="vanimo", alpha=0.5, vmin=0, vmax=1) |
|
|
ax.set_xticks(range(len(DIALECTS))) |
|
|
ax.set_xticklabels(DIALECTS, fontsize=8, rotation=90, ha="right") |
|
|
ax.set_yticks([]) |
|
|
ax.set_title("Valid Dialects", color=ORANGE_COLOR) |
|
|
|
|
|
for i in range(len(DIALECTS)): |
|
|
text = ax.text(i, 0, round(probabilities[i], 2), |
|
|
ha="center", va="center", color="w") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.pyplot(fig) |
|
|
|
|
|
print(sent) |
|
|
|