File size: 1,938 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from abc import ABC, abstractmethod

import logging

import torch
import torch.nn as nn

from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort

"""
A base classifier type

Currently, has the ability to process text or other inputs in a manner
suitable for the particular model type.
In other words, the CNNClassifier processes lists of words,
and the ConstituencyClassifier processes trees
"""

logger = logging.getLogger('stanza')

class BaseClassifier(ABC, nn.Module):
    @abstractmethod
    def extract_sentences(self, doc):
        """
        Extract the sentences or the relevant information in the sentences from a document
        """

    def preprocess_sentences(self, sentences):
        """
        By default, don't do anything
        """
        return sentences

    def label_sentences(self, sentences, batch_size=None):
        """
        Given a list of sentences, return the model's results on that text.
        """
        self.eval()

        sentences = self.preprocess_sentences(sentences)

        if batch_size is None:
            intervals = [(0, len(sentences))]
            orig_idx = None
        else:
            sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True)
            intervals = split_into_batches(sentences, batch_size)
        labels = []
        for interval in intervals:
            if interval[1] - interval[0] == 0:
                # this can happen for empty text
                continue
            output = self(sentences[interval[0]:interval[1]])
            predicted = torch.argmax(output, dim=1)
            labels.extend(predicted.tolist())

        if orig_idx:
            sentences = unsort(sentences, orig_idx)
            labels = unsort(labels, orig_idx)

        logger.debug("Found labels")
        for (label, sentence) in zip(labels, sentences):
            logger.debug((label, sentence))

        return labels