yeomtong commited on
Commit
b14c6f6
·
verified ·
1 Parent(s): 45de895

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +31 -30
predictor.py CHANGED
@@ -59,6 +59,35 @@ def srl_init(model_path, bert_name="bert-base-cased", spacy_model="en_core_web_m
59
  })
60
  torch.set_grad_enabled(False)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def _predict_cached(sentence):
63
  """Internal: uses cached objects set by srl_init()."""
64
  if _CACHE["model"] is None:
@@ -69,7 +98,8 @@ def _predict_cached(sentence):
69
  nlp = _CACHE["nlp"]
70
  device = "cuda" if (_CACHE["device"].type == "cuda") else "cpu"
71
 
72
-
 
73
  return predict_srl_allennlp_like_spacy(
74
  model, tokenizer, nlp, sentence, id2label,
75
  device=device, prob_threshold=0.40, top_k=None, pick_best_if_none=True
@@ -137,35 +167,6 @@ def _ensure_loaded(model_path, bert_name, spacy_model, model_cls): # NEW
137
  torch.set_grad_enabled(False)
138
 
139
 
140
- def normalize_whitespace(s: str) -> str:
141
- if s is None: return ""
142
- return s.replace("\u00A0", " ").replace("\u2009", " ").strip()
143
-
144
- def spacy_verb_indices(nlp, sentence: str):
145
- doc = nlp(sentence)
146
- return [i for i, t in enumerate(doc) if t.pos_ in ("VERB","AUX") or t.tag_.startswith("VB")]
147
-
148
- def words_and_spans_spacy(sentence: str, nlp):
149
- doc = nlp(sentence)
150
- words = [t.text for t in doc]
151
- spans = [(t.idx, t.idx + len(t.text)) for t in doc]
152
- return words, spans
153
-
154
- def bio_to_spans(tags):
155
- spans = []; i = 0
156
- while i < len(tags):
157
- t = tags[i]
158
- if t == "O" or t.endswith("-V"):
159
- i += 1; continue
160
- if t.startswith("B-"):
161
- role = t[2:]; j = i+1
162
- while j < len(tags) and tags[j] == f"I-{role}": j += 1
163
- spans.append((role, i, j-1)); i = j
164
- else:
165
- i += 1
166
- return spans
167
-
168
-
169
  @torch.no_grad()
170
  def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, device="cuda"):
171
  model.eval()
 
59
  })
60
  torch.set_grad_enabled(False)
61
 
62
+ def normalize_whitespace(s: str) -> str:
63
+ if s is None: return ""
64
+ return s.replace("\u00A0", " ").replace("\u2009", " ").strip()
65
+
66
+ def spacy_verb_indices(nlp, sentence: str):
67
+ doc = nlp(sentence)
68
+ return [i for i, t in enumerate(doc) if t.pos_ in ("VERB","AUX") or t.tag_.startswith("VB")]
69
+
70
+ def words_and_spans_spacy(sentence: str, nlp):
71
+ doc = nlp(sentence)
72
+ words = [t.text for t in doc]
73
+ spans = [(t.idx, t.idx + len(t.text)) for t in doc]
74
+ return words, spans
75
+
76
+ def bio_to_spans(tags):
77
+ spans = []; i = 0
78
+ while i < len(tags):
79
+ t = tags[i]
80
+ if t == "O" or t.endswith("-V"):
81
+ i += 1; continue
82
+ if t.startswith("B-"):
83
+ role = t[2:]; j = i+1
84
+ while j < len(tags) and tags[j] == f"I-{role}": j += 1
85
+ spans.append((role, i, j-1)); i = j
86
+ else:
87
+ i += 1
88
+ return spans
89
+
90
+
91
  def _predict_cached(sentence):
92
  """Internal: uses cached objects set by srl_init()."""
93
  if _CACHE["model"] is None:
 
98
  nlp = _CACHE["nlp"]
99
  device = "cuda" if (_CACHE["device"].type == "cuda") else "cpu"
100
 
101
+ sentence = normalize_whitespace(sentence)
102
+
103
  return predict_srl_allennlp_like_spacy(
104
  model, tokenizer, nlp, sentence, id2label,
105
  device=device, prob_threshold=0.40, top_k=None, pick_best_if_none=True
 
167
  torch.set_grad_enabled(False)
168
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  @torch.no_grad()
171
  def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, device="cuda"):
172
  model.eval()