ErikDaska commited on
Commit
c6ea9d6
·
verified ·
1 Parent(s): 5f5f3b2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +8 -42
src/streamlit_app.py CHANGED
@@ -12,7 +12,7 @@ def instantiate_gpt2(model_name: str,max_length_ : int, num_return_sequences : i
12
  results = pipe(text, max_length=max_length_, num_return_sequences=num_return_sequences)
13
  return results
14
 
15
- def instantiate_roberta(top_k : int, text : str) -> dict:
16
  pipe = pipeline("fill-mask", model="Iscte-Sintra/RoBERTa-Kriolu", tokenizer="Iscte-Sintra/RoBERTa-Kriolu", token=token)
17
  return pipe(text, top_k=top_k)
18
 
@@ -37,9 +37,9 @@ def build_decoder_page(model_name):
37
  st.warning('Max length must be greater than default sentence number of tokens!', icon="⚠️")
38
  st.warning(e)
39
 
40
- def build_roberta_page():
41
 
42
- st.title("RoBERTa : Encoder")
43
 
44
  top_k = st.sidebar.number_input('Number of predictions to return', min_value=1, max_value=5, value=1, step=1)
45
 
@@ -56,7 +56,7 @@ def build_roberta_page():
56
  submit = st.button("Submit")
57
  try:
58
  if submit and input_text:
59
- results = instantiate_roberta(top_k, input_text)
60
  except Exception as e:
61
  st.warning('There must be a special token "<mask>" in sentence!', icon="⚠️")
62
  st.warning(e)
@@ -70,48 +70,14 @@ def build_roberta_page():
70
  else:
71
  predicted_text = st.text_input("Predicted Token", disabled=True)
72
 
73
- def build_albertina_page():
74
- st.title("Albertina : Encoder")
75
-
76
- top_k = st.sidebar.number_input('Number of predictions to return', min_value=1, max_value=5, value=1, step=1)
77
-
78
- st.write("Enter a sentence with a **[MASK]** token, and the model will predict the missing word.")
79
-
80
- results = None
81
-
82
- col1, col2 = st.columns(2)
83
-
84
- with col1:
85
- st.subheader("Input")
86
- input_text = st.text_input("Input Sentence", "Katxor sta trás di [MASK].")
87
-
88
- submit = st.button("Submit")
89
- try:
90
- if submit and input_text:
91
- results = instantiate_albertina(top_k, input_text)
92
- except Exception as e:
93
- st.warning('There must be a special token "[MASK]" in sentence!', icon="⚠️")
94
- st.warning(e)
95
-
96
- with col2:
97
- st.subheader("Prediction")
98
- if results:
99
- predicted_text = st.text_input("Predicted Token", value=results[0]['sequence'], disabled=True)
100
- for result in results:
101
- st.write(f"**Prediction**: {result['token_str']} | **Confidence**: {round(result['score'], 4)}")
102
- else:
103
- predicted_text = st.text_input("Predicted Token", disabled=True)
104
-
105
 
106
  # Your dictionary of models
107
- model_dict = {'RoBERTa': 1,'Albertina':2 ,'GPT-2': 3}
108
 
109
  # Always appears at the top of the sidebar
110
  selected_model = st.sidebar.selectbox("Architecture", list(model_dict.keys()))
111
 
112
- if model_dict[selected_model] == 1:
113
- build_roberta_page()
114
- elif model_dict[selected_model] == 2:
115
- build_albertina_page()
116
  else:
117
- build_decoder_page('GPT2-Kriolu')
 
12
  results = pipe(text, max_length=max_length_, num_return_sequences=num_return_sequences)
13
  return results
14
 
15
+ def instantiate_encoder(model_name: str, top_k : int, text : str) -> dict:
16
  pipe = pipeline("fill-mask", model="Iscte-Sintra/RoBERTa-Kriolu", tokenizer="Iscte-Sintra/RoBERTa-Kriolu", token=token)
17
  return pipe(text, top_k=top_k)
18
 
 
37
  st.warning('Max length must be greater than default sentence number of tokens!', icon="⚠️")
38
  st.warning(e)
39
 
40
+ def build_encoder_page(model_name:str):
41
 
42
+ st.title(f"{model_name} : Encoder - Fill-Mask Task")
43
 
44
  top_k = st.sidebar.number_input('Number of predictions to return', min_value=1, max_value=5, value=1, step=1)
45
 
 
56
  submit = st.button("Submit")
57
  try:
58
  if submit and input_text:
59
+ results = instantiate_encoder(model_name, top_k, input_text)
60
  except Exception as e:
61
  st.warning('There must be a special token "<mask>" in sentence!', icon="⚠️")
62
  st.warning(e)
 
70
  else:
71
  predicted_text = st.text_input("Predicted Token", disabled=True)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Your dictionary of models
75
+ model_dict = {'RoBERTa-Kriolu': "Encoder",'Albertina-Kriolu':"Encoder" ,'GPT2-Kriolu': "Decoder"}
76
 
77
  # Always appears at the top of the sidebar
78
  selected_model = st.sidebar.selectbox("Architecture", list(model_dict.keys()))
79
 
80
+ if model_dict[selected_model] == "Encoder":
81
+ build_encoder_page(selected_model)
 
 
82
  else:
83
+ build_decoder_page(selected_model)