ADI / app.py
AMR-KELEG's picture
Update app.py
090528d verified
# Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/
import constants
import torch
import pandas as pd
import streamlit as st
import matplotlib.pyplot as plt
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from constants import DIALECTS
import altair as alt
from altair import X, Y, Scale
import base64
import re
def predict_binary_outcomes(model, tokenizer, text, threshold=0.3):
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
The model is expected to generate logits for each dialect of the following dialects in the same order:
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encodings = tokenizer(
text, truncation=True, padding=True, max_length=128, return_tensors="pt"
)
## inputs
input_ids = encodings["input_ids"].to(device)
attention_mask = encodings["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1).tolist()
return probabilities
# binary_predictions = probabilities.multiply((probabilities >= threshold).astype(int))
# return binary_predictions
# Map indices to actual labels
# predicted_dialects = [
# dialect
# for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
# if dialect_prediction == 1
# ]
# return predicted_dialects
def preprocess_text(arabic_text):
"""Apply preprocessing to the given Arabic text.
Args:
arabic_text: The Arabic text to be preprocessed.
Returns:
The preprocessed Arabic text.
"""
no_urls = re.sub(
r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b",
"",
arabic_text,
flags=re.MULTILINE,
)
no_english = re.sub(r"[a-zA-Z]", "", no_urls)
return no_english
@st.cache_data
def render_svg(svg):
"""Renders the given svg string."""
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>'
c = st.container()
c.write(html, unsafe_allow_html=True)
@st.cache_data
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv(index=None).encode("utf-8")
@st.cache_resource
def load_model(model_name):
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return model
tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME)
model = load_model(constants.MODEL_NAME)
@st.cache_data
def render_metadata():
"""Renders the metadata."""
# TODO: Update!
html = r"""<p align="center">
<a href="https://huggingface.co/AMR-KELEG/Sentence-ALDi"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a>
<a href="https://github.com/AMR-KELEG/ALDi"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a>
<a href="https://arxiv.org/abs/2310.13747"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2310.13747-b31b1b.svg"></a>
</p>"""
c = st.container()
c.write(html, unsafe_allow_html=True)
# TODO: Update!
# render_svg(open("assets/ALDi_logo.svg").read())
render_metadata()
sent = st.text_input(
"Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None
)
# TODO: Check if this is needed!
clicked = st.button("Submit")
if sent:
probabilities = predict_binary_outcomes(model, tokenizer, sent)
ORANGE_COLOR = "#FF8000"
fig, ax = plt.subplots(figsize=(8, 1))
fig.patch.set_facecolor("none")
ax.set_facecolor("none")
ax.spines["left"].set_color(ORANGE_COLOR)
ax.spines["bottom"].set_color(ORANGE_COLOR)
ax.tick_params(axis="x", colors=ORANGE_COLOR)
ax.spines[["right", "top"]].set_visible(False)
# dialect_labels = [int(dialect in valid_dialects) for dialect in DIALECTS]
im = ax.imshow([probabilities], cmap="vanimo", alpha=0.5, vmin=0, vmax=1)
ax.set_xticks(range(len(DIALECTS)))
ax.set_xticklabels(DIALECTS, fontsize=8, rotation=90, ha="right")
ax.set_yticks([])
ax.set_title("Valid Dialects", color=ORANGE_COLOR)
for i in range(len(DIALECTS)):
text = ax.text(i, 0, round(probabilities[i], 2),
ha="center", va="center", color="w")
# ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR)
# ax.set_xlim(0, 1)
# ax.set_ylim(-1, 1)
# ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR)
# ax.get_yaxis().set_visible(False)
# ax.set_xlabel("ALDi score", color=ORANGE_COLOR)
st.pyplot(fig)
print(sent)