Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,8 @@ from transformers.models.bert.modeling_bert import BertForMaskedLM
|
|
| 6 |
|
| 7 |
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
|
| 8 |
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from PIL import Image
|
| 11 |
|
|
@@ -35,6 +37,19 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
|
|
| 35 |
spaBERT_model.to(device)
|
| 36 |
spaBERT_model.eval()
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
#Get BERT Embedding for review
|
| 40 |
def get_bert_embedding(review_text):
|
|
|
|
| 6 |
|
| 7 |
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
|
| 8 |
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
|
| 9 |
+
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
|
| 12 |
from PIL import Image
|
| 13 |
|
|
|
|
| 37 |
spaBERT_model.to(device)
|
| 38 |
spaBERT_model.eval()
|
| 39 |
|
| 40 |
+
# Load data using SpatialDataset
|
| 41 |
+
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
| 42 |
+
tokenizer = tokenizer,
|
| 43 |
+
#max_token_len = 256, #Originally 300
|
| 44 |
+
max_token_len = max_seq_length, #Originally 300
|
| 45 |
+
distance_norm_factor = 0.0001,
|
| 46 |
+
spatial_dist_fill = 20,
|
| 47 |
+
with_type = False,
|
| 48 |
+
sep_between_neighbors = True, #Initially false, play around with this potentially?
|
| 49 |
+
label_encoder = None, #Initially None, potentially change this because we do have real/fake reviews.
|
| 50 |
+
mode = None) #If set to None it will use the full dataset for mlm
|
| 51 |
+
|
| 52 |
+
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
| 53 |
|
| 54 |
#Get BERT Embedding for review
|
| 55 |
def get_bert_embedding(review_text):
|