Sparkonix commited on
Commit
b7c31ab
·
0 Parent(s):

add remote to repo

Browse files
Files changed (9) hide show
  1. .gitignore +47 -0
  2. Dockerfile +23 -0
  3. README.md +182 -0
  4. docker-compose.yml +12 -0
  5. main.py +71 -0
  6. models.py +81 -0
  7. requirements.txt +40 -0
  8. upload_model.py +93 -0
  9. utils.py +331 -0
.gitignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore the classification model folder with large files
2
+ classification_model/
3
+
4
+ # Python artifacts
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+ *.so
9
+ .Python
10
+ env/
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # Virtual environments
27
+ venv/
28
+ ENV/
29
+ env/
30
+
31
+ # IDE files
32
+ .idea/
33
+ .vscode/
34
+ *.swp
35
+ *.swo
36
+
37
+ # Jupyter Notebook
38
+ .ipynb_checkpoints
39
+
40
+ # OS specific files
41
+ .DS_Store
42
+ .DS_Store?
43
+ ._*
44
+ .Spotlight-V100
45
+ .Trashes
46
+ ehthumbs.db
47
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy requirements first for better caching
6
+ COPY requirements.txt .
7
+
8
+ # Install dependencies
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of the application
12
+ COPY . .
13
+
14
+ # Set environment variables
15
+ ENV PORT=7860
16
+ ENV MODEL_PATH="Sparkonix/email-classifier-model"
17
+ # Replace YOUR_ACTUAL_USERNAME with your Hugging Face username after uploading the model
18
+
19
+ # Expose the port
20
+ EXPOSE 7860
21
+
22
+ # Command to run the application
23
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Email Classification for Support Team
2
+
3
+ ## Project Overview
4
+
5
+ This project implements an email classification system that categorizes support emails into predefined categories while ensuring that personal information (PII) is masked before processing. The system uses a combination of Named Entity Recognition (NER) techniques for PII masking and a pre-trained XLM-RoBERTa model for email classification.
6
+
7
+ ## Key Features
8
+
9
+ 1. **Email Classification**: Classifies support emails into four categories:
10
+ - Incident
11
+ - Request
12
+ - Change
13
+ - Problem
14
+
15
+ 2. **Personal Information Masking**: Detects and masks the following types of PII:
16
+ - Full Name ("full_name")
17
+ - Email Address ("email")
18
+ - Phone number ("phone_number")
19
+ - Date of birth ("dob")
20
+ - Aadhar card number ("aadhar_num")
21
+ - Credit/Debit Card Number ("credit_debit_no")
22
+ - CVV number ("cvv_no")
23
+ - Card expiry number ("expiry_no")
24
+
25
+ 3. **API Interface**: Exposes the solution as a RESTful API endpoint.
26
+
27
+ ## Project Structure
28
+
29
+ ```
30
+ .
31
+ ├── classification_model/ # Local model files (not used in deployment)
32
+ ├── docker-compose.yml # Docker Compose configuration
33
+ ├── Dockerfile # Docker configuration
34
+ ├── main.py # Main FastAPI application
35
+ ├── models.py # Email classifier model implementation
36
+ ├── README.md # Project documentation
37
+ ├── requirements.txt # Python dependencies
38
+ └── utils.py # PII masker implementation
39
+ ```
40
+
41
+ ## Installation
42
+
43
+ ### Prerequisites
44
+
45
+ - Python 3.8+
46
+ - [Docker](https://www.docker.com/) (optional)
47
+ - Hugging Face account for model hosting
48
+
49
+ ### Setup
50
+
51
+ 1. Clone the repository:
52
+ ```
53
+ git clone <repository-url>
54
+ cd email_classifier_project
55
+ ```
56
+
57
+ 2. Install dependencies:
58
+ ```
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ 3. Run the application:
63
+ ```
64
+ python main.py
65
+ ```
66
+
67
+ ### Using Docker
68
+
69
+ 1. Build and run with Docker Compose:
70
+ ```
71
+ docker-compose up
72
+ ```
73
+
74
+ ## Uploading the Model to Hugging Face Hub
75
+
76
+ Before deploying the application to Hugging Face Spaces, you need to upload the model to the Hugging Face Model Hub:
77
+
78
+ 1. Install the Hugging Face CLI if you haven't already:
79
+ ```
80
+ pip install huggingface_hub
81
+ ```
82
+
83
+ 2. Log in to Hugging Face:
84
+ ```
85
+ huggingface-cli login
86
+ ```
87
+
88
+ 3. Create a new model repository on Hugging Face:
89
+ ```
90
+ huggingface-cli repo create email-classifier-model
91
+ ```
92
+
93
+ 4. Upload the model using Python:
94
+ ```python
95
+ from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
96
+
97
+ # Load the local model
98
+ model = XLMRobertaForSequenceClassification.from_pretrained("classification_model")
99
+ tokenizer = XLMRobertaTokenizer.from_pretrained("classification_model")
100
+
101
+ # Push to Hugging Face Hub
102
+ model.push_to_hub("YourUsername/email-classifier-model")
103
+ tokenizer.push_to_hub("YourUsername/email-classifier-model")
104
+ ```
105
+
106
+ 5. Update the `MODEL_PATH` environment variable in the Dockerfile with your Hugging Face model path:
107
+ ```
108
+ ENV MODEL_PATH="YourUsername/email-classifier-model"
109
+ ```
110
+
111
+ ## API Usage
112
+
113
+ The API exposes a single endpoint for email classification:
114
+
115
+ - **Endpoint**: `/classify`
116
+ - **Method**: POST
117
+ - **Input Format**:
118
+ ```json
119
+ {
120
+ "input_email_body": "string containing the email"
121
+ }
122
+ ```
123
+ - **Output Format**:
124
+ ```json
125
+ {
126
+ "input_email_body": "string containing the email",
127
+ "list_of_masked_entities": [
128
+ {
129
+ "position": [start_index, end_index],
130
+ "classification": "entity_type",
131
+ "entity": "original_entity_value"
132
+ }
133
+ ],
134
+ "masked_email": "string containing the masked email",
135
+ "category_of_the_email": "string containing the class"
136
+ }
137
+ ```
138
+
139
+ ## Example
140
+
141
+ ```python
142
+ import requests
143
+
144
+ url = "https://username-space-name.hf.space/classify"
145
+ data = {
146
+ "input_email_body": "Hello, my name is John Doe, and I'm having issues with my account."
147
+ }
148
+
149
+ response = requests.post(url, json=data)
150
+ print(response.json())
151
+ ```
152
+
153
+ ## Deployment to Hugging Face Spaces
154
+
155
+ 1. Create a new Space on Hugging Face:
156
+ - Go to https://huggingface.co/spaces
157
+ - Click "Create new Space"
158
+ - Choose a name for your Space
159
+ - Select "Docker" as the Space SDK
160
+
161
+ 2. Connect your GitHub repository to the Space:
162
+ - In the Space settings, go to "Repository"
163
+ - Enter your GitHub repository URL
164
+ - Authenticate with GitHub if prompted
165
+
166
+ 3. Ensure your Hugging Face Space has access to the model:
167
+ - Go to your model on Hugging Face Hub
168
+ - Go to "Settings" > "Collaborators"
169
+ - Add your Space as a collaborator with "Read" access
170
+
171
+ 4. Your API will be available at:
172
+ ```
173
+ https://username-space-name.hf.space/classify
174
+ ```
175
+
176
+ ## Technologies Used
177
+
178
+ - **FastAPI**: Web framework for building the API
179
+ - **SpaCy**: NLP library for PII detection and masking
180
+ - **Transformers**: Hugging Face library for the email classification model
181
+ - **PyTorch**: Deep learning framework
182
+ - **Docker**: Containerization for deployment
docker-compose.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3'
2
+
3
+ services:
4
+ api:
5
+ build: .
6
+ ports:
7
+ - "8000:7860"
8
+ volumes:
9
+ - .:/app
10
+ environment:
11
+ - PORT=7860
12
+ restart: unless-stopped
main.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from typing import Dict, Any, List, Tuple, Optional
5
+ import uvicorn
6
+
7
+ from utils import PIIMasker
8
+ from models import EmailClassifier
9
+
10
+ # Initialize the FastAPI application
11
+ app = FastAPI(title="Email Classification API",
12
+ description="API for classifying support emails and masking PII",
13
+ version="1.0.0")
14
+
15
+ # Initialize the PII masker and email classifier
16
+ pii_masker = PIIMasker()
17
+ email_classifier = EmailClassifier()
18
+
19
+ class EmailInput(BaseModel):
20
+ """Input model for the email classification endpoint"""
21
+ input_email_body: str
22
+
23
+ class EntityInfo(BaseModel):
24
+ """Model for entity information"""
25
+ position: Tuple[int, int]
26
+ classification: str
27
+ entity: str
28
+
29
+ class EmailOutput(BaseModel):
30
+ """Output model for the email classification endpoint"""
31
+ input_email_body: str
32
+ list_of_masked_entities: List[EntityInfo]
33
+ masked_email: str
34
+ category_of_the_email: str
35
+
36
+ @app.post("/classify", response_model=EmailOutput)
37
+ async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
38
+ """
39
+ Classify an email into a support category while masking PII
40
+
41
+ Args:
42
+ email_input: The input email data
43
+
44
+ Returns:
45
+ The classified email data with masked PII
46
+ """
47
+ try:
48
+ # Process the email to mask PII
49
+ processed_data = pii_masker.process_email(email_input.input_email_body)
50
+
51
+ # Classify the masked email
52
+ classified_data = email_classifier.process_email(processed_data)
53
+
54
+ return classified_data
55
+ except Exception as e:
56
+ raise HTTPException(status_code=500, detail=f"Error processing email: {str(e)}")
57
+
58
+ @app.get("/health")
59
+ async def health_check():
60
+ """
61
+ Health check endpoint
62
+
63
+ Returns:
64
+ Status message indicating the API is running
65
+ """
66
+ return {"status": "healthy", "message": "Email classification API is running"}
67
+
68
+ # For local development and testing
69
+ if __name__ == "__main__":
70
+ port = int(os.environ.get("PORT", 8000))
71
+ uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)
models.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
4
+ from typing import Dict, Any
5
+
6
+ class EmailClassifier:
7
+ """
8
+ Email classification model to categorize emails into different support categories
9
+ """
10
+
11
+ CATEGORIES = ["Incident", "Request", "Change", "Problem"]
12
+
13
+ def __init__(self, model_path: str = None):
14
+ """
15
+ Initialize the email classifier with a pre-trained model
16
+
17
+ Args:
18
+ model_path: Path or Hugging Face Hub model ID
19
+ """
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Use environment variable for model path or fall back to Hugging Face Hub model
23
+ # This allows for flexibility in deployment
24
+ model_path = model_path or os.environ.get("MODEL_PATH", "Sparkonix11/email-classifier-model")
25
+
26
+ # Load the tokenizer and model from Hugging Face Hub or local path
27
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
28
+ self.model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
29
+ self.model.to(self.device)
30
+ self.model.eval()
31
+
32
+ def classify(self, masked_email: str) -> str:
33
+ """
34
+ Classify a masked email into one of the predefined categories
35
+
36
+ Args:
37
+ masked_email: The email content with PII masked
38
+
39
+ Returns:
40
+ The predicted category as a string
41
+ """
42
+ # Tokenize the masked email
43
+ inputs = self.tokenizer(
44
+ masked_email,
45
+ return_tensors="pt",
46
+ padding="max_length",
47
+ truncation=True,
48
+ max_length=512
49
+ )
50
+
51
+ inputs = {key: val.to(self.device) for key, val in inputs.items()}
52
+
53
+ # Perform inference
54
+ with torch.no_grad():
55
+ outputs = self.model(**inputs)
56
+ logits = outputs.logits
57
+ predicted_class_idx = torch.argmax(logits, dim=1).item()
58
+
59
+ # Map the predicted class index to the category
60
+ return self.CATEGORIES[predicted_class_idx]
61
+
62
+ def process_email(self, masked_email_data: Dict[str, Any]) -> Dict[str, Any]:
63
+ """
64
+ Process an email by classifying it into a category
65
+
66
+ Args:
67
+ masked_email_data: Dictionary containing the masked email and other data
68
+
69
+ Returns:
70
+ The input dictionary with the classification added
71
+ """
72
+ # Extract masked email content
73
+ masked_email = masked_email_data["masked_email"]
74
+
75
+ # Classify the masked email
76
+ category = self.classify(masked_email)
77
+
78
+ # Add the classification to the data
79
+ masked_email_data["category_of_the_email"] = category
80
+
81
+ return masked_email_data
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI for the API
2
+ fastapi>=0.95.1
3
+ uvicorn>=0.22.0
4
+
5
+ # Pydantic (FastAPI dependency, ensure compatibility with your FastAPI version)
6
+ # Usually, if FastAPI is Pydantic v2 compatible, this is fine:
7
+ pydantic>=2.0.0
8
+
9
+ # Transformers for the classification model
10
+ # Your model's config.json specifies "transformers_version": "4.51.3"
11
+ transformers>=4.30.0
12
+
13
+ # PyTorch (CPU version by default for Docker on Hugging Face Spaces)
14
+ # Let pip choose a compatible version with transformers 4.51.3.
15
+ # PyTorch 2.x is generally expected.
16
+ torch>=2.0.0
17
+
18
+ # SpaCy for PII Masking
19
+ # Choose a version you've tested or a recent stable one.
20
+ # If you resolved build issues for 3.8.x, you can use that.
21
+ # Otherwise, 3.7.x is also robust.
22
+ spacy>=3.5.0
23
+ # The SpaCy model (e.g., xx_ent_wiki_sm) should be downloaded in the Dockerfile.
24
+
25
+ # For model weights loading (if you use .safetensors, which is likely with transformers 4.51.3)
26
+ safetensors
27
+
28
+ # Regex library (only if you are using the third-party 'regex' package;
29
+ # Python's built-in 're' module doesn't need to be listed)
30
+ # regex
31
+
32
+ # If your PII masking code directly uses numpy (less likely if encapsulated in SpaCy)
33
+ # numpy
34
+
35
+ # Additional dependencies
36
+ python-multipart>=0.0.6
37
+ huggingface-hub>=0.15.1
38
+ spacy-transformers>=1.2.5
39
+ xx-ent-wiki-sm @ https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.5.0/xx_ent_wiki_sm-3.5.0-py3-none-any.whl
40
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl
upload_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to upload the email classification model to Hugging Face Hub
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import argparse
8
+ import subprocess
9
+ import pkg_resources
10
+
11
+ def check_and_install_dependencies():
12
+ """Check for required libraries and install if missing"""
13
+ required_packages = ['torch', 'transformers', 'sentencepiece']
14
+ installed_packages = {pkg.key for pkg in pkg_resources.working_set}
15
+
16
+ missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages]
17
+
18
+ if missing_packages:
19
+ print(f"Installing missing dependencies: {', '.join(missing_packages)}")
20
+ subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing_packages)
21
+ print("Dependencies installed. You may need to restart the script.")
22
+ return False
23
+
24
+ return True
25
+
26
+ def get_huggingface_username(token=None):
27
+ """Get the username for the authenticated user"""
28
+ try:
29
+ from huggingface_hub import HfApi
30
+ api = HfApi(token=token)
31
+ user_info = api.whoami()
32
+ return user_info.get('name')
33
+ except Exception as e:
34
+ print(f"Error getting Hugging Face username: {e}")
35
+ return None
36
+
37
+ def main():
38
+ """Upload model to Hugging Face Hub"""
39
+ # Check dependencies first
40
+ if not check_and_install_dependencies():
41
+ return
42
+
43
+ # Import dependencies after installation check
44
+ from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
45
+ from huggingface_hub import login
46
+
47
+ parser = argparse.ArgumentParser(description="Upload email classification model to Hugging Face Hub")
48
+ parser.add_argument("--model_path", type=str, default="classification_model",
49
+ help="Local path to the model files")
50
+ parser.add_argument("--hub_model_id", type=str,
51
+ help="Hugging Face Hub model ID (e.g., 'username/email-classifier-model')")
52
+ parser.add_argument("--model_name", type=str, default="email-classifier-model",
53
+ help="Name for the model repository (default: email-classifier-model)")
54
+ parser.add_argument("--token", type=str,
55
+ help="Hugging Face API token (optional, can use environment variable or huggingface-cli login)")
56
+
57
+ args = parser.parse_args()
58
+
59
+ # Login if token is provided
60
+ if args.token:
61
+ login(token=args.token)
62
+
63
+ # If hub_model_id is not provided, try to get username and construct it
64
+ if not args.hub_model_id:
65
+ username = get_huggingface_username(args.token)
66
+ if not username:
67
+ print("Could not determine Hugging Face username. Please provide --hub_model_id explicitly.")
68
+ return
69
+ args.hub_model_id = f"{username}/{args.model_name}"
70
+
71
+ print(f"Loading model from {args.model_path}...")
72
+ # Load the local model and tokenizer
73
+ model = XLMRobertaForSequenceClassification.from_pretrained(args.model_path)
74
+ tokenizer = XLMRobertaTokenizer.from_pretrained(args.model_path)
75
+
76
+ print(f"Uploading model to {args.hub_model_id}...")
77
+ try:
78
+ # Push to Hugging Face Hub
79
+ model.push_to_hub(args.hub_model_id)
80
+ tokenizer.push_to_hub(args.hub_model_id)
81
+
82
+ print("Model successfully uploaded to Hugging Face Hub!")
83
+ print(f"You can now use the model with the ID: {args.hub_model_id}")
84
+ print(f"Update the MODEL_PATH in Dockerfile to: {args.hub_model_id}")
85
+ except Exception as e:
86
+ print(f"Error uploading model: {e}")
87
+ print("\nPossible solutions:")
88
+ print("1. Make sure you're logged in with 'huggingface-cli login'")
89
+ print("2. Check that you have permission to create repos in the specified namespace")
90
+ print("3. Try using your own username: --hub_model_id yourusername/email-classifier-model")
91
+
92
+ if __name__ == "__main__":
93
+ main()
utils.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import spacy
3
+ from typing import List, Dict, Tuple, Any
4
+
5
+ class Entity:
6
+ def __init__(self, start: int, end: int, entity_type: str, value: str):
7
+ self.start = start
8
+ self.end = end
9
+ self.entity_type = entity_type
10
+ self.value = value
11
+
12
+ def to_dict(self):
13
+ return {
14
+ "position": [self.start, self.end],
15
+ "classification": self.entity_type,
16
+ "entity": self.value
17
+ }
18
+
19
+ def __repr__(self): # Added for easier debugging
20
+ return f"Entity(type='{self.entity_type}', value='{self.value}', start={self.start}, end={self.end})"
21
+
22
+ class PIIMasker:
23
+ def __init__(self, spacy_model_name: str = "xx_ent_wiki_sm"): # Allow model choice
24
+ # Load SpaCy model
25
+ try:
26
+ self.nlp = spacy.load(spacy_model_name)
27
+ except OSError:
28
+ print(f"SpaCy model '{spacy_model_name}' not found. Downloading...")
29
+ try:
30
+ spacy.cli.download(spacy_model_name)
31
+ self.nlp = spacy.load(spacy_model_name)
32
+ except Exception as e:
33
+ print(f"Failed to download or load {spacy_model_name}. Error: {e}")
34
+ print("Attempting to load 'en_core_web_sm' as a fallback for English.")
35
+ try:
36
+ self.nlp = spacy.load("en_core_web_sm")
37
+ except OSError:
38
+ print("Downloading 'en_core_web_sm'...")
39
+ spacy.cli.download("en_core_web_sm")
40
+ self.nlp = spacy.load("en_core_web_sm")
41
+
42
+
43
+ # Initialize regex patterns
44
+ self._initialize_patterns()
45
+
46
+ def _initialize_patterns(self):
47
+ # Define regex patterns for different entity types
48
+ self.patterns = {
49
+ "email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
50
+ "phone_number": r'\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b',
51
+ # Card number regex: common formats, allows optional spaces/hyphens
52
+ "credit_debit_no": r'\b(?:(?:\d{4}[\s-]?){3}\d{4}|\d{13,19})\b',
53
+ # CVV: 3 or 4 digits, ensuring it's a standalone number (word boundary)
54
+ "cvv_no": r'\b\d{3,4}\b',
55
+ # Expiry: MM/YY or MM/YYYY, common separators
56
+ "expiry_no": r'\b(0[1-9]|1[0-2])[/\s-]([0-9]{2}|20[0-9]{2})\b',
57
+ "aadhar_num": r'\b\d{4}\s?\d{4}\s?\d{4}\b',
58
+ # DOB: DD/MM/YYYY or DD-MM-YYYY etc.
59
+ "dob": r'\b(0[1-9]|[12][0-9]|3[01])[/\s-](0[1-9]|1[0-2])[/\s-](?:19|20)\d\d\b'
60
+ }
61
+
62
+ def detect_regex_entities(self, text: str) -> List[Entity]:
63
+ """Detect entities using regex patterns"""
64
+ entities = []
65
+
66
+ for entity_type, pattern in self.patterns.items():
67
+ for match in re.finditer(pattern, text):
68
+ start, end = match.span()
69
+ value = match.group()
70
+
71
+ # Specific verifications
72
+ if entity_type == "credit_debit_no":
73
+ if not self.verify_credit_card(text, match):
74
+ continue
75
+ elif entity_type == "cvv_no":
76
+ if not self.verify_cvv(text, match):
77
+ continue
78
+ elif entity_type == "dob": # Using the generic context verifier for DOB
79
+ if not self._verify_with_context(text, start, end, ["birth", "dob", "born"]):
80
+ continue
81
+
82
+ # Avoid detecting parts of already matched longer entities (e.g. year within a DOB)
83
+ # This is a simple check; more robust overlap handling is done later
84
+ is_substring_of_existing = False
85
+ for existing_entity in entities:
86
+ if existing_entity.start <= start and existing_entity.end >= end and existing_entity.value != value :
87
+ is_substring_of_existing = True
88
+ break
89
+ if is_substring_of_existing:
90
+ continue
91
+
92
+ entities.append(Entity(start, end, entity_type, value))
93
+ return entities
94
+
95
+ def _verify_with_context(self, text: str, start: int, end: int, keywords: List[str], window: int = 50) -> bool:
96
+ """Verify an entity match using surrounding context"""
97
+ context_before = text[max(0, start - window):start].lower()
98
+ context_after = text[end:min(len(text), end + window)].lower()
99
+
100
+ for keyword in keywords:
101
+ if keyword in context_before or keyword in context_after:
102
+ return True
103
+ return False
104
+
105
+ def verify_credit_card(self, text: str, match: re.Match) -> bool:
106
+ """Verify if a match is actually a credit card number using contextual clues"""
107
+ context_window = 50
108
+ start, end = match.span()
109
+
110
+ context_before = text[max(0, start - context_window):start].lower()
111
+ context_after = text[end:min(len(text), end + context_window)].lower()
112
+
113
+ card_keywords = ["card", "credit", "debit", "visa", "mastercard", "payment", "amex", "account no", "card no"]
114
+ for keyword in card_keywords:
115
+ if keyword in context_before or keyword in context_after:
116
+ return True
117
+ # Basic Luhn algorithm check (optional, can be computationally more intensive)
118
+ # For simplicity, we'll rely on context here. If needed, Luhn can be added.
119
+ return False
120
+
121
+
122
+ def verify_cvv(self, text: str, match: re.Match) -> bool:
123
+ """Verify if a 3-4 digit number is actually a CVV using contextual clues"""
124
+ context_window = 30
125
+ start, end = match.span()
126
+ value = match.group()
127
+
128
+ # If it's part of a longer number sequence (like a phone number or ID), it's likely not a CVV
129
+ # Check character immediately before and after
130
+ char_before = text[start-1:start] if start > 0 else ""
131
+ char_after = text[end:end+1] if end < len(text) else ""
132
+ if char_before.isdigit() or char_after.isdigit():
133
+ return False # It's part of a larger number
134
+
135
+ context_before = text[max(0, start - context_window):start].lower()
136
+ context_after = text[end:min(len(text), end + context_window)].lower()
137
+
138
+ cvv_keywords = ["cvv", "cvc", "csc", "security code", "card verification", "verification no"]
139
+ date_keywords = ["date", "year", "/", "-", "born", "age", "since", "established", "version", "model", "grade"] # More exhaustive
140
+
141
+ is_cvv_context = any(keyword in context_before or keyword in context_after for keyword in cvv_keywords)
142
+
143
+ # If it looks like a year in common contexts, it's probably not a CVV
144
+ # e.g. "since 2023", "class of 99", "born 1990"
145
+ if value.isdigit() and (1900 <= int(value) <= 2100 if len(value) == 4 else False):
146
+ year_context_keywords = ["year", "born", "fiscal", "established", "since", "class of", "ended", "began", "joined"]
147
+ if any(kw in context_before for kw in year_context_keywords):
148
+ return False # Likely a year
149
+ # If it's MM/YY or MM/YYYY context, it's expiry, not CVV
150
+ if re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()): # Ends with MM/
151
+ return False # Part of an expiry date
152
+
153
+ is_date_context = any(keyword in context_before or keyword in context_after for keyword in date_keywords)
154
+
155
+ # Check if the number itself looks like a year in typical CVV lengths
156
+ looks_like_year = False
157
+ if len(value) == 2 and value.isdigit(): # e.g. "23" for year in expiry
158
+ if any(k in context_before for k in ["expiry", "exp", "valid thru", "good thru"]) or \
159
+ re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()):
160
+ looks_like_year = True # It's the YY part of an expiry
161
+ elif len(value) == 4 and value.isdigit() and (1900 <= int(value) <= 2100):
162
+ if any(k in (context_before + context_after) for k in ["year", "born", "fiscal"]):
163
+ looks_like_year = True
164
+
165
+
166
+ return is_cvv_context and not (is_date_context and looks_like_year)
167
+
168
+
169
+ def detect_name_entities(self, text: str) -> List[Entity]:
170
+ """Detect name entities using SpaCy NER"""
171
+ entities = []
172
+ doc = self.nlp(text)
173
+
174
+ for ent in doc.ents:
175
+ # Use PER for person, common in many models like xx_ent_wiki_sm
176
+ # Also checking for PERSON as some models might use it.
177
+ if ent.label_ in ["PER", "PERSON"]:
178
+ entities.append(Entity(ent.start_char, ent.end_char, "full_name", ent.text))
179
+ return entities
180
+
181
+ def detect_all_entities(self, text: str) -> List[Entity]:
182
+ """Detect all types of entities in the text"""
183
+ # Get regex-based entities first
184
+ entities = self.detect_regex_entities(text)
185
+
186
+ # Add SpaCy-based name entities
187
+ # We add them second and let overlap resolution handle conflicts
188
+ # This is because NER for names can be more reliable than a generic regex
189
+ name_entities = self.detect_name_entities(text)
190
+ entities.extend(name_entities)
191
+
192
+ # Sort entities by their starting position
193
+ entities.sort(key=lambda x: x.start)
194
+
195
+ # Resolve overlaps: prioritize NER entities (like names) or longer regex matches
196
+ entities = self._resolve_overlaps(entities)
197
+ return entities
198
+
199
+ def _resolve_overlaps(self, entities: List[Entity]) -> List[Entity]:
200
+ """Resolve overlapping entities.
201
+ Prioritize:
202
+ 1. NER entities (e.g., "full_name") if they overlap with regex.
203
+ 2. Longer entities over shorter ones.
204
+ 3. If same length and type, no change (first one encountered).
205
+ """
206
+ if not entities:
207
+ return []
208
+
209
+ # A simple greedy approach: iterate and remove/adjust overlaps
210
+ # This can be made more sophisticated
211
+ resolved_entities: List[Entity] = []
212
+ for current_entity in sorted(entities, key=lambda e: (e.start, -(e.end - e.start))): # Process by start, then by longest
213
+ is_overlapped_or_contained = False
214
+ temp_resolved = []
215
+ for i, res_entity in enumerate(resolved_entities):
216
+ # Check for overlap:
217
+ # Current: |----|
218
+ # Res: |----| or |----| or |--| or |------|
219
+ overlap = max(0, min(current_entity.end, res_entity.end) - max(current_entity.start, res_entity.start))
220
+
221
+ if overlap > 0:
222
+ is_overlapped_or_contained = True
223
+ # Preference:
224
+ # 1. NER names often trump regex if they are the ones causing overlap
225
+ # 2. Longer entity wins
226
+ current_len = current_entity.end - current_entity.start
227
+ res_len = res_entity.end - res_entity.start
228
+
229
+ # If current is a name and overlaps, and previous is not a name, prefer current if it's not fully contained
230
+ if current_entity.entity_type == "full_name" and res_entity.entity_type != "full_name":
231
+ if not (res_entity.start <= current_entity.start and res_entity.end >= current_entity.end): # current not fully contained by res
232
+ # remove res_entity, current will be added later
233
+ continue # go to next res_entity, this one is marked for removal
234
+ elif res_entity.entity_type == "full_name" and current_entity.entity_type != "full_name":
235
+ # res_entity is a name, current is not. Prefer res_entity if it's not fully contained
236
+ if not (current_entity.start <= res_entity.start and current_entity.end >= res_entity.end):
237
+ # current entity is subsumed or less important, so don't add current
238
+ # and keep res_entity
239
+ temp_resolved.append(res_entity)
240
+ is_overlapped_or_contained = True # Mark current as handled
241
+ break # Current is dominated
242
+
243
+ # General case: longer entity wins
244
+ if current_len > res_len:
245
+ # current is longer, res_entity is removed from consideration for this current_entity
246
+ pass # res_entity will not be added to temp_resolved if it's fully replaced
247
+ elif res_len > current_len:
248
+ # res is longer, current is dominated
249
+ temp_resolved.append(res_entity)
250
+ is_overlapped_or_contained = True # Mark current as handled
251
+ break # Current is dominated
252
+ else: # Same length, keep existing one (res_entity)
253
+ temp_resolved.append(res_entity)
254
+ is_overlapped_or_contained = True # Mark current as handled
255
+ break
256
+ else: # No overlap
257
+ temp_resolved.append(res_entity)
258
+
259
+ if not is_overlapped_or_contained:
260
+ temp_resolved.append(current_entity)
261
+
262
+ resolved_entities = sorted(temp_resolved, key=lambda e: (e.start, -(e.end - e.start)))
263
+
264
+
265
+ # Final pass to remove fully contained entities if a larger one exists
266
+ final_entities = []
267
+ if not resolved_entities:
268
+ return []
269
+
270
+ for i, entity in enumerate(resolved_entities):
271
+ is_contained = False
272
+ for j, other_entity in enumerate(resolved_entities):
273
+ if i == j:
274
+ continue
275
+ # If 'entity' is strictly contained within 'other_entity'
276
+ if other_entity.start <= entity.start and other_entity.end >= entity.end and \
277
+ (other_entity.end - other_entity.start > entity.end - entity.start):
278
+ is_contained = True
279
+ break
280
+ if not is_contained:
281
+ final_entities.append(entity)
282
+
283
+ return final_entities
284
+
285
+
286
+ def mask_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
287
+ """
288
+ Mask PII entities in the text and return masked text and entity information
289
+ """
290
+ entities = self.detect_all_entities(text)
291
+ entity_info = [entity.to_dict() for entity in entities]
292
+
293
+ masked_text = list(text) # Use list of chars for easier replacement
294
+
295
+ # Sort entities by start position to ensure correct masking,
296
+ # longest first at same start to prevent partial masking by shorter entities
297
+ entities.sort(key=lambda x: (x.start, -(x.end - x.start)))
298
+
299
+ offset = 0
300
+ new_text_parts = []
301
+ current_pos = 0
302
+
303
+ for entity in entities:
304
+ # Add text before the entity
305
+ if entity.start > current_pos:
306
+ new_text_parts.append(text[current_pos:entity.start])
307
+
308
+ # Add the mask
309
+ mask = f"[{entity.entity_type.upper()}]" # Changed to upper for clarity
310
+ new_text_parts.append(mask)
311
+
312
+ current_pos = entity.end
313
+
314
+ # Add any remaining text after the last entity
315
+ if current_pos < len(text):
316
+ new_text_parts.append(text[current_pos:])
317
+
318
+ return "".join(new_text_parts), entity_info
319
+
320
+
321
+ def process_email(self, email_text: str) -> Dict[str, Any]:
322
+ """
323
+ Process an email by detecting and masking PII entities
324
+ """
325
+ masked_email, entity_info = self.mask_text(email_text)
326
+ return {
327
+ "input_email_body": email_text,
328
+ "list_of_masked_entities": entity_info,
329
+ "masked_email": masked_email,
330
+ "category_of_the_email": ""
331
+ }