Commit ·
b7c31ab
0
Parent(s):
add remote to repo
Browse files- .gitignore +47 -0
- Dockerfile +23 -0
- README.md +182 -0
- docker-compose.yml +12 -0
- main.py +71 -0
- models.py +81 -0
- requirements.txt +40 -0
- upload_model.py +93 -0
- utils.py +331 -0
.gitignore
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the classification model folder with large files
|
| 2 |
+
classification_model/
|
| 3 |
+
|
| 4 |
+
# Python artifacts
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
*.so
|
| 9 |
+
.Python
|
| 10 |
+
env/
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
|
| 26 |
+
# Virtual environments
|
| 27 |
+
venv/
|
| 28 |
+
ENV/
|
| 29 |
+
env/
|
| 30 |
+
|
| 31 |
+
# IDE files
|
| 32 |
+
.idea/
|
| 33 |
+
.vscode/
|
| 34 |
+
*.swp
|
| 35 |
+
*.swo
|
| 36 |
+
|
| 37 |
+
# Jupyter Notebook
|
| 38 |
+
.ipynb_checkpoints
|
| 39 |
+
|
| 40 |
+
# OS specific files
|
| 41 |
+
.DS_Store
|
| 42 |
+
.DS_Store?
|
| 43 |
+
._*
|
| 44 |
+
.Spotlight-V100
|
| 45 |
+
.Trashes
|
| 46 |
+
ehthumbs.db
|
| 47 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Copy requirements first for better caching
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
|
| 8 |
+
# Install dependencies
|
| 9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 10 |
+
|
| 11 |
+
# Copy the rest of the application
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
# Set environment variables
|
| 15 |
+
ENV PORT=7860
|
| 16 |
+
ENV MODEL_PATH="Sparkonix/email-classifier-model"
|
| 17 |
+
# Replace YOUR_ACTUAL_USERNAME with your Hugging Face username after uploading the model
|
| 18 |
+
|
| 19 |
+
# Expose the port
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
# Command to run the application
|
| 23 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Email Classification for Support Team
|
| 2 |
+
|
| 3 |
+
## Project Overview
|
| 4 |
+
|
| 5 |
+
This project implements an email classification system that categorizes support emails into predefined categories while ensuring that personal information (PII) is masked before processing. The system uses a combination of Named Entity Recognition (NER) techniques for PII masking and a pre-trained XLM-RoBERTa model for email classification.
|
| 6 |
+
|
| 7 |
+
## Key Features
|
| 8 |
+
|
| 9 |
+
1. **Email Classification**: Classifies support emails into four categories:
|
| 10 |
+
- Incident
|
| 11 |
+
- Request
|
| 12 |
+
- Change
|
| 13 |
+
- Problem
|
| 14 |
+
|
| 15 |
+
2. **Personal Information Masking**: Detects and masks the following types of PII:
|
| 16 |
+
- Full Name ("full_name")
|
| 17 |
+
- Email Address ("email")
|
| 18 |
+
- Phone number ("phone_number")
|
| 19 |
+
- Date of birth ("dob")
|
| 20 |
+
- Aadhar card number ("aadhar_num")
|
| 21 |
+
- Credit/Debit Card Number ("credit_debit_no")
|
| 22 |
+
- CVV number ("cvv_no")
|
| 23 |
+
- Card expiry number ("expiry_no")
|
| 24 |
+
|
| 25 |
+
3. **API Interface**: Exposes the solution as a RESTful API endpoint.
|
| 26 |
+
|
| 27 |
+
## Project Structure
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
.
|
| 31 |
+
├── classification_model/ # Local model files (not used in deployment)
|
| 32 |
+
├── docker-compose.yml # Docker Compose configuration
|
| 33 |
+
├── Dockerfile # Docker configuration
|
| 34 |
+
├── main.py # Main FastAPI application
|
| 35 |
+
├── models.py # Email classifier model implementation
|
| 36 |
+
├── README.md # Project documentation
|
| 37 |
+
├── requirements.txt # Python dependencies
|
| 38 |
+
└── utils.py # PII masker implementation
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Installation
|
| 42 |
+
|
| 43 |
+
### Prerequisites
|
| 44 |
+
|
| 45 |
+
- Python 3.8+
|
| 46 |
+
- [Docker](https://www.docker.com/) (optional)
|
| 47 |
+
- Hugging Face account for model hosting
|
| 48 |
+
|
| 49 |
+
### Setup
|
| 50 |
+
|
| 51 |
+
1. Clone the repository:
|
| 52 |
+
```
|
| 53 |
+
git clone <repository-url>
|
| 54 |
+
cd email_classifier_project
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
2. Install dependencies:
|
| 58 |
+
```
|
| 59 |
+
pip install -r requirements.txt
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
3. Run the application:
|
| 63 |
+
```
|
| 64 |
+
python main.py
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Using Docker
|
| 68 |
+
|
| 69 |
+
1. Build and run with Docker Compose:
|
| 70 |
+
```
|
| 71 |
+
docker-compose up
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Uploading the Model to Hugging Face Hub
|
| 75 |
+
|
| 76 |
+
Before deploying the application to Hugging Face Spaces, you need to upload the model to the Hugging Face Model Hub:
|
| 77 |
+
|
| 78 |
+
1. Install the Hugging Face CLI if you haven't already:
|
| 79 |
+
```
|
| 80 |
+
pip install huggingface_hub
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
2. Log in to Hugging Face:
|
| 84 |
+
```
|
| 85 |
+
huggingface-cli login
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
3. Create a new model repository on Hugging Face:
|
| 89 |
+
```
|
| 90 |
+
huggingface-cli repo create email-classifier-model
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
4. Upload the model using Python:
|
| 94 |
+
```python
|
| 95 |
+
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
| 96 |
+
|
| 97 |
+
# Load the local model
|
| 98 |
+
model = XLMRobertaForSequenceClassification.from_pretrained("classification_model")
|
| 99 |
+
tokenizer = XLMRobertaTokenizer.from_pretrained("classification_model")
|
| 100 |
+
|
| 101 |
+
# Push to Hugging Face Hub
|
| 102 |
+
model.push_to_hub("YourUsername/email-classifier-model")
|
| 103 |
+
tokenizer.push_to_hub("YourUsername/email-classifier-model")
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
5. Update the `MODEL_PATH` environment variable in the Dockerfile with your Hugging Face model path:
|
| 107 |
+
```
|
| 108 |
+
ENV MODEL_PATH="YourUsername/email-classifier-model"
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
## API Usage
|
| 112 |
+
|
| 113 |
+
The API exposes a single endpoint for email classification:
|
| 114 |
+
|
| 115 |
+
- **Endpoint**: `/classify`
|
| 116 |
+
- **Method**: POST
|
| 117 |
+
- **Input Format**:
|
| 118 |
+
```json
|
| 119 |
+
{
|
| 120 |
+
"input_email_body": "string containing the email"
|
| 121 |
+
}
|
| 122 |
+
```
|
| 123 |
+
- **Output Format**:
|
| 124 |
+
```json
|
| 125 |
+
{
|
| 126 |
+
"input_email_body": "string containing the email",
|
| 127 |
+
"list_of_masked_entities": [
|
| 128 |
+
{
|
| 129 |
+
"position": [start_index, end_index],
|
| 130 |
+
"classification": "entity_type",
|
| 131 |
+
"entity": "original_entity_value"
|
| 132 |
+
}
|
| 133 |
+
],
|
| 134 |
+
"masked_email": "string containing the masked email",
|
| 135 |
+
"category_of_the_email": "string containing the class"
|
| 136 |
+
}
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## Example
|
| 140 |
+
|
| 141 |
+
```python
|
| 142 |
+
import requests
|
| 143 |
+
|
| 144 |
+
url = "https://username-space-name.hf.space/classify"
|
| 145 |
+
data = {
|
| 146 |
+
"input_email_body": "Hello, my name is John Doe, and I'm having issues with my account."
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
response = requests.post(url, json=data)
|
| 150 |
+
print(response.json())
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Deployment to Hugging Face Spaces
|
| 154 |
+
|
| 155 |
+
1. Create a new Space on Hugging Face:
|
| 156 |
+
- Go to https://huggingface.co/spaces
|
| 157 |
+
- Click "Create new Space"
|
| 158 |
+
- Choose a name for your Space
|
| 159 |
+
- Select "Docker" as the Space SDK
|
| 160 |
+
|
| 161 |
+
2. Connect your GitHub repository to the Space:
|
| 162 |
+
- In the Space settings, go to "Repository"
|
| 163 |
+
- Enter your GitHub repository URL
|
| 164 |
+
- Authenticate with GitHub if prompted
|
| 165 |
+
|
| 166 |
+
3. Ensure your Hugging Face Space has access to the model:
|
| 167 |
+
- Go to your model on Hugging Face Hub
|
| 168 |
+
- Go to "Settings" > "Collaborators"
|
| 169 |
+
- Add your Space as a collaborator with "Read" access
|
| 170 |
+
|
| 171 |
+
4. Your API will be available at:
|
| 172 |
+
```
|
| 173 |
+
https://username-space-name.hf.space/classify
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## Technologies Used
|
| 177 |
+
|
| 178 |
+
- **FastAPI**: Web framework for building the API
|
| 179 |
+
- **SpaCy**: NLP library for PII detection and masking
|
| 180 |
+
- **Transformers**: Hugging Face library for the email classification model
|
| 181 |
+
- **PyTorch**: Deep learning framework
|
| 182 |
+
- **Docker**: Containerization for deployment
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
api:
|
| 5 |
+
build: .
|
| 6 |
+
ports:
|
| 7 |
+
- "8000:7860"
|
| 8 |
+
volumes:
|
| 9 |
+
- .:/app
|
| 10 |
+
environment:
|
| 11 |
+
- PORT=7860
|
| 12 |
+
restart: unless-stopped
|
main.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from fastapi import FastAPI, HTTPException
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from typing import Dict, Any, List, Tuple, Optional
|
| 5 |
+
import uvicorn
|
| 6 |
+
|
| 7 |
+
from utils import PIIMasker
|
| 8 |
+
from models import EmailClassifier
|
| 9 |
+
|
| 10 |
+
# Initialize the FastAPI application
|
| 11 |
+
app = FastAPI(title="Email Classification API",
|
| 12 |
+
description="API for classifying support emails and masking PII",
|
| 13 |
+
version="1.0.0")
|
| 14 |
+
|
| 15 |
+
# Initialize the PII masker and email classifier
|
| 16 |
+
pii_masker = PIIMasker()
|
| 17 |
+
email_classifier = EmailClassifier()
|
| 18 |
+
|
| 19 |
+
class EmailInput(BaseModel):
|
| 20 |
+
"""Input model for the email classification endpoint"""
|
| 21 |
+
input_email_body: str
|
| 22 |
+
|
| 23 |
+
class EntityInfo(BaseModel):
|
| 24 |
+
"""Model for entity information"""
|
| 25 |
+
position: Tuple[int, int]
|
| 26 |
+
classification: str
|
| 27 |
+
entity: str
|
| 28 |
+
|
| 29 |
+
class EmailOutput(BaseModel):
|
| 30 |
+
"""Output model for the email classification endpoint"""
|
| 31 |
+
input_email_body: str
|
| 32 |
+
list_of_masked_entities: List[EntityInfo]
|
| 33 |
+
masked_email: str
|
| 34 |
+
category_of_the_email: str
|
| 35 |
+
|
| 36 |
+
@app.post("/classify", response_model=EmailOutput)
|
| 37 |
+
async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
|
| 38 |
+
"""
|
| 39 |
+
Classify an email into a support category while masking PII
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
email_input: The input email data
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
The classified email data with masked PII
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
# Process the email to mask PII
|
| 49 |
+
processed_data = pii_masker.process_email(email_input.input_email_body)
|
| 50 |
+
|
| 51 |
+
# Classify the masked email
|
| 52 |
+
classified_data = email_classifier.process_email(processed_data)
|
| 53 |
+
|
| 54 |
+
return classified_data
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise HTTPException(status_code=500, detail=f"Error processing email: {str(e)}")
|
| 57 |
+
|
| 58 |
+
@app.get("/health")
|
| 59 |
+
async def health_check():
|
| 60 |
+
"""
|
| 61 |
+
Health check endpoint
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Status message indicating the API is running
|
| 65 |
+
"""
|
| 66 |
+
return {"status": "healthy", "message": "Email classification API is running"}
|
| 67 |
+
|
| 68 |
+
# For local development and testing
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
port = int(os.environ.get("PORT", 8000))
|
| 71 |
+
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)
|
models.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
|
| 6 |
+
class EmailClassifier:
|
| 7 |
+
"""
|
| 8 |
+
Email classification model to categorize emails into different support categories
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
CATEGORIES = ["Incident", "Request", "Change", "Problem"]
|
| 12 |
+
|
| 13 |
+
def __init__(self, model_path: str = None):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the email classifier with a pre-trained model
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model_path: Path or Hugging Face Hub model ID
|
| 19 |
+
"""
|
| 20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
# Use environment variable for model path or fall back to Hugging Face Hub model
|
| 23 |
+
# This allows for flexibility in deployment
|
| 24 |
+
model_path = model_path or os.environ.get("MODEL_PATH", "Sparkonix11/email-classifier-model")
|
| 25 |
+
|
| 26 |
+
# Load the tokenizer and model from Hugging Face Hub or local path
|
| 27 |
+
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
|
| 28 |
+
self.model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
|
| 29 |
+
self.model.to(self.device)
|
| 30 |
+
self.model.eval()
|
| 31 |
+
|
| 32 |
+
def classify(self, masked_email: str) -> str:
|
| 33 |
+
"""
|
| 34 |
+
Classify a masked email into one of the predefined categories
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
masked_email: The email content with PII masked
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
The predicted category as a string
|
| 41 |
+
"""
|
| 42 |
+
# Tokenize the masked email
|
| 43 |
+
inputs = self.tokenizer(
|
| 44 |
+
masked_email,
|
| 45 |
+
return_tensors="pt",
|
| 46 |
+
padding="max_length",
|
| 47 |
+
truncation=True,
|
| 48 |
+
max_length=512
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
inputs = {key: val.to(self.device) for key, val in inputs.items()}
|
| 52 |
+
|
| 53 |
+
# Perform inference
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
outputs = self.model(**inputs)
|
| 56 |
+
logits = outputs.logits
|
| 57 |
+
predicted_class_idx = torch.argmax(logits, dim=1).item()
|
| 58 |
+
|
| 59 |
+
# Map the predicted class index to the category
|
| 60 |
+
return self.CATEGORIES[predicted_class_idx]
|
| 61 |
+
|
| 62 |
+
def process_email(self, masked_email_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 63 |
+
"""
|
| 64 |
+
Process an email by classifying it into a category
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
masked_email_data: Dictionary containing the masked email and other data
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
The input dictionary with the classification added
|
| 71 |
+
"""
|
| 72 |
+
# Extract masked email content
|
| 73 |
+
masked_email = masked_email_data["masked_email"]
|
| 74 |
+
|
| 75 |
+
# Classify the masked email
|
| 76 |
+
category = self.classify(masked_email)
|
| 77 |
+
|
| 78 |
+
# Add the classification to the data
|
| 79 |
+
masked_email_data["category_of_the_email"] = category
|
| 80 |
+
|
| 81 |
+
return masked_email_data
|
requirements.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI for the API
|
| 2 |
+
fastapi>=0.95.1
|
| 3 |
+
uvicorn>=0.22.0
|
| 4 |
+
|
| 5 |
+
# Pydantic (FastAPI dependency, ensure compatibility with your FastAPI version)
|
| 6 |
+
# Usually, if FastAPI is Pydantic v2 compatible, this is fine:
|
| 7 |
+
pydantic>=2.0.0
|
| 8 |
+
|
| 9 |
+
# Transformers for the classification model
|
| 10 |
+
# Your model's config.json specifies "transformers_version": "4.51.3"
|
| 11 |
+
transformers>=4.30.0
|
| 12 |
+
|
| 13 |
+
# PyTorch (CPU version by default for Docker on Hugging Face Spaces)
|
| 14 |
+
# Let pip choose a compatible version with transformers 4.51.3.
|
| 15 |
+
# PyTorch 2.x is generally expected.
|
| 16 |
+
torch>=2.0.0
|
| 17 |
+
|
| 18 |
+
# SpaCy for PII Masking
|
| 19 |
+
# Choose a version you've tested or a recent stable one.
|
| 20 |
+
# If you resolved build issues for 3.8.x, you can use that.
|
| 21 |
+
# Otherwise, 3.7.x is also robust.
|
| 22 |
+
spacy>=3.5.0
|
| 23 |
+
# The SpaCy model (e.g., xx_ent_wiki_sm) should be downloaded in the Dockerfile.
|
| 24 |
+
|
| 25 |
+
# For model weights loading (if you use .safetensors, which is likely with transformers 4.51.3)
|
| 26 |
+
safetensors
|
| 27 |
+
|
| 28 |
+
# Regex library (only if you are using the third-party 'regex' package;
|
| 29 |
+
# Python's built-in 're' module doesn't need to be listed)
|
| 30 |
+
# regex
|
| 31 |
+
|
| 32 |
+
# If your PII masking code directly uses numpy (less likely if encapsulated in SpaCy)
|
| 33 |
+
# numpy
|
| 34 |
+
|
| 35 |
+
# Additional dependencies
|
| 36 |
+
python-multipart>=0.0.6
|
| 37 |
+
huggingface-hub>=0.15.1
|
| 38 |
+
spacy-transformers>=1.2.5
|
| 39 |
+
xx-ent-wiki-sm @ https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.5.0/xx_ent_wiki_sm-3.5.0-py3-none-any.whl
|
| 40 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl
|
upload_model.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to upload the email classification model to Hugging Face Hub
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import argparse
|
| 8 |
+
import subprocess
|
| 9 |
+
import pkg_resources
|
| 10 |
+
|
| 11 |
+
def check_and_install_dependencies():
|
| 12 |
+
"""Check for required libraries and install if missing"""
|
| 13 |
+
required_packages = ['torch', 'transformers', 'sentencepiece']
|
| 14 |
+
installed_packages = {pkg.key for pkg in pkg_resources.working_set}
|
| 15 |
+
|
| 16 |
+
missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages]
|
| 17 |
+
|
| 18 |
+
if missing_packages:
|
| 19 |
+
print(f"Installing missing dependencies: {', '.join(missing_packages)}")
|
| 20 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing_packages)
|
| 21 |
+
print("Dependencies installed. You may need to restart the script.")
|
| 22 |
+
return False
|
| 23 |
+
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
def get_huggingface_username(token=None):
|
| 27 |
+
"""Get the username for the authenticated user"""
|
| 28 |
+
try:
|
| 29 |
+
from huggingface_hub import HfApi
|
| 30 |
+
api = HfApi(token=token)
|
| 31 |
+
user_info = api.whoami()
|
| 32 |
+
return user_info.get('name')
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error getting Hugging Face username: {e}")
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
"""Upload model to Hugging Face Hub"""
|
| 39 |
+
# Check dependencies first
|
| 40 |
+
if not check_and_install_dependencies():
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
# Import dependencies after installation check
|
| 44 |
+
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
| 45 |
+
from huggingface_hub import login
|
| 46 |
+
|
| 47 |
+
parser = argparse.ArgumentParser(description="Upload email classification model to Hugging Face Hub")
|
| 48 |
+
parser.add_argument("--model_path", type=str, default="classification_model",
|
| 49 |
+
help="Local path to the model files")
|
| 50 |
+
parser.add_argument("--hub_model_id", type=str,
|
| 51 |
+
help="Hugging Face Hub model ID (e.g., 'username/email-classifier-model')")
|
| 52 |
+
parser.add_argument("--model_name", type=str, default="email-classifier-model",
|
| 53 |
+
help="Name for the model repository (default: email-classifier-model)")
|
| 54 |
+
parser.add_argument("--token", type=str,
|
| 55 |
+
help="Hugging Face API token (optional, can use environment variable or huggingface-cli login)")
|
| 56 |
+
|
| 57 |
+
args = parser.parse_args()
|
| 58 |
+
|
| 59 |
+
# Login if token is provided
|
| 60 |
+
if args.token:
|
| 61 |
+
login(token=args.token)
|
| 62 |
+
|
| 63 |
+
# If hub_model_id is not provided, try to get username and construct it
|
| 64 |
+
if not args.hub_model_id:
|
| 65 |
+
username = get_huggingface_username(args.token)
|
| 66 |
+
if not username:
|
| 67 |
+
print("Could not determine Hugging Face username. Please provide --hub_model_id explicitly.")
|
| 68 |
+
return
|
| 69 |
+
args.hub_model_id = f"{username}/{args.model_name}"
|
| 70 |
+
|
| 71 |
+
print(f"Loading model from {args.model_path}...")
|
| 72 |
+
# Load the local model and tokenizer
|
| 73 |
+
model = XLMRobertaForSequenceClassification.from_pretrained(args.model_path)
|
| 74 |
+
tokenizer = XLMRobertaTokenizer.from_pretrained(args.model_path)
|
| 75 |
+
|
| 76 |
+
print(f"Uploading model to {args.hub_model_id}...")
|
| 77 |
+
try:
|
| 78 |
+
# Push to Hugging Face Hub
|
| 79 |
+
model.push_to_hub(args.hub_model_id)
|
| 80 |
+
tokenizer.push_to_hub(args.hub_model_id)
|
| 81 |
+
|
| 82 |
+
print("Model successfully uploaded to Hugging Face Hub!")
|
| 83 |
+
print(f"You can now use the model with the ID: {args.hub_model_id}")
|
| 84 |
+
print(f"Update the MODEL_PATH in Dockerfile to: {args.hub_model_id}")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error uploading model: {e}")
|
| 87 |
+
print("\nPossible solutions:")
|
| 88 |
+
print("1. Make sure you're logged in with 'huggingface-cli login'")
|
| 89 |
+
print("2. Check that you have permission to create repos in the specified namespace")
|
| 90 |
+
print("3. Try using your own username: --hub_model_id yourusername/email-classifier-model")
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import spacy
|
| 3 |
+
from typing import List, Dict, Tuple, Any
|
| 4 |
+
|
| 5 |
+
class Entity:
|
| 6 |
+
def __init__(self, start: int, end: int, entity_type: str, value: str):
|
| 7 |
+
self.start = start
|
| 8 |
+
self.end = end
|
| 9 |
+
self.entity_type = entity_type
|
| 10 |
+
self.value = value
|
| 11 |
+
|
| 12 |
+
def to_dict(self):
|
| 13 |
+
return {
|
| 14 |
+
"position": [self.start, self.end],
|
| 15 |
+
"classification": self.entity_type,
|
| 16 |
+
"entity": self.value
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def __repr__(self): # Added for easier debugging
|
| 20 |
+
return f"Entity(type='{self.entity_type}', value='{self.value}', start={self.start}, end={self.end})"
|
| 21 |
+
|
| 22 |
+
class PIIMasker:
|
| 23 |
+
def __init__(self, spacy_model_name: str = "xx_ent_wiki_sm"): # Allow model choice
|
| 24 |
+
# Load SpaCy model
|
| 25 |
+
try:
|
| 26 |
+
self.nlp = spacy.load(spacy_model_name)
|
| 27 |
+
except OSError:
|
| 28 |
+
print(f"SpaCy model '{spacy_model_name}' not found. Downloading...")
|
| 29 |
+
try:
|
| 30 |
+
spacy.cli.download(spacy_model_name)
|
| 31 |
+
self.nlp = spacy.load(spacy_model_name)
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Failed to download or load {spacy_model_name}. Error: {e}")
|
| 34 |
+
print("Attempting to load 'en_core_web_sm' as a fallback for English.")
|
| 35 |
+
try:
|
| 36 |
+
self.nlp = spacy.load("en_core_web_sm")
|
| 37 |
+
except OSError:
|
| 38 |
+
print("Downloading 'en_core_web_sm'...")
|
| 39 |
+
spacy.cli.download("en_core_web_sm")
|
| 40 |
+
self.nlp = spacy.load("en_core_web_sm")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Initialize regex patterns
|
| 44 |
+
self._initialize_patterns()
|
| 45 |
+
|
| 46 |
+
def _initialize_patterns(self):
|
| 47 |
+
# Define regex patterns for different entity types
|
| 48 |
+
self.patterns = {
|
| 49 |
+
"email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
| 50 |
+
"phone_number": r'\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b',
|
| 51 |
+
# Card number regex: common formats, allows optional spaces/hyphens
|
| 52 |
+
"credit_debit_no": r'\b(?:(?:\d{4}[\s-]?){3}\d{4}|\d{13,19})\b',
|
| 53 |
+
# CVV: 3 or 4 digits, ensuring it's a standalone number (word boundary)
|
| 54 |
+
"cvv_no": r'\b\d{3,4}\b',
|
| 55 |
+
# Expiry: MM/YY or MM/YYYY, common separators
|
| 56 |
+
"expiry_no": r'\b(0[1-9]|1[0-2])[/\s-]([0-9]{2}|20[0-9]{2})\b',
|
| 57 |
+
"aadhar_num": r'\b\d{4}\s?\d{4}\s?\d{4}\b',
|
| 58 |
+
# DOB: DD/MM/YYYY or DD-MM-YYYY etc.
|
| 59 |
+
"dob": r'\b(0[1-9]|[12][0-9]|3[01])[/\s-](0[1-9]|1[0-2])[/\s-](?:19|20)\d\d\b'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def detect_regex_entities(self, text: str) -> List[Entity]:
|
| 63 |
+
"""Detect entities using regex patterns"""
|
| 64 |
+
entities = []
|
| 65 |
+
|
| 66 |
+
for entity_type, pattern in self.patterns.items():
|
| 67 |
+
for match in re.finditer(pattern, text):
|
| 68 |
+
start, end = match.span()
|
| 69 |
+
value = match.group()
|
| 70 |
+
|
| 71 |
+
# Specific verifications
|
| 72 |
+
if entity_type == "credit_debit_no":
|
| 73 |
+
if not self.verify_credit_card(text, match):
|
| 74 |
+
continue
|
| 75 |
+
elif entity_type == "cvv_no":
|
| 76 |
+
if not self.verify_cvv(text, match):
|
| 77 |
+
continue
|
| 78 |
+
elif entity_type == "dob": # Using the generic context verifier for DOB
|
| 79 |
+
if not self._verify_with_context(text, start, end, ["birth", "dob", "born"]):
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Avoid detecting parts of already matched longer entities (e.g. year within a DOB)
|
| 83 |
+
# This is a simple check; more robust overlap handling is done later
|
| 84 |
+
is_substring_of_existing = False
|
| 85 |
+
for existing_entity in entities:
|
| 86 |
+
if existing_entity.start <= start and existing_entity.end >= end and existing_entity.value != value :
|
| 87 |
+
is_substring_of_existing = True
|
| 88 |
+
break
|
| 89 |
+
if is_substring_of_existing:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
entities.append(Entity(start, end, entity_type, value))
|
| 93 |
+
return entities
|
| 94 |
+
|
| 95 |
+
def _verify_with_context(self, text: str, start: int, end: int, keywords: List[str], window: int = 50) -> bool:
|
| 96 |
+
"""Verify an entity match using surrounding context"""
|
| 97 |
+
context_before = text[max(0, start - window):start].lower()
|
| 98 |
+
context_after = text[end:min(len(text), end + window)].lower()
|
| 99 |
+
|
| 100 |
+
for keyword in keywords:
|
| 101 |
+
if keyword in context_before or keyword in context_after:
|
| 102 |
+
return True
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def verify_credit_card(self, text: str, match: re.Match) -> bool:
|
| 106 |
+
"""Verify if a match is actually a credit card number using contextual clues"""
|
| 107 |
+
context_window = 50
|
| 108 |
+
start, end = match.span()
|
| 109 |
+
|
| 110 |
+
context_before = text[max(0, start - context_window):start].lower()
|
| 111 |
+
context_after = text[end:min(len(text), end + context_window)].lower()
|
| 112 |
+
|
| 113 |
+
card_keywords = ["card", "credit", "debit", "visa", "mastercard", "payment", "amex", "account no", "card no"]
|
| 114 |
+
for keyword in card_keywords:
|
| 115 |
+
if keyword in context_before or keyword in context_after:
|
| 116 |
+
return True
|
| 117 |
+
# Basic Luhn algorithm check (optional, can be computationally more intensive)
|
| 118 |
+
# For simplicity, we'll rely on context here. If needed, Luhn can be added.
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def verify_cvv(self, text: str, match: re.Match) -> bool:
|
| 123 |
+
"""Verify if a 3-4 digit number is actually a CVV using contextual clues"""
|
| 124 |
+
context_window = 30
|
| 125 |
+
start, end = match.span()
|
| 126 |
+
value = match.group()
|
| 127 |
+
|
| 128 |
+
# If it's part of a longer number sequence (like a phone number or ID), it's likely not a CVV
|
| 129 |
+
# Check character immediately before and after
|
| 130 |
+
char_before = text[start-1:start] if start > 0 else ""
|
| 131 |
+
char_after = text[end:end+1] if end < len(text) else ""
|
| 132 |
+
if char_before.isdigit() or char_after.isdigit():
|
| 133 |
+
return False # It's part of a larger number
|
| 134 |
+
|
| 135 |
+
context_before = text[max(0, start - context_window):start].lower()
|
| 136 |
+
context_after = text[end:min(len(text), end + context_window)].lower()
|
| 137 |
+
|
| 138 |
+
cvv_keywords = ["cvv", "cvc", "csc", "security code", "card verification", "verification no"]
|
| 139 |
+
date_keywords = ["date", "year", "/", "-", "born", "age", "since", "established", "version", "model", "grade"] # More exhaustive
|
| 140 |
+
|
| 141 |
+
is_cvv_context = any(keyword in context_before or keyword in context_after for keyword in cvv_keywords)
|
| 142 |
+
|
| 143 |
+
# If it looks like a year in common contexts, it's probably not a CVV
|
| 144 |
+
# e.g. "since 2023", "class of 99", "born 1990"
|
| 145 |
+
if value.isdigit() and (1900 <= int(value) <= 2100 if len(value) == 4 else False):
|
| 146 |
+
year_context_keywords = ["year", "born", "fiscal", "established", "since", "class of", "ended", "began", "joined"]
|
| 147 |
+
if any(kw in context_before for kw in year_context_keywords):
|
| 148 |
+
return False # Likely a year
|
| 149 |
+
# If it's MM/YY or MM/YYYY context, it's expiry, not CVV
|
| 150 |
+
if re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()): # Ends with MM/
|
| 151 |
+
return False # Part of an expiry date
|
| 152 |
+
|
| 153 |
+
is_date_context = any(keyword in context_before or keyword in context_after for keyword in date_keywords)
|
| 154 |
+
|
| 155 |
+
# Check if the number itself looks like a year in typical CVV lengths
|
| 156 |
+
looks_like_year = False
|
| 157 |
+
if len(value) == 2 and value.isdigit(): # e.g. "23" for year in expiry
|
| 158 |
+
if any(k in context_before for k in ["expiry", "exp", "valid thru", "good thru"]) or \
|
| 159 |
+
re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()):
|
| 160 |
+
looks_like_year = True # It's the YY part of an expiry
|
| 161 |
+
elif len(value) == 4 and value.isdigit() and (1900 <= int(value) <= 2100):
|
| 162 |
+
if any(k in (context_before + context_after) for k in ["year", "born", "fiscal"]):
|
| 163 |
+
looks_like_year = True
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
return is_cvv_context and not (is_date_context and looks_like_year)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def detect_name_entities(self, text: str) -> List[Entity]:
|
| 170 |
+
"""Detect name entities using SpaCy NER"""
|
| 171 |
+
entities = []
|
| 172 |
+
doc = self.nlp(text)
|
| 173 |
+
|
| 174 |
+
for ent in doc.ents:
|
| 175 |
+
# Use PER for person, common in many models like xx_ent_wiki_sm
|
| 176 |
+
# Also checking for PERSON as some models might use it.
|
| 177 |
+
if ent.label_ in ["PER", "PERSON"]:
|
| 178 |
+
entities.append(Entity(ent.start_char, ent.end_char, "full_name", ent.text))
|
| 179 |
+
return entities
|
| 180 |
+
|
| 181 |
+
def detect_all_entities(self, text: str) -> List[Entity]:
|
| 182 |
+
"""Detect all types of entities in the text"""
|
| 183 |
+
# Get regex-based entities first
|
| 184 |
+
entities = self.detect_regex_entities(text)
|
| 185 |
+
|
| 186 |
+
# Add SpaCy-based name entities
|
| 187 |
+
# We add them second and let overlap resolution handle conflicts
|
| 188 |
+
# This is because NER for names can be more reliable than a generic regex
|
| 189 |
+
name_entities = self.detect_name_entities(text)
|
| 190 |
+
entities.extend(name_entities)
|
| 191 |
+
|
| 192 |
+
# Sort entities by their starting position
|
| 193 |
+
entities.sort(key=lambda x: x.start)
|
| 194 |
+
|
| 195 |
+
# Resolve overlaps: prioritize NER entities (like names) or longer regex matches
|
| 196 |
+
entities = self._resolve_overlaps(entities)
|
| 197 |
+
return entities
|
| 198 |
+
|
| 199 |
+
def _resolve_overlaps(self, entities: List[Entity]) -> List[Entity]:
|
| 200 |
+
"""Resolve overlapping entities.
|
| 201 |
+
Prioritize:
|
| 202 |
+
1. NER entities (e.g., "full_name") if they overlap with regex.
|
| 203 |
+
2. Longer entities over shorter ones.
|
| 204 |
+
3. If same length and type, no change (first one encountered).
|
| 205 |
+
"""
|
| 206 |
+
if not entities:
|
| 207 |
+
return []
|
| 208 |
+
|
| 209 |
+
# A simple greedy approach: iterate and remove/adjust overlaps
|
| 210 |
+
# This can be made more sophisticated
|
| 211 |
+
resolved_entities: List[Entity] = []
|
| 212 |
+
for current_entity in sorted(entities, key=lambda e: (e.start, -(e.end - e.start))): # Process by start, then by longest
|
| 213 |
+
is_overlapped_or_contained = False
|
| 214 |
+
temp_resolved = []
|
| 215 |
+
for i, res_entity in enumerate(resolved_entities):
|
| 216 |
+
# Check for overlap:
|
| 217 |
+
# Current: |----|
|
| 218 |
+
# Res: |----| or |----| or |--| or |------|
|
| 219 |
+
overlap = max(0, min(current_entity.end, res_entity.end) - max(current_entity.start, res_entity.start))
|
| 220 |
+
|
| 221 |
+
if overlap > 0:
|
| 222 |
+
is_overlapped_or_contained = True
|
| 223 |
+
# Preference:
|
| 224 |
+
# 1. NER names often trump regex if they are the ones causing overlap
|
| 225 |
+
# 2. Longer entity wins
|
| 226 |
+
current_len = current_entity.end - current_entity.start
|
| 227 |
+
res_len = res_entity.end - res_entity.start
|
| 228 |
+
|
| 229 |
+
# If current is a name and overlaps, and previous is not a name, prefer current if it's not fully contained
|
| 230 |
+
if current_entity.entity_type == "full_name" and res_entity.entity_type != "full_name":
|
| 231 |
+
if not (res_entity.start <= current_entity.start and res_entity.end >= current_entity.end): # current not fully contained by res
|
| 232 |
+
# remove res_entity, current will be added later
|
| 233 |
+
continue # go to next res_entity, this one is marked for removal
|
| 234 |
+
elif res_entity.entity_type == "full_name" and current_entity.entity_type != "full_name":
|
| 235 |
+
# res_entity is a name, current is not. Prefer res_entity if it's not fully contained
|
| 236 |
+
if not (current_entity.start <= res_entity.start and current_entity.end >= res_entity.end):
|
| 237 |
+
# current entity is subsumed or less important, so don't add current
|
| 238 |
+
# and keep res_entity
|
| 239 |
+
temp_resolved.append(res_entity)
|
| 240 |
+
is_overlapped_or_contained = True # Mark current as handled
|
| 241 |
+
break # Current is dominated
|
| 242 |
+
|
| 243 |
+
# General case: longer entity wins
|
| 244 |
+
if current_len > res_len:
|
| 245 |
+
# current is longer, res_entity is removed from consideration for this current_entity
|
| 246 |
+
pass # res_entity will not be added to temp_resolved if it's fully replaced
|
| 247 |
+
elif res_len > current_len:
|
| 248 |
+
# res is longer, current is dominated
|
| 249 |
+
temp_resolved.append(res_entity)
|
| 250 |
+
is_overlapped_or_contained = True # Mark current as handled
|
| 251 |
+
break # Current is dominated
|
| 252 |
+
else: # Same length, keep existing one (res_entity)
|
| 253 |
+
temp_resolved.append(res_entity)
|
| 254 |
+
is_overlapped_or_contained = True # Mark current as handled
|
| 255 |
+
break
|
| 256 |
+
else: # No overlap
|
| 257 |
+
temp_resolved.append(res_entity)
|
| 258 |
+
|
| 259 |
+
if not is_overlapped_or_contained:
|
| 260 |
+
temp_resolved.append(current_entity)
|
| 261 |
+
|
| 262 |
+
resolved_entities = sorted(temp_resolved, key=lambda e: (e.start, -(e.end - e.start)))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# Final pass to remove fully contained entities if a larger one exists
|
| 266 |
+
final_entities = []
|
| 267 |
+
if not resolved_entities:
|
| 268 |
+
return []
|
| 269 |
+
|
| 270 |
+
for i, entity in enumerate(resolved_entities):
|
| 271 |
+
is_contained = False
|
| 272 |
+
for j, other_entity in enumerate(resolved_entities):
|
| 273 |
+
if i == j:
|
| 274 |
+
continue
|
| 275 |
+
# If 'entity' is strictly contained within 'other_entity'
|
| 276 |
+
if other_entity.start <= entity.start and other_entity.end >= entity.end and \
|
| 277 |
+
(other_entity.end - other_entity.start > entity.end - entity.start):
|
| 278 |
+
is_contained = True
|
| 279 |
+
break
|
| 280 |
+
if not is_contained:
|
| 281 |
+
final_entities.append(entity)
|
| 282 |
+
|
| 283 |
+
return final_entities
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def mask_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
|
| 287 |
+
"""
|
| 288 |
+
Mask PII entities in the text and return masked text and entity information
|
| 289 |
+
"""
|
| 290 |
+
entities = self.detect_all_entities(text)
|
| 291 |
+
entity_info = [entity.to_dict() for entity in entities]
|
| 292 |
+
|
| 293 |
+
masked_text = list(text) # Use list of chars for easier replacement
|
| 294 |
+
|
| 295 |
+
# Sort entities by start position to ensure correct masking,
|
| 296 |
+
# longest first at same start to prevent partial masking by shorter entities
|
| 297 |
+
entities.sort(key=lambda x: (x.start, -(x.end - x.start)))
|
| 298 |
+
|
| 299 |
+
offset = 0
|
| 300 |
+
new_text_parts = []
|
| 301 |
+
current_pos = 0
|
| 302 |
+
|
| 303 |
+
for entity in entities:
|
| 304 |
+
# Add text before the entity
|
| 305 |
+
if entity.start > current_pos:
|
| 306 |
+
new_text_parts.append(text[current_pos:entity.start])
|
| 307 |
+
|
| 308 |
+
# Add the mask
|
| 309 |
+
mask = f"[{entity.entity_type.upper()}]" # Changed to upper for clarity
|
| 310 |
+
new_text_parts.append(mask)
|
| 311 |
+
|
| 312 |
+
current_pos = entity.end
|
| 313 |
+
|
| 314 |
+
# Add any remaining text after the last entity
|
| 315 |
+
if current_pos < len(text):
|
| 316 |
+
new_text_parts.append(text[current_pos:])
|
| 317 |
+
|
| 318 |
+
return "".join(new_text_parts), entity_info
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def process_email(self, email_text: str) -> Dict[str, Any]:
|
| 322 |
+
"""
|
| 323 |
+
Process an email by detecting and masking PII entities
|
| 324 |
+
"""
|
| 325 |
+
masked_email, entity_info = self.mask_text(email_text)
|
| 326 |
+
return {
|
| 327 |
+
"input_email_body": email_text,
|
| 328 |
+
"list_of_masked_entities": entity_info,
|
| 329 |
+
"masked_email": masked_email,
|
| 330 |
+
"category_of_the_email": ""
|
| 331 |
+
}
|