File size: 8,098 Bytes
c2d6002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
import json
import torch
import argparse
from transformers import AutoTokenizer
from modeling_custom import BertForTokenClassificationWithSixO
from labeler import Labeler


class ModelPipeline:
    """
    Pipeline for text processing using the BertForTokenClassificationWithFourO model.
    Handles preprocessing, inference, and postprocessing.
    """

    def __init__(self, model_path=None):
        """
        Initialize the pipeline with model and tokenizer.

        Args:
            model_path: Path to the model directory. Defaults to current directory.
        """
        # Use current directory if no path specified
        if model_path is None:
            model_path = os.path.dirname(os.path.abspath(__file__))

        # Load tokenizer and model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = BertForTokenClassificationWithSixO.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()

        # Load config for any custom settings
        config_path = os.path.join(model_path, "config.json")
        with open(config_path, 'r') as f:
            self.config = json.load(f)

        # Initialize labeler for postprocessing
        self.labeler = Labeler(tags=(1, 2),
                               regexes=(r'[^\S\r\n\v\f]', r'\u200c'),
                               chars=(" ", "‌"),
                               class_count=2)

    def _tokenize(self, chars, chunk_size=512):
        input_ids = []
        attention_masks = []

        ids_ = [101] + [self.tokenizer.encode(char)[1] for char in chars] + [102]
        for i in range(0, len(ids_), chunk_size):
            chunked_ids = ids_[i:i + chunk_size]
            attention_mask = [1] * len(chunked_ids)

            if len(chunked_ids) != chunk_size:
                attention_mask += [0] * (chunk_size - len(chunked_ids))  # padding the attention mask accordingly
                chunked_ids += [0] * (chunk_size - len(chunked_ids))  # padding the last chunk to chunk size

            input_ids.append(chunked_ids)
            attention_masks.append(attention_mask)

        return input_ids, attention_masks

    def preprocess(self, text):
        chars, labels = self.labeler.label_text(text, corpus_type=Labeler.WHOLE_RAW)
        input_ids, attention_mask = self._tokenize(chars[0], chunk_size=512)
        input_ids = torch.tensor(input_ids)
        attention_mask = torch.tensor(attention_mask)
        labels = [0] + labels[0] + [0]
        return {"input_ids": input_ids, "attention_mask": attention_mask}, chars, labels

    def predict(self, encoded_inputs):
        """
        Run model inference on preprocessed inputs.

        Args:
            encoded_inputs: Tokenized inputs from preprocess method

        Returns:
            Model predictions
        """
        # Move input tensors to the same device as the model
        input_ids = encoded_inputs["input_ids"].to(self.device)
        attention_mask = encoded_inputs["attention_mask"].to(self.device)

        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

        # Get predicted class for each token
        predictions = torch.argmax(outputs.logits, dim=-1)
        return predictions, outputs.logits

    def postprocess_spacing(self, predictions, chars):
        # Move predictions back to CPU if needed for postprocessing
        predictions_cpu = predictions.cpu()
        predictions_flat = [label for sample in predictions_cpu.tolist() for label in sample]
        return self.labeler.text_generator([' '] + chars[0] + [' '],
                                           predictions_flat[:len(chars[0]) + 2],
                                           corpus_type=Labeler.WHOLE_RAW).strip()

    def postprocess_correcting(self, logits, chars, labels, alpha):
        # Process the labels to match the sequence length

        sequence_length = logits.size(1)
        if len(labels) < sequence_length:
            labels.extend([0] * (sequence_length - len(labels)))  # Padding with 0

        else:
            labels = labels[:sequence_length]  # Truncate if longer

        # Convert labels to one-hot encoding
        num_classes = logits.size(-1)
        user_labels_tensor = torch.tensor([labels], device=self.device)
        user_labels_one_hot = torch.nn.functional.one_hot(user_labels_tensor, num_classes=num_classes).float()

        # Expand dimensions to match logits shape if needed
        if logits.dim() > user_labels_one_hot.dim():
            user_labels_one_hot = user_labels_one_hot.unsqueeze(0)

        # Combine logits and user labels
        combined_logits = logits * (1 - alpha) + user_labels_one_hot * alpha

        # Apply softmax to get probabilities
        combined_probs = torch.nn.functional.softmax(combined_logits, dim=-1)

        # Get the final predictions
        final_predictions = torch.argmax(combined_probs, dim=-1)

        # Move to CPU for postprocessing
        final_predictions_cpu = final_predictions.cpu()

        # Flatten predictions
        predictions_flat = [label for sample in final_predictions_cpu.tolist() for label in sample]

        # Generate text with combined predictions
        return self.labeler.text_generator([' '] + chars[0] + [' '],
                                           predictions_flat[:len(chars[0]) + 2],
                                           corpus_type=Labeler.WHOLE_RAW).strip()

    def process_text(self, text, mode):
        """
        Process text through the entire pipeline.

        Args:
            text: Input text as string or list of strings
            mode: Processing mode ('space' or 'correct')

        Returns:
            Processed text with labels applied
        """
        # Run the full pipeline
        encoded_inputs, chars, labels = self.preprocess(text)

        predictions, logits = self.predict(encoded_inputs)

        if mode == 'space':
            result = self.postprocess_spacing(predictions, chars)
        elif mode == 'correct':
            result = self.postprocess_correcting(logits, chars, labels, 0.5)
        else:
            raise ValueError(f"Unrecognized mode: {mode}")
        return result


def main():
    """Command line interface for the model pipeline."""
    parser = argparse.ArgumentParser(description="Process text using token classification model")
    parser.add_argument("--text", type=str, help="Text to process")
    parser.add_argument("--file", type=str, help="Path to file containing text to process")
    parser.add_argument("--output", type=str, help="Path to output file")
    parser.add_argument("--mode", type=str, choices=['space', 'correct'], default='space',
                        help="Processing mode: 'space' uses model outputs only, 'correct' combines model results with "
                             "original text spacing")
    parser.add_argument("--model_path", type=str, default=None, help="Path to model directory")
    args = parser.parse_args()

    # Initialize pipeline
    pipeline = ModelPipeline(args.model_path)

    # Get input text
    if args.text:
        input_text = args.text
    elif args.file:
        with open(args.file, 'r', encoding='utf-8') as f:
            input_text = f.read()
    else:
        print("Please provide either --text or --file")
        return

    # Process the text
    result = pipeline.process_text(input_text, args.mode)

    # Output results
    if args.output:
        with open(args.output, 'w', encoding='utf-8') as f:
            if isinstance(result, list):
                f.write('\n'.join(result))
            else:
                f.write(result)
    else:
        print("\nProcessed Text:")
        if isinstance(result, list):
            for item in result:
                print(item)
        else:
            print(result)


if __name__ == "__main__":
    main()