ValadisCERTH's picture
Create helper.py
2ae3a4d
raw
history blame
9.11 kB
import spacy
from geopy.geocoders import Nominatim
import geonamescache
import pycountry
from geotext import GeoText
import re
from transformers import BertTokenizer, BertModel
import torch
# initial loads
# load the spacy model
spacy.cli.download("en_core_web_lg")
nlp = spacy.load("en_core_web_lg")
# load the pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# Load valid city names from geonamescache
gc = geonamescache.GeonamesCache()
city_names = set([city['name'] for city in gc.get_cities().values()])
def flatten(lst):
"""
Define a helper function to flatten the list recursively
"""
for item in lst:
if isinstance(item, list):
yield from flatten(item)
else:
yield item
def is_country(reference):
"""
Check if a given reference is a valid country name
"""
try:
# use the pycountry library to verify if an input is a country
country = pycountry.countries.search_fuzzy(reference)[0]
return True
except LookupError:
return False
def is_city(reference):
"""
Check if the given reference is a valid city name
"""
# Check if the reference is a valid city name
if reference in city_names:
return True
# Load the Nomatim (open street maps) api
geolocator = Nominatim(user_agent="certh_serco_validate_city_app")
location = geolocator.geocode(reference, language="en")
# If a reference is identified as a 'city', 'town', or 'village', then it is indeed a city
if location.raw['type'] in ['city', 'town', 'village']:
return True
# If a reference is identified as 'administrative' (e.g. administrative area),
# then we further examine if the retrieved info is a single token (meaning a country) or a series of tokens (meaning a city)
# that condition takes place to separate some cases where small cities were identified as administrative areas
elif location.raw['type'] == 'administrative':
if len(location.raw['display_name'].split(",")) > 1:
return True
return False
def validate_locations(locations):
"""
Validate that the identified references are indeed a Country and a City
"""
validated_loc = []
for location in locations:
if is_city(location):
validated_loc.append((location, 'city'))
elif is_country(location):
validated_loc.append((location, 'country'))
else:
# Check if the location is a multi-word name
words = location.split()
if len(words) > 1:
# Try to find the country or city name among the words
for i in range(len(words)):
name = ' '.join(words[i:])
if is_country(name):
validated_loc.append((name, 'country'))
break
elif is_city(name):
validated_loc.append((name, 'city'))
break
return validated_loc
def identify_loc_ner(sentence):
"""
Identify all the geopolitical and location entities with the spacy tool
"""
doc = nlp(sentence)
ner_locations = []
# GPE and LOC are the labels for location entities in spaCy
for ent in doc.ents:
if ent.label_ in ['GPE', 'LOC']:
if len(ent.text.split()) > 1:
ner_locations.append(ent.text)
else:
for token in ent:
if token.ent_type_ == 'GPE':
ner_locations.append(ent.text)
break
return ner_locations
def identify_loc_geoparselibs(sentence):
"""
Identify cities and countries with 3 different geoparsing libraries
"""
geoparse_locations = []
# Geoparsing library 1
# Load geonames cache to check if a city name is valid
gc = geonamescache.GeonamesCache()
# Get a list of many countries/cities
countries = gc.get_countries()
cities = gc.get_cities()
city_names = [city['name'] for city in cities.values()]
country_names = [country['name'] for country in countries.values()]
# if any word sequence in our sentence is one of those countries/cities identify it
words = sentence.split()
for i in range(len(words)):
for j in range(i+1, len(words)+1):
word_seq = ' '.join(words[i:j])
if word_seq in city_names or word_seq in country_names:
geoparse_locations.append(word_seq)
# Geoparsing library 2
# similarly with the pycountry library
for country in pycountry.countries:
if country.name in sentence:
geoparse_locations.append(country.name)
# Geoparsing library 3
# similarly with the geotext library
places = GeoText(sentence)
cities = list(places.cities)
countries = list(places.countries)
if cities:
geoparse_locations += cities
if countries:
geoparse_locations += countries
return (geoparse_locations, countries, cities)
def identify_loc_regex(sentence):
"""
Identify cities and countries with regular expression matching
"""
regex_locations = []
# Country references can be preceded by 'in', 'from' or 'of'
pattern = r"\b(in|from|of)\b\s([\w\s]+)"
additional_refs = re.findall(pattern, sentence)
for match in additional_refs:
regex_locations.append(match[1])
return regex_locations
def identify_loc_embeddings(sentence, countries, cities):
"""
Identify cities and countries with the BERT pre-trained embeddings matching
"""
embd_locations = []
# Define a list of country and city names (those are given by the geonamescache library before)
countries_cities = countries + cities
# Concatenate multi-word countries and cities into a single string
multiword_countries = [c.replace(' ', '_') for c in countries if ' ' in c]
multiword_cities = [c.replace(' ', '_') for c in cities if ' ' in c]
countries_cities += multiword_countries + multiword_cities
# Preprocess the input sentence
tokens = tokenizer.tokenize(sentence)
input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
# Get the BERT embeddings for the input sentence
with torch.no_grad():
embeddings = model(input_ids)[0][0]
# Find the country and city names in the input sentence
for i in range(len(tokens)):
token = tokens[i]
if token in countries_cities:
embd_locations.append(token)
else:
word_vector = embeddings[i]
similarity_scores = torch.nn.functional.cosine_similarity(word_vector.unsqueeze(0), embeddings)
similar_tokens = [tokens[j] for j in similarity_scores.argsort(descending=True)[1:6]]
for word in similar_tokens:
if word in countries_cities and similarity_scores[tokens.index(word)] > 0.5:
embd_locations.append(word)
# Convert back multi-word country and city names to original form
embd_locations = [loc.replace('_', ' ') if '_' in loc else loc for loc in embd_locations]
return embd_locations
def identify_locations(sentence):
"""
Identify all the possible Country and City references in the given sentence, using different approaches in a hybrid manner
"""
locations = []
# add all the identified country/cities results in a list
try:
# ner
locations.append(identify_loc_ner(sentence))
# geoparse libs
geoparse_list, countries, cities = identify_loc_geoparselibs(sentence)
locations.append(geoparse_list)
# flatten the geoparse list
locations_flat_1 = list(flatten(locations))
# regex
locations_flat_1.append(identify_loc_regex(sentence))
# flatten the regex list
locations_flat_2 = list(flatten(locations))
# embeddings
locations_flat_2.append(identify_loc_embeddings(sentence, countries, cities))
# flatten the embeddings list
locations_flat_3 = list(flatten(locations))
# acquire the unique country/city names (because it is possible that many different approaches will capture the same countries/cities)
flat_loc_list = set(locations_flat_3)
# validate that indeed each one of the countries/cities are indeed countries/cities
validated_locations = validate_locations(flat_loc_list)
# create a proper dictionary with country/city tags and the relevant entries as a result
locations_dict = {}
for location, loc_type in validated_locations:
if loc_type not in locations_dict:
locations_dict[loc_type] = []
locations_dict[loc_type].append(location)
return locations_dict
except:
# handle the exception if any errors occur while identifying a country/city
print(f"An error occurred while checking if a city or country exists")
return ""