File size: 5,214 Bytes
8defc61 f3b7541 8defc61 206ed66 673e2f9 f3b7541 206ed66 604feee 206ed66 ab3e62e f3b7541 43f6cad f3b7541 43f6cad f3b7541 43f6cad f3b7541 43f6cad ab3e62e 206ed66 8defc61 206ed66 673e2f9 206ed66 8defc61 206ed66 768757a f3b7541 768757a 3563942 768757a f3b7541 768757a 8defc61 604feee e8bd368 8defc61 e8bd368 8defc61 e8bd368 43f6cad 206ed66 e8bd368 206ed66 e8bd368 206ed66 e8bd368 206ed66 43f6cad 39ce3e3 1bb07fe e8bd368 f3b7541 43f6cad 090528d 43f6cad e8bd368 8defc61 e8bd368 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/
import constants
import torch
import pandas as pd
import streamlit as st
import matplotlib.pyplot as plt
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from constants import DIALECTS
import altair as alt
from altair import X, Y, Scale
import base64
import re
def predict_binary_outcomes(model, tokenizer, text, threshold=0.3):
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
The model is expected to generate logits for each dialect of the following dialects in the same order:
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encodings = tokenizer(
text, truncation=True, padding=True, max_length=128, return_tensors="pt"
)
## inputs
input_ids = encodings["input_ids"].to(device)
attention_mask = encodings["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1).tolist()
return probabilities
# binary_predictions = probabilities.multiply((probabilities >= threshold).astype(int))
# return binary_predictions
# Map indices to actual labels
# predicted_dialects = [
# dialect
# for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
# if dialect_prediction == 1
# ]
# return predicted_dialects
def preprocess_text(arabic_text):
"""Apply preprocessing to the given Arabic text.
Args:
arabic_text: The Arabic text to be preprocessed.
Returns:
The preprocessed Arabic text.
"""
no_urls = re.sub(
r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b",
"",
arabic_text,
flags=re.MULTILINE,
)
no_english = re.sub(r"[a-zA-Z]", "", no_urls)
return no_english
@st.cache_data
def render_svg(svg):
"""Renders the given svg string."""
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>'
c = st.container()
c.write(html, unsafe_allow_html=True)
@st.cache_data
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv(index=None).encode("utf-8")
@st.cache_resource
def load_model(model_name):
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return model
tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME)
model = load_model(constants.MODEL_NAME)
@st.cache_data
def render_metadata():
"""Renders the metadata."""
# TODO: Update!
html = r"""<p align="center">
<a href="https://huggingface.co/AMR-KELEG/Sentence-ALDi"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a>
<a href="https://github.com/AMR-KELEG/ALDi"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a>
<a href="https://arxiv.org/abs/2310.13747"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2310.13747-b31b1b.svg"></a>
</p>"""
c = st.container()
c.write(html, unsafe_allow_html=True)
# TODO: Update!
# render_svg(open("assets/ALDi_logo.svg").read())
render_metadata()
sent = st.text_input(
"Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None
)
# TODO: Check if this is needed!
clicked = st.button("Submit")
if sent:
probabilities = predict_binary_outcomes(model, tokenizer, sent)
ORANGE_COLOR = "#FF8000"
fig, ax = plt.subplots(figsize=(8, 1))
fig.patch.set_facecolor("none")
ax.set_facecolor("none")
ax.spines["left"].set_color(ORANGE_COLOR)
ax.spines["bottom"].set_color(ORANGE_COLOR)
ax.tick_params(axis="x", colors=ORANGE_COLOR)
ax.spines[["right", "top"]].set_visible(False)
# dialect_labels = [int(dialect in valid_dialects) for dialect in DIALECTS]
im = ax.imshow([probabilities], cmap="vanimo", alpha=0.5, vmin=0, vmax=1)
ax.set_xticks(range(len(DIALECTS)))
ax.set_xticklabels(DIALECTS, fontsize=8, rotation=90, ha="right")
ax.set_yticks([])
ax.set_title("Valid Dialects", color=ORANGE_COLOR)
for i in range(len(DIALECTS)):
text = ax.text(i, 0, round(probabilities[i], 2),
ha="center", va="center", color="w")
# ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR)
# ax.set_xlim(0, 1)
# ax.set_ylim(-1, 1)
# ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR)
# ax.get_yaxis().set_visible(False)
# ax.set_xlabel("ALDi score", color=ORANGE_COLOR)
st.pyplot(fig)
print(sent)
|