|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import argparse |
|
|
from transformers import AutoTokenizer |
|
|
from modeling_custom import BertForTokenClassificationWithFiveO |
|
|
from labeler import Labeler |
|
|
|
|
|
|
|
|
class ModelPipeline: |
|
|
""" |
|
|
Pipeline for text processing using the BertForTokenClassificationWithFourO model. |
|
|
Handles preprocessing, inference, and postprocessing. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path=None): |
|
|
""" |
|
|
Initialize the pipeline with model and tokenizer. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model directory. Defaults to current directory. |
|
|
""" |
|
|
|
|
|
if model_path is None: |
|
|
model_path = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.model = BertForTokenClassificationWithFiveO.from_pretrained(model_path) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
config_path = os.path.join(model_path, "config.json") |
|
|
with open(config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
|
|
|
self.labeler = Labeler(tags=(1, 2), |
|
|
regexes=(r'[^\S\r\n\v\f]', r'\u200c'), |
|
|
chars=(" ", ""), |
|
|
class_count=2) |
|
|
|
|
|
def _tokenize(self, chars, chunk_size=512): |
|
|
input_ids = [] |
|
|
attention_masks = [] |
|
|
|
|
|
ids_ = [101] + [self.tokenizer.encode(char)[1] for char in chars] + [102] |
|
|
for i in range(0, len(ids_), chunk_size): |
|
|
chunked_ids = ids_[i:i + chunk_size] |
|
|
attention_mask = [1] * len(chunked_ids) |
|
|
|
|
|
if len(chunked_ids) != chunk_size: |
|
|
attention_mask += [0] * (chunk_size - len(chunked_ids)) |
|
|
chunked_ids += [0] * (chunk_size - len(chunked_ids)) |
|
|
|
|
|
input_ids.append(chunked_ids) |
|
|
attention_masks.append(attention_mask) |
|
|
|
|
|
return input_ids, attention_masks |
|
|
|
|
|
def preprocess(self, text): |
|
|
chars, labels = self.labeler.label_text(text, corpus_type=Labeler.WHOLE_RAW) |
|
|
input_ids, attention_mask = self._tokenize(chars[0], chunk_size=512) |
|
|
input_ids = torch.tensor(input_ids) |
|
|
attention_mask = torch.tensor(attention_mask) |
|
|
labels = [0] + labels[0] + [0] |
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask}, chars, labels |
|
|
|
|
|
def predict(self, encoded_inputs): |
|
|
""" |
|
|
Run model inference on preprocessed inputs. |
|
|
|
|
|
Args: |
|
|
encoded_inputs: Tokenized inputs from preprocess method |
|
|
|
|
|
Returns: |
|
|
Model predictions |
|
|
""" |
|
|
|
|
|
input_ids = encoded_inputs["input_ids"].to(self.device) |
|
|
attention_mask = encoded_inputs["attention_mask"].to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
predictions = torch.argmax(outputs.logits, dim=-1) |
|
|
return predictions, outputs.logits |
|
|
|
|
|
def postprocess_spacing(self, predictions, chars): |
|
|
|
|
|
predictions_cpu = predictions.cpu() |
|
|
predictions_flat = [label for sample in predictions_cpu.tolist() for label in sample] |
|
|
return self.labeler.text_generator([' '] + chars[0] + [' '], |
|
|
predictions_flat[:len(chars[0]) + 2], |
|
|
corpus_type=Labeler.WHOLE_RAW).strip() |
|
|
|
|
|
def postprocess_correcting(self, logits, chars, labels, alpha): |
|
|
|
|
|
|
|
|
sequence_length = logits.size(1) |
|
|
if len(labels) < sequence_length: |
|
|
labels.extend([0] * (sequence_length - len(labels))) |
|
|
|
|
|
else: |
|
|
labels = labels[:sequence_length] |
|
|
|
|
|
|
|
|
num_classes = logits.size(-1) |
|
|
user_labels_tensor = torch.tensor([labels], device=self.device) |
|
|
user_labels_one_hot = torch.nn.functional.one_hot(user_labels_tensor, num_classes=num_classes).float() |
|
|
|
|
|
|
|
|
if logits.dim() > user_labels_one_hot.dim(): |
|
|
user_labels_one_hot = user_labels_one_hot.unsqueeze(0) |
|
|
|
|
|
|
|
|
combined_logits = logits * (1 - alpha) + user_labels_one_hot * alpha |
|
|
|
|
|
|
|
|
combined_probs = torch.nn.functional.softmax(combined_logits, dim=-1) |
|
|
|
|
|
|
|
|
final_predictions = torch.argmax(combined_probs, dim=-1) |
|
|
|
|
|
|
|
|
final_predictions_cpu = final_predictions.cpu() |
|
|
|
|
|
|
|
|
predictions_flat = [label for sample in final_predictions_cpu.tolist() for label in sample] |
|
|
|
|
|
|
|
|
return self.labeler.text_generator([' '] + chars[0] + [' '], |
|
|
predictions_flat[:len(chars[0]) + 2], |
|
|
corpus_type=Labeler.WHOLE_RAW).strip() |
|
|
|
|
|
def process_text(self, text, mode): |
|
|
""" |
|
|
Process text through the entire pipeline. |
|
|
|
|
|
Args: |
|
|
text: Input text as string or list of strings |
|
|
mode: Processing mode ('space' or 'correct') |
|
|
|
|
|
Returns: |
|
|
Processed text with labels applied |
|
|
""" |
|
|
|
|
|
encoded_inputs, chars, labels = self.preprocess(text) |
|
|
|
|
|
predictions, logits = self.predict(encoded_inputs) |
|
|
|
|
|
if mode == 'space': |
|
|
result = self.postprocess_spacing(predictions, chars) |
|
|
elif mode == 'correct': |
|
|
result = self.postprocess_correcting(logits, chars, labels, 0.5) |
|
|
else: |
|
|
raise ValueError(f"Unrecognized mode: {mode}") |
|
|
return result |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Command line interface for the model pipeline.""" |
|
|
parser = argparse.ArgumentParser(description="Process text using token classification model") |
|
|
parser.add_argument("--text", type=str, help="Text to process") |
|
|
parser.add_argument("--file", type=str, help="Path to file containing text to process") |
|
|
parser.add_argument("--output", type=str, help="Path to output file") |
|
|
parser.add_argument("--mode", type=str, choices=['space', 'correct'], default='space', |
|
|
help="Processing mode: 'space' uses model outputs only, 'correct' combines model results with " |
|
|
"original text spacing") |
|
|
parser.add_argument("--model_path", type=str, default=None, help="Path to model directory") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
pipeline = ModelPipeline(args.model_path) |
|
|
|
|
|
|
|
|
if args.text: |
|
|
input_text = args.text |
|
|
elif args.file: |
|
|
with open(args.file, 'r', encoding='utf-8') as f: |
|
|
input_text = f.read() |
|
|
else: |
|
|
print("Please provide either --text or --file") |
|
|
return |
|
|
|
|
|
|
|
|
result = pipeline.process_text(input_text, args.mode) |
|
|
|
|
|
|
|
|
if args.output: |
|
|
with open(args.output, 'w', encoding='utf-8') as f: |
|
|
if isinstance(result, list): |
|
|
f.write('\n'.join(result)) |
|
|
else: |
|
|
f.write(result) |
|
|
else: |
|
|
print("\nProcessed Text:") |
|
|
if isinstance(result, list): |
|
|
for item in result: |
|
|
print(item) |
|
|
else: |
|
|
print(result) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|