import os import json import torch import argparse from transformers import AutoTokenizer from modeling_custom import BertForTokenClassificationWithFiveO from labeler import Labeler class ModelPipeline: """ Pipeline for text processing using the BertForTokenClassificationWithFourO model. Handles preprocessing, inference, and postprocessing. """ def __init__(self, model_path=None): """ Initialize the pipeline with model and tokenizer. Args: model_path: Path to the model directory. Defaults to current directory. """ # Use current directory if no path specified if model_path is None: model_path = os.path.dirname(os.path.abspath(__file__)) # Load tokenizer and model self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = BertForTokenClassificationWithFiveO.from_pretrained(model_path) self.model.to(self.device) self.model.eval() # Load config for any custom settings config_path = os.path.join(model_path, "config.json") with open(config_path, 'r') as f: self.config = json.load(f) # Initialize labeler for postprocessing self.labeler = Labeler(tags=(1, 2), regexes=(r'[^\S\r\n\v\f]', r'\u200c'), chars=(" ", "‌"), class_count=2) def _tokenize(self, chars, chunk_size=512): input_ids = [] attention_masks = [] ids_ = [101] + [self.tokenizer.encode(char)[1] for char in chars] + [102] for i in range(0, len(ids_), chunk_size): chunked_ids = ids_[i:i + chunk_size] attention_mask = [1] * len(chunked_ids) if len(chunked_ids) != chunk_size: attention_mask += [0] * (chunk_size - len(chunked_ids)) # padding the attention mask accordingly chunked_ids += [0] * (chunk_size - len(chunked_ids)) # padding the last chunk to chunk size input_ids.append(chunked_ids) attention_masks.append(attention_mask) return input_ids, attention_masks def preprocess(self, text): chars, labels = self.labeler.label_text(text, corpus_type=Labeler.WHOLE_RAW) input_ids, attention_mask = self._tokenize(chars[0], chunk_size=512) input_ids = torch.tensor(input_ids) attention_mask = torch.tensor(attention_mask) labels = [0] + labels[0] + [0] return {"input_ids": input_ids, "attention_mask": attention_mask}, chars, labels def predict(self, encoded_inputs): """ Run model inference on preprocessed inputs. Args: encoded_inputs: Tokenized inputs from preprocess method Returns: Model predictions """ # Move input tensors to the same device as the model input_ids = encoded_inputs["input_ids"].to(self.device) attention_mask = encoded_inputs["attention_mask"].to(self.device) with torch.no_grad(): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask ) # Get predicted class for each token predictions = torch.argmax(outputs.logits, dim=-1) return predictions, outputs.logits def postprocess_spacing(self, predictions, chars): # Move predictions back to CPU if needed for postprocessing predictions_cpu = predictions.cpu() predictions_flat = [label for sample in predictions_cpu.tolist() for label in sample] return self.labeler.text_generator([' '] + chars[0] + [' '], predictions_flat[:len(chars[0]) + 2], corpus_type=Labeler.WHOLE_RAW).strip() def postprocess_correcting(self, logits, chars, labels, alpha): # Process the labels to match the sequence length sequence_length = logits.size(1) if len(labels) < sequence_length: labels.extend([0] * (sequence_length - len(labels))) # Padding with 0 else: labels = labels[:sequence_length] # Truncate if longer # Convert labels to one-hot encoding num_classes = logits.size(-1) user_labels_tensor = torch.tensor([labels], device=self.device) user_labels_one_hot = torch.nn.functional.one_hot(user_labels_tensor, num_classes=num_classes).float() # Expand dimensions to match logits shape if needed if logits.dim() > user_labels_one_hot.dim(): user_labels_one_hot = user_labels_one_hot.unsqueeze(0) # Combine logits and user labels combined_logits = logits * (1 - alpha) + user_labels_one_hot * alpha # Apply softmax to get probabilities combined_probs = torch.nn.functional.softmax(combined_logits, dim=-1) # Get the final predictions final_predictions = torch.argmax(combined_probs, dim=-1) # Move to CPU for postprocessing final_predictions_cpu = final_predictions.cpu() # Flatten predictions predictions_flat = [label for sample in final_predictions_cpu.tolist() for label in sample] # Generate text with combined predictions return self.labeler.text_generator([' '] + chars[0] + [' '], predictions_flat[:len(chars[0]) + 2], corpus_type=Labeler.WHOLE_RAW).strip() def process_text(self, text, mode): """ Process text through the entire pipeline. Args: text: Input text as string or list of strings mode: Processing mode ('space' or 'correct') Returns: Processed text with labels applied """ # Run the full pipeline encoded_inputs, chars, labels = self.preprocess(text) predictions, logits = self.predict(encoded_inputs) if mode == 'space': result = self.postprocess_spacing(predictions, chars) elif mode == 'correct': result = self.postprocess_correcting(logits, chars, labels, 0.5) else: raise ValueError(f"Unrecognized mode: {mode}") return result def main(): """Command line interface for the model pipeline.""" parser = argparse.ArgumentParser(description="Process text using token classification model") parser.add_argument("--text", type=str, help="Text to process") parser.add_argument("--file", type=str, help="Path to file containing text to process") parser.add_argument("--output", type=str, help="Path to output file") parser.add_argument("--mode", type=str, choices=['space', 'correct'], default='space', help="Processing mode: 'space' uses model outputs only, 'correct' combines model results with " "original text spacing") parser.add_argument("--model_path", type=str, default=None, help="Path to model directory") args = parser.parse_args() # Initialize pipeline pipeline = ModelPipeline(args.model_path) # Get input text if args.text: input_text = args.text elif args.file: with open(args.file, 'r', encoding='utf-8') as f: input_text = f.read() else: print("Please provide either --text or --file") return # Process the text result = pipeline.process_text(input_text, args.mode) # Output results if args.output: with open(args.output, 'w', encoding='utf-8') as f: if isinstance(result, list): f.write('\n'.join(result)) else: f.write(result) else: print("\nProcessed Text:") if isinstance(result, list): for item in result: print(item) else: print(result) if __name__ == "__main__": main()