|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort |
|
|
|
|
|
""" |
|
|
A base classifier type |
|
|
|
|
|
Currently, has the ability to process text or other inputs in a manner |
|
|
suitable for the particular model type. |
|
|
In other words, the CNNClassifier processes lists of words, |
|
|
and the ConstituencyClassifier processes trees |
|
|
""" |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
class BaseClassifier(ABC, nn.Module): |
|
|
@abstractmethod |
|
|
def extract_sentences(self, doc): |
|
|
""" |
|
|
Extract the sentences or the relevant information in the sentences from a document |
|
|
""" |
|
|
|
|
|
def preprocess_sentences(self, sentences): |
|
|
""" |
|
|
By default, don't do anything |
|
|
""" |
|
|
return sentences |
|
|
|
|
|
def label_sentences(self, sentences, batch_size=None): |
|
|
""" |
|
|
Given a list of sentences, return the model's results on that text. |
|
|
""" |
|
|
self.eval() |
|
|
|
|
|
sentences = self.preprocess_sentences(sentences) |
|
|
|
|
|
if batch_size is None: |
|
|
intervals = [(0, len(sentences))] |
|
|
orig_idx = None |
|
|
else: |
|
|
sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True) |
|
|
intervals = split_into_batches(sentences, batch_size) |
|
|
labels = [] |
|
|
for interval in intervals: |
|
|
if interval[1] - interval[0] == 0: |
|
|
|
|
|
continue |
|
|
output = self(sentences[interval[0]:interval[1]]) |
|
|
predicted = torch.argmax(output, dim=1) |
|
|
labels.extend(predicted.tolist()) |
|
|
|
|
|
if orig_idx: |
|
|
sentences = unsort(sentences, orig_idx) |
|
|
labels = unsort(labels, orig_idx) |
|
|
|
|
|
logger.debug("Found labels") |
|
|
for (label, sentence) in zip(labels, sentences): |
|
|
logger.debug((label, sentence)) |
|
|
|
|
|
return labels |
|
|
|