Update predictor.py
Browse files- predictor.py +7 -0
predictor.py
CHANGED
|
@@ -193,8 +193,15 @@ def predict_srl_allennlp_like_spacy(
|
|
| 193 |
|
| 194 |
return words, results
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
def main_predictor(model_path, bert_name, sentence, spacy_model="en_core_web_md"):
|
|
|
|
| 198 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 199 |
ckpt = torch.load(model_path, map_location=device)
|
| 200 |
hp = ckpt.get("hparams", ckpt.get("hyper_parameters", {}))
|
|
|
|
| 193 |
|
| 194 |
return words, results
|
| 195 |
|
| 196 |
+
def normalize_whitespace(s: str) -> str:
|
| 197 |
+
if s is None:
|
| 198 |
+
return ""
|
| 199 |
+
# strip leading/trailing spaces (incl. non-breaking etc.)
|
| 200 |
+
s = s.replace("\u00A0", " ").replace("\u2009", " ").strip()
|
| 201 |
+
return s
|
| 202 |
|
| 203 |
def main_predictor(model_path, bert_name, sentence, spacy_model="en_core_web_md"):
|
| 204 |
+
sentence = normalize_whitespace(sentence)
|
| 205 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 206 |
ckpt = torch.load(model_path, map_location=device)
|
| 207 |
hp = ckpt.get("hparams", ckpt.get("hyper_parameters", {}))
|