Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -116,14 +116,14 @@ def get_bert_embedding(review_text):
|
|
| 116 |
|
| 117 |
|
| 118 |
#Get SpaBERT Embedding for geo-entity
|
| 119 |
-
def get_spaBert_embedding(entity):
|
| 120 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
| 121 |
if entity_index is None:
|
| 122 |
if(dev_mode == True):
|
| 123 |
st.write("Got Bert embedding for: ", entity)
|
| 124 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
| 125 |
else:
|
| 126 |
-
|
| 127 |
if(dev_mode == True):
|
| 128 |
st.write("Got SpaBert embedding for: ", entity)
|
| 129 |
return spaBERT_embeddings[entity_index]
|
|
@@ -134,11 +134,12 @@ def processSpatialEntities(review, nlp):
|
|
| 134 |
doc = nlp(review)
|
| 135 |
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
|
| 136 |
token_embeddings = []
|
|
|
|
| 137 |
|
| 138 |
# Iterate over each entity span and process only geo entities
|
| 139 |
for start, end, text, label in entity_spans:
|
| 140 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
| 141 |
-
spaBert_emb = get_spaBert_embedding(text)
|
| 142 |
token_embeddings.append(spaBert_emb)
|
| 143 |
if(dev_mode == True):
|
| 144 |
st.write("Geo-Entity Found in review: ", text)
|
|
@@ -146,7 +147,7 @@ def processSpatialEntities(review, nlp):
|
|
| 146 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
| 147 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
| 148 |
#processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
|
| 149 |
-
return processed_embedding
|
| 150 |
|
| 151 |
|
| 152 |
#Initialize discriminator module
|
|
@@ -262,7 +263,7 @@ selected_review = example_reviews[selected_key]
|
|
| 262 |
if st.button("Process Review"):
|
| 263 |
if selected_review.strip():
|
| 264 |
bert_embedding = get_bert_embedding(selected_review)
|
| 265 |
-
spaBert_embedding = processSpatialEntities(selected_review,nlp)
|
| 266 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
| 267 |
|
| 268 |
if(dev_mode == True):
|
|
@@ -290,6 +291,10 @@ if st.button("Process Review"):
|
|
| 290 |
# Display the highlighted text with HTML support
|
| 291 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
#Display the models prediction
|
| 294 |
if(prediction == 0):
|
| 295 |
st.write("Prediction: Not Spam")
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
#Get SpaBERT Embedding for geo-entity
|
| 119 |
+
def get_spaBert_embedding(entity,current_pseudo_sentences):
|
| 120 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
| 121 |
if entity_index is None:
|
| 122 |
if(dev_mode == True):
|
| 123 |
st.write("Got Bert embedding for: ", entity)
|
| 124 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
| 125 |
else:
|
| 126 |
+
current_pseudo_sentences.append(pseudo_sentences[entity_index])
|
| 127 |
if(dev_mode == True):
|
| 128 |
st.write("Got SpaBert embedding for: ", entity)
|
| 129 |
return spaBERT_embeddings[entity_index]
|
|
|
|
| 134 |
doc = nlp(review)
|
| 135 |
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
|
| 136 |
token_embeddings = []
|
| 137 |
+
current_pseudo_sentences = []
|
| 138 |
|
| 139 |
# Iterate over each entity span and process only geo entities
|
| 140 |
for start, end, text, label in entity_spans:
|
| 141 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
| 142 |
+
spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences)
|
| 143 |
token_embeddings.append(spaBert_emb)
|
| 144 |
if(dev_mode == True):
|
| 145 |
st.write("Geo-Entity Found in review: ", text)
|
|
|
|
| 147 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
| 148 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
| 149 |
#processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
|
| 150 |
+
return processed_embedding,current_pseudo_sentences
|
| 151 |
|
| 152 |
|
| 153 |
#Initialize discriminator module
|
|
|
|
| 263 |
if st.button("Process Review"):
|
| 264 |
if selected_review.strip():
|
| 265 |
bert_embedding = get_bert_embedding(selected_review)
|
| 266 |
+
spaBert_embedding, current_pseudo_sentences = processSpatialEntities(selected_review,nlp)
|
| 267 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
| 268 |
|
| 269 |
if(dev_mode == True):
|
|
|
|
| 291 |
# Display the highlighted text with HTML support
|
| 292 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
| 293 |
|
| 294 |
+
#Display pseudo sentences found
|
| 295 |
+
for sentence in current_pseudo_sentences:
|
| 296 |
+
st.write("Pseudo-Sentence: ", sentence)
|
| 297 |
+
|
| 298 |
#Display the models prediction
|
| 299 |
if(prediction == 0):
|
| 300 |
st.write("Prediction: Not Spam")
|