Spaces:
Sleeping
Sleeping
Chidam Gopal
commited on
included state and city in NER
Browse files- infer_location.py +5 -15
infer_location.py
CHANGED
|
@@ -22,20 +22,6 @@ class LocationFinder:
|
|
| 22 |
# Load the ONNX model
|
| 23 |
self.ort_session = ort.InferenceSession(model_path)
|
| 24 |
|
| 25 |
-
# State abbreviations list for post-processing
|
| 26 |
-
self.state_abbr = {
|
| 27 |
-
"AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY",
|
| 28 |
-
"LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND",
|
| 29 |
-
"OH", "OK", "OR", "PA", "RI", "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
# # Helper function to correct misclassified state abbreviations
|
| 33 |
-
# def correct_state_abbreviation(self, token, predicted_label):
|
| 34 |
-
# if token.upper() in self.state_abbr and predicted_label == "I-CITY":
|
| 35 |
-
# return "I-STATE"
|
| 36 |
-
# return predicted_label
|
| 37 |
-
|
| 38 |
-
|
| 39 |
def find_location(self, sequence, verbose=False):
|
| 40 |
inputs = self.tokenizer(sequence,
|
| 41 |
return_tensors="np", # ONNX requires inputs in NumPy format
|
|
@@ -80,6 +66,11 @@ class LocationFinder:
|
|
| 80 |
state_entities = []
|
| 81 |
org_entities = []
|
| 82 |
city_state_entities = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
|
| 84 |
if prob > threshold:
|
| 85 |
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
|
@@ -115,7 +106,6 @@ class LocationFinder:
|
|
| 115 |
return {
|
| 116 |
'city': city_res,
|
| 117 |
'state': state_res,
|
| 118 |
-
'organization': org_res,
|
| 119 |
}
|
| 120 |
|
| 121 |
if __name__ == '__main__':
|
|
|
|
| 22 |
# Load the ONNX model
|
| 23 |
self.ort_session = ort.InferenceSession(model_path)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def find_location(self, sequence, verbose=False):
|
| 26 |
inputs = self.tokenizer(sequence,
|
| 27 |
return_tensors="np", # ONNX requires inputs in NumPy format
|
|
|
|
| 66 |
state_entities = []
|
| 67 |
org_entities = []
|
| 68 |
city_state_entities = []
|
| 69 |
+
|
| 70 |
+
city_entities = []
|
| 71 |
+
state_entities = []
|
| 72 |
+
city_state_entities = []
|
| 73 |
+
org_entities = []
|
| 74 |
for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
|
| 75 |
if prob > threshold:
|
| 76 |
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
|
|
|
| 106 |
return {
|
| 107 |
'city': city_res,
|
| 108 |
'state': state_res,
|
|
|
|
| 109 |
}
|
| 110 |
|
| 111 |
if __name__ == '__main__':
|