Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
|
|
| 12 |
from PIL import Image
|
| 13 |
|
| 14 |
device = torch.device('cpu')
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
#Spacy Initialization Section
|
|
@@ -117,10 +118,12 @@ def get_bert_embedding(review_text):
|
|
| 117 |
def get_spaBert_embedding(entity):
|
| 118 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
| 119 |
if entity_index is None:
|
| 120 |
-
|
|
|
|
| 121 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
| 122 |
else:
|
| 123 |
-
|
|
|
|
| 124 |
return spaBERT_embeddings[entity_index]
|
| 125 |
|
| 126 |
|
|
@@ -135,7 +138,8 @@ def processSpatialEntities(review, nlp):
|
|
| 135 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
| 136 |
spaBert_emb = get_spaBert_embedding(text)
|
| 137 |
token_embeddings.append(spaBert_emb)
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
| 141 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
|
@@ -246,14 +250,14 @@ selected_review = example_reviews[user_input]
|
|
| 246 |
if st.button("Highlight Geo-Entities"):
|
| 247 |
if selected_review.strip():
|
| 248 |
bert_embedding = get_bert_embedding(selected_review)
|
| 249 |
-
st.write("Review Embedding Shape:", bert_embedding.shape)
|
| 250 |
-
|
| 251 |
spaBert_embedding = processSpatialEntities(selected_review,nlp)
|
| 252 |
-
st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
|
| 253 |
-
|
| 254 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
prediction = get_prediction(combined_embedding)
|
| 259 |
|
|
|
|
| 12 |
from PIL import Image
|
| 13 |
|
| 14 |
device = torch.device('cpu')
|
| 15 |
+
dev_mode = True
|
| 16 |
|
| 17 |
|
| 18 |
#Spacy Initialization Section
|
|
|
|
| 118 |
def get_spaBert_embedding(entity):
|
| 119 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
| 120 |
if entity_index is None:
|
| 121 |
+
if(dev_mode == True):
|
| 122 |
+
st.write("Got Bert embedding for: ", entity)
|
| 123 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
| 124 |
else:
|
| 125 |
+
if(dev_mode == True):
|
| 126 |
+
st.write("Got SpaBert embedding for: ", entity)
|
| 127 |
return spaBERT_embeddings[entity_index]
|
| 128 |
|
| 129 |
|
|
|
|
| 138 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
| 139 |
spaBert_emb = get_spaBert_embedding(text)
|
| 140 |
token_embeddings.append(spaBert_emb)
|
| 141 |
+
if(dev_mode == True)
|
| 142 |
+
st.write("Geo-Entity Found in review: ", text)
|
| 143 |
|
| 144 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
| 145 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
|
|
|
| 250 |
if st.button("Highlight Geo-Entities"):
|
| 251 |
if selected_review.strip():
|
| 252 |
bert_embedding = get_bert_embedding(selected_review)
|
|
|
|
|
|
|
| 253 |
spaBert_embedding = processSpatialEntities(selected_review,nlp)
|
|
|
|
|
|
|
| 254 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
| 255 |
+
|
| 256 |
+
if(dev_mode == True):
|
| 257 |
+
st.write("Review Embedding Shape:", bert_embedding.shape)
|
| 258 |
+
st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
|
| 259 |
+
st.write("Concatenated Embedding Shape:", combined_embedding.shape)
|
| 260 |
+
st.write("Concatenated Embedding:", combined_embedding)
|
| 261 |
|
| 262 |
prediction = get_prediction(combined_embedding)
|
| 263 |
|