Update generic_ner.py
Browse files- generic_ner.py +31 -22
generic_ner.py
CHANGED
|
@@ -238,38 +238,47 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 238 |
return outputs, text_chunks, text
|
| 239 |
|
| 240 |
def postprocess(self, outputs, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
tokens_result, text_chunks, text = outputs
|
| 242 |
-
|
| 243 |
# Initialize variables for collecting results across chunks
|
| 244 |
predictions = {task: [] for task in self.label_map.keys()}
|
| 245 |
confidence_scores = {task: [] for task in self.label_map.keys()}
|
| 246 |
-
|
| 247 |
# Collect predictions from each chunk
|
| 248 |
for chunk_result in tokens_result:
|
| 249 |
for task, logits in chunk_result.logits.items():
|
| 250 |
predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
|
| 251 |
confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
|
| 252 |
-
|
| 253 |
-
decoded_predictions = {}
|
| 254 |
-
for task, preds in predictions.items():
|
| 255 |
-
decoded_predictions[task] = [
|
| 256 |
-
[self.id2label[task][label] for label in seq] for seq in preds
|
| 257 |
-
]
|
| 258 |
# Extract entities from the combined predictions
|
| 259 |
entities = {}
|
| 260 |
-
# print(decoded_predictions)
|
| 261 |
for task, preds in predictions.items():
|
| 262 |
-
print('preds', len(preds))
|
| 263 |
-
print('text_chunks', len(text_chunks))
|
| 264 |
-
print('confidence_scores[task]', len(confidence_scores[task]))
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
return entities
|
|
|
|
| 238 |
return outputs, text_chunks, text
|
| 239 |
|
| 240 |
def postprocess(self, outputs, **kwargs):
|
| 241 |
+
"""
|
| 242 |
+
Postprocess the outputs of the model for NER tasks.
|
| 243 |
+
|
| 244 |
+
outputs: Model predictions for all chunks
|
| 245 |
+
"""
|
| 246 |
tokens_result, text_chunks, text = outputs
|
| 247 |
+
|
| 248 |
# Initialize variables for collecting results across chunks
|
| 249 |
predictions = {task: [] for task in self.label_map.keys()}
|
| 250 |
confidence_scores = {task: [] for task in self.label_map.keys()}
|
| 251 |
+
|
| 252 |
# Collect predictions from each chunk
|
| 253 |
for chunk_result in tokens_result:
|
| 254 |
for task, logits in chunk_result.logits.items():
|
| 255 |
predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
|
| 256 |
confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
|
| 257 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Extract entities from the combined predictions
|
| 259 |
entities = {}
|
|
|
|
| 260 |
for task, preds in predictions.items():
|
| 261 |
+
print('preds', len(preds)) # Debugging line
|
| 262 |
+
print('text_chunks', len(text_chunks)) # Debugging line
|
| 263 |
+
print('confidence_scores[task]', len(confidence_scores[task])) # Debugging line
|
| 264 |
+
|
| 265 |
+
# Process each chunk individually
|
| 266 |
+
for idx, text_chunk in enumerate(text_chunks):
|
| 267 |
+
words_list, preds_list, confidence_list = realign(
|
| 268 |
+
text_chunk, # Single chunk of text
|
| 269 |
+
preds[idx * len(text_chunk.split()): (idx + 1) * len(text_chunk.split())],
|
| 270 |
+
confidence_scores[task][idx * len(text_chunk.split()): (idx + 1) * len(text_chunk.split())],
|
| 271 |
+
self.tokenizer,
|
| 272 |
+
self.id2label[task],
|
| 273 |
+
)
|
| 274 |
+
print(words_list, preds_list, confidence_list) # Debugging line
|
| 275 |
+
|
| 276 |
+
# Get entities for this chunk
|
| 277 |
+
chunk_entities = get_entities(words_list, preds_list, confidence_list, text)
|
| 278 |
+
|
| 279 |
+
# Append chunk entities to the task-level results
|
| 280 |
+
if task not in entities:
|
| 281 |
+
entities[task] = []
|
| 282 |
+
entities[task].extend(chunk_entities)
|
| 283 |
+
|
| 284 |
return entities
|