AMR-KELEG commited on
Commit
f3b7541
·
1 Parent(s): 3563942

Update the model name

Browse files
Files changed (2) hide show
  1. app.py +52 -85
  2. constants.py +23 -2
app.py CHANGED
@@ -1,9 +1,11 @@
1
  # Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/
2
  import constants
 
3
  import pandas as pd
4
  import streamlit as st
5
  import matplotlib.pyplot as plt
6
  from transformers import BertForSequenceClassification, AutoTokenizer
 
7
 
8
  import altair as alt
9
  from altair import X, Y, Scale
@@ -11,6 +13,38 @@ import base64
11
 
12
  import re
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def preprocess_text(arabic_text):
16
  """Apply preprocessing to the given Arabic text.
@@ -57,42 +91,10 @@ tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME)
57
  model = load_model(constants.MODEL_NAME)
58
 
59
 
60
- def compute_ALDi(sentences):
61
- """Computes the ALDi score for the given sentences.
62
-
63
- Args:
64
- sentences: A list of Arabic sentences.
65
-
66
- Returns:
67
- A list of ALDi scores for the given sentences.
68
- """
69
- progress_text = "Computing ALDi..."
70
- my_bar = st.progress(0, text=progress_text)
71
-
72
- BATCH_SIZE = 4
73
- output_logits = []
74
-
75
- preprocessed_sentences = [preprocess_text(s) for s in sentences]
76
-
77
- for first_index in range(0, len(preprocessed_sentences), BATCH_SIZE):
78
- inputs = tokenizer(
79
- preprocessed_sentences[first_index : first_index + BATCH_SIZE],
80
- return_tensors="pt",
81
- padding=True,
82
- )
83
- outputs = model(**inputs).logits.reshape(-1).tolist()
84
- output_logits = output_logits + [max(min(o, 1), 0) for o in outputs]
85
- my_bar.progress(
86
- min((first_index + BATCH_SIZE) / len(preprocessed_sentences), 1),
87
- text=progress_text,
88
- )
89
- my_bar.empty()
90
- return output_logits
91
-
92
-
93
  @st.cache_data
94
  def render_metadata():
95
  """Renders the metadata."""
 
96
  html = r"""<p align="center">
97
  <a href="https://huggingface.co/AMR-KELEG/Sentence-ALDi"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a>
98
  <a href="https://github.com/AMR-KELEG/ALDi"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a>
@@ -101,10 +103,11 @@ def render_metadata():
101
  c = st.container()
102
  c.write(html, unsafe_allow_html=True)
103
 
104
- render_svg(open("assets/ALDi_logo.svg").read())
 
105
  render_metadata()
106
 
107
- tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"])
108
 
109
  with tab1:
110
  sent = st.text_input(
@@ -115,7 +118,7 @@ with tab1:
115
  clicked = st.button("Submit")
116
 
117
  if sent:
118
- ALDi_score = compute_ALDi([sent])[0]
119
 
120
  ORANGE_COLOR = "#FF8000"
121
  fig, ax = plt.subplots(figsize=(8, 1))
@@ -128,55 +131,19 @@ with tab1:
128
 
129
  ax.spines[["right", "top"]].set_visible(False)
130
 
131
- ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR)
132
- ax.set_xlim(0, 1)
133
- ax.set_ylim(-1, 1)
134
- ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR)
135
- ax.get_yaxis().set_visible(False)
136
- ax.set_xlabel("ALDi score", color=ORANGE_COLOR)
 
 
 
 
 
 
 
137
  st.pyplot(fig)
138
 
139
  print(sent)
140
- with open("logs.txt", "a") as f:
141
- f.write(sent + "\n")
142
-
143
- with tab2:
144
- file = st.file_uploader("Upload a file", type=["txt"])
145
- if file is not None:
146
- df = pd.read_csv(file, sep="\t", header=None)
147
- df.columns = ["Sentence"]
148
- df.reset_index(drop=True, inplace=True)
149
-
150
- # TODO: Run the model
151
- df["ALDi"] = compute_ALDi(df["Sentence"].tolist())
152
-
153
- # A horizontal rule
154
- st.markdown("""---""")
155
-
156
- chart = (
157
- alt.Chart(df.reset_index())
158
- .mark_area(color="darkorange", opacity=0.5)
159
- .encode(
160
- x=X(field="index", title="Sentence Index"),
161
- y=Y("ALDi", scale=Scale(domain=[0, 1])),
162
- )
163
- )
164
- st.altair_chart(chart.interactive(), use_container_width=True)
165
-
166
- col1, col2 = st.columns([4, 1])
167
-
168
- with col1:
169
- # Display the output
170
- st.table(
171
- df,
172
- )
173
-
174
- with col2:
175
- # Add a download button
176
- csv = convert_df(df)
177
- st.download_button(
178
- label=":file_folder: Download predictions as CSV",
179
- data=csv,
180
- file_name="ALDi_scores.csv",
181
- mime="text/csv",
182
- )
 
1
  # Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/
2
  import constants
3
+ import torch
4
  import pandas as pd
5
  import streamlit as st
6
  import matplotlib.pyplot as plt
7
  from transformers import BertForSequenceClassification, AutoTokenizer
8
+ from constants import DIALECTS
9
 
10
  import altair as alt
11
  from altair import X, Y, Scale
 
13
 
14
  import re
15
 
16
+ def predict_binary_outcomes(model, tokenizer, text, threshold=0.3):
17
+ """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
18
+ Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
19
+ The model is expected to generate logits for each dialect of the following dialects in the same order:
20
+ Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
21
+ Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
22
+ """
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ encodings = tokenizer(
26
+ text, truncation=True, padding=True, max_length=128, return_tensors="pt"
27
+ )
28
+
29
+ ## inputs
30
+ input_ids = encodings["input_ids"].to(device)
31
+ attention_mask = encodings["attention_mask"].to(device)
32
+
33
+ with torch.no_grad():
34
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
35
+ logits = outputs.logits
36
+
37
+ probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
38
+ binary_predictions = (probabilities >= threshold).astype(int)
39
+
40
+ # Map indices to actual labels
41
+ predicted_dialects = [
42
+ dialect
43
+ for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
44
+ if dialect_prediction == 1
45
+ ]
46
+
47
+ return predicted_dialects
48
 
49
  def preprocess_text(arabic_text):
50
  """Apply preprocessing to the given Arabic text.
 
91
  model = load_model(constants.MODEL_NAME)
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @st.cache_data
95
  def render_metadata():
96
  """Renders the metadata."""
97
+ # TODO: Update!
98
  html = r"""<p align="center">
99
  <a href="https://huggingface.co/AMR-KELEG/Sentence-ALDi"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a>
100
  <a href="https://github.com/AMR-KELEG/ALDi"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a>
 
103
  c = st.container()
104
  c.write(html, unsafe_allow_html=True)
105
 
106
+ # TODO: Update!
107
+ # render_svg(open("assets/ALDi_logo.svg").read())
108
  render_metadata()
109
 
110
+ tab1= st.tabs(["Input a Sentence"])
111
 
112
  with tab1:
113
  sent = st.text_input(
 
118
  clicked = st.button("Submit")
119
 
120
  if sent:
121
+ valid_dialects = predict_binary_outcomes(model, tokenizer, sent)
122
 
123
  ORANGE_COLOR = "#FF8000"
124
  fig, ax = plt.subplots(figsize=(8, 1))
 
131
 
132
  ax.spines[["right", "top"]].set_visible(False)
133
 
134
+ dialect_labels = [int(dialect in valid_dialects) for dialect in DIALECTS]
135
+ im = ax.imshow(dialect_labels, cmap="vanimo", alpha=0.5, vmin=0, vmax=1, annot=False)
136
+ ax.set_yticks(range(len(DIALECTS)))
137
+ ax.set_yticklabels(DIALECTS, fontsize=8)
138
+ ax.set_xticks([])
139
+ ax.set_title("Valid Dialects", color=ORANGE_COLOR)
140
+
141
+ # ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR)
142
+ # ax.set_xlim(0, 1)
143
+ # ax.set_ylim(-1, 1)
144
+ # ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR)
145
+ # ax.get_yaxis().set_visible(False)
146
+ # ax.set_xlabel("ALDi score", color=ORANGE_COLOR)
147
  st.pyplot(fig)
148
 
149
  print(sent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
constants.py CHANGED
@@ -1,4 +1,25 @@
1
  CHOICE_TEXT = "Input Text"
2
  CHOICE_FILE = "Upload File"
3
- TITLE = "ALDi: Arabic Level of Dialectness"
4
- MODEL_NAME = "AMR-KELEG/Sentence-ALDi"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  CHOICE_TEXT = "Input Text"
2
  CHOICE_FILE = "Upload File"
3
+ TITLE = "ADI: Arabic Dialect Idenitifcation"
4
+ MODEL_NAME = "AHAAM/B2BERT"
5
+
6
+ DIALECTS = [
7
+ "Algeria",
8
+ "Bahrain",
9
+ "Egypt",
10
+ "Iraq",
11
+ "Jordan",
12
+ "Kuwait",
13
+ "Lebanon",
14
+ "Libya",
15
+ "Morocco",
16
+ "Oman",
17
+ "Palestine",
18
+ "Qatar",
19
+ "Saudi_Arabia",
20
+ "Sudan",
21
+ "Syria",
22
+ "Tunisia",
23
+ "UAE",
24
+ "Yemen",
25
+ ]