Albin Thörn Cleland
Clean initial commit with LFS
19b8775
from abc import ABC, abstractmethod
import logging
import torch
import torch.nn as nn
from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
"""
A base classifier type
Currently, has the ability to process text or other inputs in a manner
suitable for the particular model type.
In other words, the CNNClassifier processes lists of words,
and the ConstituencyClassifier processes trees
"""
logger = logging.getLogger('stanza')
class BaseClassifier(ABC, nn.Module):
@abstractmethod
def extract_sentences(self, doc):
"""
Extract the sentences or the relevant information in the sentences from a document
"""
def preprocess_sentences(self, sentences):
"""
By default, don't do anything
"""
return sentences
def label_sentences(self, sentences, batch_size=None):
"""
Given a list of sentences, return the model's results on that text.
"""
self.eval()
sentences = self.preprocess_sentences(sentences)
if batch_size is None:
intervals = [(0, len(sentences))]
orig_idx = None
else:
sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True)
intervals = split_into_batches(sentences, batch_size)
labels = []
for interval in intervals:
if interval[1] - interval[0] == 0:
# this can happen for empty text
continue
output = self(sentences[interval[0]:interval[1]])
predicted = torch.argmax(output, dim=1)
labels.extend(predicted.tolist())
if orig_idx:
sentences = unsort(sentences, orig_idx)
labels = unsort(labels, orig_idx)
logger.debug("Found labels")
for (label, sentence) in zip(labels, sentences):
logger.debug((label, sentence))
return labels