Spaces:
Runtime error
Runtime error
Commit ·
7cd3150
1
Parent(s): b6076c4
Update countriesIdentification.py
Browse files- countriesIdentification.py +3 -58
countriesIdentification.py
CHANGED
|
@@ -8,18 +8,11 @@ from geotext import GeoText
|
|
| 8 |
|
| 9 |
import re
|
| 10 |
|
| 11 |
-
from transformers import BertTokenizer, BertModel
|
| 12 |
-
import torch
|
| 13 |
-
|
| 14 |
spacy.cli.download("en_core_web_lg")
|
| 15 |
|
| 16 |
# Load the spacy model with GloVe embeddings
|
| 17 |
nlp = spacy.load("en_core_web_lg")
|
| 18 |
|
| 19 |
-
# load the pre-trained BERT tokenizer and model
|
| 20 |
-
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 21 |
-
model = BertModel.from_pretrained('bert-base-cased')
|
| 22 |
-
|
| 23 |
# Load valid city names from geonamescache
|
| 24 |
gc = geonamescache.GeonamesCache()
|
| 25 |
|
|
@@ -267,48 +260,6 @@ def identify_loc_regex(sentence):
|
|
| 267 |
return regex_locations
|
| 268 |
|
| 269 |
|
| 270 |
-
def identify_loc_embeddings(sentence, countries, cities):
|
| 271 |
-
"""
|
| 272 |
-
Identify cities and countries with the BERT pre-trained embeddings matching
|
| 273 |
-
"""
|
| 274 |
-
|
| 275 |
-
embd_locations = []
|
| 276 |
-
|
| 277 |
-
# Define a list of country and city names (those are given by the geonamescache library before)
|
| 278 |
-
countries_cities = countries + cities
|
| 279 |
-
|
| 280 |
-
# Concatenate multi-word countries and cities into a single string
|
| 281 |
-
multiword_countries = [c.replace(' ', '_') for c in countries if ' ' in c]
|
| 282 |
-
multiword_cities = [c.replace(' ', '_') for c in cities if ' ' in c]
|
| 283 |
-
countries_cities += multiword_countries + multiword_cities
|
| 284 |
-
|
| 285 |
-
# Preprocess the input sentence
|
| 286 |
-
tokens = tokenizer.tokenize(sentence)
|
| 287 |
-
input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
|
| 288 |
-
|
| 289 |
-
# Get the BERT embeddings for the input sentence
|
| 290 |
-
with torch.no_grad():
|
| 291 |
-
embeddings = model(input_ids)[0][0]
|
| 292 |
-
|
| 293 |
-
# Find the country and city names in the input sentence
|
| 294 |
-
for i in range(len(tokens)):
|
| 295 |
-
token = tokens[i]
|
| 296 |
-
if token in countries_cities:
|
| 297 |
-
embd_locations.append(token)
|
| 298 |
-
else:
|
| 299 |
-
word_vector = embeddings[i]
|
| 300 |
-
similarity_scores = torch.nn.functional.cosine_similarity(word_vector.unsqueeze(0), embeddings)
|
| 301 |
-
similar_tokens = [tokens[j] for j in similarity_scores.argsort(descending=True)[1:6]]
|
| 302 |
-
for word in similar_tokens:
|
| 303 |
-
if word in countries_cities and similarity_scores[tokens.index(word)] > 0.5:
|
| 304 |
-
embd_locations.append(word)
|
| 305 |
-
|
| 306 |
-
# Convert back multi-word country and city names to original form
|
| 307 |
-
embd_locations = [loc.replace('_', ' ') if '_' in loc else loc for loc in embd_locations]
|
| 308 |
-
|
| 309 |
-
return embd_locations
|
| 310 |
-
|
| 311 |
-
|
| 312 |
|
| 313 |
def multiple_country_city_identifications_solve(country_city_dict):
|
| 314 |
"""
|
|
@@ -580,19 +531,13 @@ def identify_locations(sentence):
|
|
| 580 |
# flatten the regex list
|
| 581 |
locations_flat_2 = list(flatten(locations))
|
| 582 |
|
| 583 |
-
# embeddings
|
| 584 |
-
locations_flat_2.append(identify_loc_embeddings(sentence, countries, cities))
|
| 585 |
-
|
| 586 |
-
# flatten the embeddings list
|
| 587 |
-
locations_flat_3 = list(flatten(locations))
|
| 588 |
-
|
| 589 |
# remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
|
| 590 |
# Lowercase the words and get their unique references using set()
|
| 591 |
-
loc_unique = set([loc.lower() for loc in
|
| 592 |
|
| 593 |
# Create a new list of locations with initial capitalization, removing duplicates
|
| 594 |
loc_capitalization = list(
|
| 595 |
-
set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in
|
| 596 |
|
| 597 |
# That calculation checks whether there are substrings contained in another string. E.g. for the case of [timor leste, timor], it should remove "timor"
|
| 598 |
if extra_serco_countries:
|
|
@@ -705,5 +650,5 @@ def identify_locations(sentence):
|
|
| 705 |
return (0, "LOCATION", "no_country")
|
| 706 |
|
| 707 |
except:
|
| 708 |
-
# handle the exception if any errors occur while
|
| 709 |
return (0, "LOCATION", "unknown_error")
|
|
|
|
| 8 |
|
| 9 |
import re
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
spacy.cli.download("en_core_web_lg")
|
| 12 |
|
| 13 |
# Load the spacy model with GloVe embeddings
|
| 14 |
nlp = spacy.load("en_core_web_lg")
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# Load valid city names from geonamescache
|
| 17 |
gc = geonamescache.GeonamesCache()
|
| 18 |
|
|
|
|
| 260 |
return regex_locations
|
| 261 |
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
def multiple_country_city_identifications_solve(country_city_dict):
|
| 265 |
"""
|
|
|
|
| 531 |
# flatten the regex list
|
| 532 |
locations_flat_2 = list(flatten(locations))
|
| 533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
# remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
|
| 535 |
# Lowercase the words and get their unique references using set()
|
| 536 |
+
loc_unique = set([loc.lower() for loc in locations_flat_2])
|
| 537 |
|
| 538 |
# Create a new list of locations with initial capitalization, removing duplicates
|
| 539 |
loc_capitalization = list(
|
| 540 |
+
set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in locations_flat_2]))
|
| 541 |
|
| 542 |
# That calculation checks whether there are substrings contained in another string. E.g. for the case of [timor leste, timor], it should remove "timor"
|
| 543 |
if extra_serco_countries:
|
|
|
|
| 650 |
return (0, "LOCATION", "no_country")
|
| 651 |
|
| 652 |
except:
|
| 653 |
+
# handle the exception if any errors occur while identifying a country/city
|
| 654 |
return (0, "LOCATION", "unknown_error")
|