matin-ebrahimkhani's picture
Upload the model
07b65ad verified
import os
import json
import torch
import argparse
from transformers import AutoTokenizer
from modeling_custom import BertForTokenClassificationWithFiveO
from labeler import Labeler
class ModelPipeline:
"""
Pipeline for text processing using the BertForTokenClassificationWithFourO model.
Handles preprocessing, inference, and postprocessing.
"""
def __init__(self, model_path=None):
"""
Initialize the pipeline with model and tokenizer.
Args:
model_path: Path to the model directory. Defaults to current directory.
"""
# Use current directory if no path specified
if model_path is None:
model_path = os.path.dirname(os.path.abspath(__file__))
# Load tokenizer and model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = BertForTokenClassificationWithFiveO.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
# Load config for any custom settings
config_path = os.path.join(model_path, "config.json")
with open(config_path, 'r') as f:
self.config = json.load(f)
# Initialize labeler for postprocessing
self.labeler = Labeler(tags=(1, 2),
regexes=(r'[^\S\r\n\v\f]', r'\u200c'),
chars=(" ", "‌"),
class_count=2)
def _tokenize(self, chars, chunk_size=512):
input_ids = []
attention_masks = []
ids_ = [101] + [self.tokenizer.encode(char)[1] for char in chars] + [102]
for i in range(0, len(ids_), chunk_size):
chunked_ids = ids_[i:i + chunk_size]
attention_mask = [1] * len(chunked_ids)
if len(chunked_ids) != chunk_size:
attention_mask += [0] * (chunk_size - len(chunked_ids)) # padding the attention mask accordingly
chunked_ids += [0] * (chunk_size - len(chunked_ids)) # padding the last chunk to chunk size
input_ids.append(chunked_ids)
attention_masks.append(attention_mask)
return input_ids, attention_masks
def preprocess(self, text):
chars, labels = self.labeler.label_text(text, corpus_type=Labeler.WHOLE_RAW)
input_ids, attention_mask = self._tokenize(chars[0], chunk_size=512)
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
labels = [0] + labels[0] + [0]
return {"input_ids": input_ids, "attention_mask": attention_mask}, chars, labels
def predict(self, encoded_inputs):
"""
Run model inference on preprocessed inputs.
Args:
encoded_inputs: Tokenized inputs from preprocess method
Returns:
Model predictions
"""
# Move input tensors to the same device as the model
input_ids = encoded_inputs["input_ids"].to(self.device)
attention_mask = encoded_inputs["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
# Get predicted class for each token
predictions = torch.argmax(outputs.logits, dim=-1)
return predictions, outputs.logits
def postprocess_spacing(self, predictions, chars):
# Move predictions back to CPU if needed for postprocessing
predictions_cpu = predictions.cpu()
predictions_flat = [label for sample in predictions_cpu.tolist() for label in sample]
return self.labeler.text_generator([' '] + chars[0] + [' '],
predictions_flat[:len(chars[0]) + 2],
corpus_type=Labeler.WHOLE_RAW).strip()
def postprocess_correcting(self, logits, chars, labels, alpha):
# Process the labels to match the sequence length
sequence_length = logits.size(1)
if len(labels) < sequence_length:
labels.extend([0] * (sequence_length - len(labels))) # Padding with 0
else:
labels = labels[:sequence_length] # Truncate if longer
# Convert labels to one-hot encoding
num_classes = logits.size(-1)
user_labels_tensor = torch.tensor([labels], device=self.device)
user_labels_one_hot = torch.nn.functional.one_hot(user_labels_tensor, num_classes=num_classes).float()
# Expand dimensions to match logits shape if needed
if logits.dim() > user_labels_one_hot.dim():
user_labels_one_hot = user_labels_one_hot.unsqueeze(0)
# Combine logits and user labels
combined_logits = logits * (1 - alpha) + user_labels_one_hot * alpha
# Apply softmax to get probabilities
combined_probs = torch.nn.functional.softmax(combined_logits, dim=-1)
# Get the final predictions
final_predictions = torch.argmax(combined_probs, dim=-1)
# Move to CPU for postprocessing
final_predictions_cpu = final_predictions.cpu()
# Flatten predictions
predictions_flat = [label for sample in final_predictions_cpu.tolist() for label in sample]
# Generate text with combined predictions
return self.labeler.text_generator([' '] + chars[0] + [' '],
predictions_flat[:len(chars[0]) + 2],
corpus_type=Labeler.WHOLE_RAW).strip()
def process_text(self, text, mode):
"""
Process text through the entire pipeline.
Args:
text: Input text as string or list of strings
mode: Processing mode ('space' or 'correct')
Returns:
Processed text with labels applied
"""
# Run the full pipeline
encoded_inputs, chars, labels = self.preprocess(text)
predictions, logits = self.predict(encoded_inputs)
if mode == 'space':
result = self.postprocess_spacing(predictions, chars)
elif mode == 'correct':
result = self.postprocess_correcting(logits, chars, labels, 0.5)
else:
raise ValueError(f"Unrecognized mode: {mode}")
return result
def main():
"""Command line interface for the model pipeline."""
parser = argparse.ArgumentParser(description="Process text using token classification model")
parser.add_argument("--text", type=str, help="Text to process")
parser.add_argument("--file", type=str, help="Path to file containing text to process")
parser.add_argument("--output", type=str, help="Path to output file")
parser.add_argument("--mode", type=str, choices=['space', 'correct'], default='space',
help="Processing mode: 'space' uses model outputs only, 'correct' combines model results with "
"original text spacing")
parser.add_argument("--model_path", type=str, default=None, help="Path to model directory")
args = parser.parse_args()
# Initialize pipeline
pipeline = ModelPipeline(args.model_path)
# Get input text
if args.text:
input_text = args.text
elif args.file:
with open(args.file, 'r', encoding='utf-8') as f:
input_text = f.read()
else:
print("Please provide either --text or --file")
return
# Process the text
result = pipeline.process_text(input_text, args.mode)
# Output results
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
if isinstance(result, list):
f.write('\n'.join(result))
else:
f.write(result)
else:
print("\nProcessed Text:")
if isinstance(result, list):
for item in result:
print(item)
else:
print(result)
if __name__ == "__main__":
main()