File size: 2,752 Bytes
b7c31ab edc8356 b7c31ab 5e08f57 edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 b7c31ab edc8356 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import torch
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
from typing import Dict, Any
class EmailClassifier:
"""
Email classification model to categorize emails into different support categories
"""
CATEGORIES = ['Change', 'Incident', 'Problem', 'Request']
def __init__(self, model_path: str = None):
"""
Initialize the email classifier with a pre-trained model
Args:
model_path: Path or Hugging Face Hub model ID
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use environment variable for model path or fall back to Hugging Face Hub model
# This allows for flexibility in deployment
model_path = model_path or os.environ.get(
"MODEL_PATH", "Sparkonix11/email-classifier-model"
)
# Load the tokenizer and model from Hugging Face Hub or local path
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
self.model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
def classify(self, masked_email: str) -> str:
"""
Classify a masked email into one of the predefined categories
Args:
masked_email: The email content with PII masked
Returns:
The predicted category as a string
"""
# Tokenize the masked email
inputs = self.tokenizer(
masked_email,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)
inputs = {key: val.to(self.device) for key, val in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=1).item()
# Map the predicted class index to the category
return self.CATEGORIES[predicted_class_idx]
def process_email(self, masked_email_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process an email by classifying it into a category
Args:
masked_email_data: Dictionary containing the masked email and other data
Returns:
The input dictionary with the classification added
"""
# Extract masked email content
masked_email = masked_email_data["masked_email"]
# Classify the masked email
category = self.classify(masked_email)
# Add the classification to the data
masked_email_data["category_of_the_email"] = category
return masked_email_data
|