ValadisCERTH commited on
Commit
13d7c9d
·
1 Parent(s): 99568fb

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +112 -19
helper.py CHANGED
@@ -11,7 +11,6 @@ import re
11
  from transformers import BertTokenizer, BertModel
12
  import torch
13
 
14
-
15
  # initial loads
16
 
17
  # load the spacy model
@@ -19,8 +18,8 @@ spacy.cli.download("en_core_web_lg")
19
  nlp = spacy.load("en_core_web_lg")
20
 
21
  # load the pre-trained BERT tokenizer and model
22
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
23
- model = BertModel.from_pretrained('bert-base-uncased')
24
 
25
  # Load valid city names from geonamescache
26
  gc = geonamescache.GeonamesCache()
@@ -54,7 +53,7 @@ def is_country(reference):
54
 
55
  def is_city(reference):
56
  """
57
- Check if the given reference is a valid city name
58
  """
59
 
60
  # Check if the reference is a valid city name
@@ -77,7 +76,7 @@ def is_city(reference):
77
  return True
78
 
79
  return False
80
-
81
 
82
  def validate_locations(locations):
83
  """
@@ -87,20 +86,28 @@ def validate_locations(locations):
87
  validated_loc = []
88
 
89
  for location in locations:
 
 
90
  if is_city(location):
91
  validated_loc.append((location, 'city'))
 
 
92
  elif is_country(location):
93
  validated_loc.append((location, 'country'))
 
94
  else:
95
  # Check if the location is a multi-word name
96
  words = location.split()
97
  if len(words) > 1:
 
98
  # Try to find the country or city name among the words
99
  for i in range(len(words)):
100
  name = ' '.join(words[i:])
 
101
  if is_country(name):
102
  validated_loc.append((name, 'country'))
103
  break
 
104
  elif is_city(name):
105
  validated_loc.append((name, 'city'))
106
  break
@@ -120,10 +127,11 @@ def identify_loc_ner(sentence):
120
  # GPE and LOC are the labels for location entities in spaCy
121
  for ent in doc.ents:
122
  if ent.label_ in ['GPE', 'LOC']:
 
123
  if len(ent.text.split()) > 1:
124
  ner_locations.append(ent.text)
125
  else:
126
- for token in ent:
127
  if token.ent_type_ == 'GPE':
128
  ner_locations.append(ent.text)
129
  break
@@ -187,7 +195,7 @@ def identify_loc_regex(sentence):
187
 
188
  regex_locations = []
189
 
190
- # Country references can be preceded by 'in', 'from' or 'of'
191
  pattern = r"\b(in|from|of)\b\s([\w\s]+)"
192
  additional_refs = re.findall(pattern, sentence)
193
 
@@ -246,8 +254,6 @@ def identify_locations(sentence):
246
 
247
  locations = []
248
 
249
- # add all the identified country/cities results in a list
250
-
251
  try:
252
 
253
  # ner
@@ -272,24 +278,111 @@ def identify_locations(sentence):
272
  # flatten the embeddings list
273
  locations_flat_3 = list(flatten(locations))
274
 
275
- # acquire the unique country/city names (because it is possible that many different approaches will capture the same countries/cities)
276
- flat_loc_list = set(locations_flat_3)
277
-
 
 
 
 
278
  # validate that indeed each one of the countries/cities are indeed countries/cities
279
- validated_locations = validate_locations(flat_loc_list)
280
 
281
  # create a proper dictionary with country/city tags and the relevant entries as a result
282
  locations_dict = {}
283
-
284
  for location, loc_type in validated_locations:
285
  if loc_type not in locations_dict:
286
  locations_dict[loc_type] = []
287
  locations_dict[loc_type].append(location)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- return locations_dict
 
 
290
 
291
- except:
292
 
293
- # handle the exception if any errors occur while identifying a country/city
294
- print(f"An error occurred while checking if a city or country exists")
295
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from transformers import BertTokenizer, BertModel
12
  import torch
13
 
 
14
  # initial loads
15
 
16
  # load the spacy model
 
18
  nlp = spacy.load("en_core_web_lg")
19
 
20
  # load the pre-trained BERT tokenizer and model
21
+ tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
22
+ model = BertModel.from_pretrained('bert-base-cased')
23
 
24
  # Load valid city names from geonamescache
25
  gc = geonamescache.GeonamesCache()
 
53
 
54
  def is_city(reference):
55
  """
56
+ Check if a given reference is a valid city name
57
  """
58
 
59
  # Check if the reference is a valid city name
 
76
  return True
77
 
78
  return False
79
+
80
 
81
  def validate_locations(locations):
82
  """
 
86
  validated_loc = []
87
 
88
  for location in locations:
89
+
90
+ # validate whether it is a city
91
  if is_city(location):
92
  validated_loc.append((location, 'city'))
93
+
94
+ # validate whether it is a country
95
  elif is_country(location):
96
  validated_loc.append((location, 'country'))
97
+
98
  else:
99
  # Check if the location is a multi-word name
100
  words = location.split()
101
  if len(words) > 1:
102
+
103
  # Try to find the country or city name among the words
104
  for i in range(len(words)):
105
  name = ' '.join(words[i:])
106
+
107
  if is_country(name):
108
  validated_loc.append((name, 'country'))
109
  break
110
+
111
  elif is_city(name):
112
  validated_loc.append((name, 'city'))
113
  break
 
127
  # GPE and LOC are the labels for location entities in spaCy
128
  for ent in doc.ents:
129
  if ent.label_ in ['GPE', 'LOC']:
130
+
131
  if len(ent.text.split()) > 1:
132
  ner_locations.append(ent.text)
133
  else:
134
+ for token in ent:
135
  if token.ent_type_ == 'GPE':
136
  ner_locations.append(ent.text)
137
  break
 
195
 
196
  regex_locations = []
197
 
198
+ # Country and cities references can be preceded by 'in', 'from' or 'of'
199
  pattern = r"\b(in|from|of)\b\s([\w\s]+)"
200
  additional_refs = re.findall(pattern, sentence)
201
 
 
254
 
255
  locations = []
256
 
 
 
257
  try:
258
 
259
  # ner
 
278
  # flatten the embeddings list
279
  locations_flat_3 = list(flatten(locations))
280
 
281
+ # remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
282
+ # Lowercase the words and get their unique references using set()
283
+ loc_unique = set([loc.lower() for loc in locations_flat_3])
284
+
285
+ # Create a new list of locations with initial capitalization, removing duplicates
286
+ loc_capitalization = list(set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in locations_flat_3]))
287
+
288
  # validate that indeed each one of the countries/cities are indeed countries/cities
289
+ validated_locations = validate_locations(loc_capitalization)
290
 
291
  # create a proper dictionary with country/city tags and the relevant entries as a result
292
  locations_dict = {}
 
293
  for location, loc_type in validated_locations:
294
  if loc_type not in locations_dict:
295
  locations_dict[loc_type] = []
296
  locations_dict[loc_type].append(location)
297
+
298
+ # conditions for multiple references
299
+ # it is mandatory that a country will exist
300
+ if locations_dict['country']:
301
+
302
+ # if a city exists
303
+ if 'city' in locations_dict:
304
+
305
+ # we accept one country and one city
306
+ if len(locations_dict['country']) == 1 and len(locations_dict['city']) == 1:
307
+
308
+ # capitalize because there may be cases that it will return 'italy'
309
+ locations_dict['country'][0] = locations_dict['country'][0].capitalize()
310
+ return locations_dict
311
+
312
+ # we can accept an absence of city but a country is always mandatory
313
+ elif len(locations_dict['country']) == 1 and len(locations_dict['city']) == 0:
314
+ locations_dict['country'][0] = locations_dict['country'][0].capitalize()
315
+ return locations_dict
316
+
317
+ # error if more than one country or city
318
+ else:
319
+ return (0, "LOCATION", "more_city_or_country")
320
+
321
+
322
+ # if a city does not exist
323
+ else:
324
+ # we only accept for one country
325
+ if len(locations_dict['country']) == 1:
326
+ locations_dict['country'][0] = locations_dict['country'][0].capitalize()
327
+ return locations_dict
328
+
329
+ # error if more than one country
330
+ else:
331
+ return (0, "LOCATION", "more_country")
332
+
333
+ # error if no country is referred
334
+ else:
335
+ return (0, "LOCATION", "no_country")
336
 
337
+ except:
338
+ # handle the exception if any errors occur while identifying a country/city
339
+ return (0, "LOCATION", "unknown_error")
340
 
 
341
 
342
+ def identify_locations2(sentence):
343
+ """
344
+ Identify all the possible Country and City references in the given sentence, using different approaches in a hybrid manner
345
+ """
346
+
347
+ locations = []
348
+
349
+ # ner
350
+ locations.append(identify_loc_ner(sentence))
351
+
352
+ # geoparse libs
353
+ geoparse_list, countries, cities = identify_loc_geoparselibs(sentence)
354
+ locations.append(geoparse_list)
355
+
356
+ # flatten the geoparse list
357
+ locations_flat_1 = list(flatten(locations))
358
+
359
+ # regex
360
+ locations_flat_1.append(identify_loc_regex(sentence))
361
+
362
+ # flatten the regex list
363
+ locations_flat_2 = list(flatten(locations))
364
+
365
+ # embeddings
366
+ locations_flat_2.append(identify_loc_embeddings(sentence, countries, cities))
367
+
368
+ # flatten the embeddings list
369
+ locations_flat_3 = list(flatten(locations))
370
+
371
+ # remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
372
+ # Lowercase the words and get their unique references using set()
373
+ loc_unique = set([loc.lower() for loc in locations_flat_3])
374
+
375
+ # Create a new list of locations with initial capitalization, removing duplicates
376
+ loc_capitalization = list(set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in locations_flat_3]))
377
+
378
+ # validate that indeed each one of the countries/cities are indeed countries/cities
379
+ validated_locations = validate_locations(loc_capitalization)
380
+
381
+ # create a proper dictionary with country/city tags and the relevant entries as a result
382
+ locations_dict = {}
383
+ for location, loc_type in validated_locations:
384
+ if loc_type not in locations_dict:
385
+ locations_dict[loc_type] = []
386
+ locations_dict[loc_type].append(location)
387
+
388
+ return locations_dict