jentegeo commited on
Commit
e1992da
·
verified ·
1 Parent(s): b81a1d3

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +0 -3
  2. app.py +81 -0
  3. email_classifier.joblib +3 -0
  4. models.py +133 -0
  5. requirements.txt +12 -0
  6. test.py +4 -0
  7. train_model.py +21 -0
  8. utils.py +133 -0
README.md CHANGED
@@ -1,3 +0,0 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict
4
+ from models import EmailClassifier
5
+ from utils import PIIDetector
6
+ import joblib
7
+ import os
8
+
9
+ app = FastAPI(
10
+ title="Email Classification and PII Masking API",
11
+ description="API for classifying support emails and masking PII information",
12
+ version="1.0.0"
13
+ )
14
+
15
+ # Initialize components
16
+ pii_detector = PIIDetector()
17
+ email_classifier = EmailClassifier()
18
+
19
+ try:
20
+ email_classifier.load_model("email_classifier.joblib")
21
+ except Exception as e:
22
+ print("Model loading failed:", e)
23
+ raise RuntimeError("Pre-trained model not found. Please train it using train_model.py")
24
+
25
+
26
+ class EmailRequest(BaseModel):
27
+ email_body: str
28
+
29
+ class MaskedEntity(BaseModel):
30
+ position: List[int]
31
+ classification: str
32
+ entity: str
33
+
34
+ class EmailResponse(BaseModel):
35
+ input_email_body: str
36
+ list_of_masked_entities: List[MaskedEntity]
37
+ masked_email: str
38
+ category_of_the_email: str
39
+
40
+ @app.post("/classify_email", response_model=EmailResponse)
41
+ async def classify_email(request: EmailRequest):
42
+ """
43
+ Endpoint for classifying emails and masking PII.
44
+
45
+ Args:
46
+ request: EmailRequest containing the email body
47
+
48
+ Returns:
49
+ EmailResponse with classification and PII masking information
50
+ """
51
+ try:
52
+ # Step 1: Detect PII in the email
53
+ email_text = request.email_body
54
+ detected_entities = pii_detector.detect_pii(email_text)
55
+
56
+ # Step 2: Mask the PII
57
+ masked_email, masked_entities = pii_detector.mask_pii(email_text, detected_entities)
58
+
59
+ # Step 3: Classify the email
60
+ category = email_classifier.predict(masked_email)
61
+
62
+ # Prepare response
63
+ response = {
64
+ "input_email_body": email_text,
65
+ "list_of_masked_entities": masked_entities,
66
+ "masked_email": masked_email,
67
+ "category_of_the_email": category
68
+ }
69
+
70
+ return response
71
+ except Exception as e:
72
+ raise HTTPException(status_code=500, detail=str(e))
73
+
74
+ @app.get("/health")
75
+ async def health_check():
76
+ """Health check endpoint"""
77
+ return {"status": "healthy"}
78
+
79
+ if __name__ == "__main__":
80
+ import uvicorn
81
+ uvicorn.run(app, host="0.0.0.0", port=5000)
email_classifier.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9ed8331348cd41157590442a442dacfad9221129defd3c40a907755fcc4d149
3
+ size 116355553
models.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ import re
5
+ import joblib
6
+ import nltk
7
+
8
+ from nltk.stem import WordNetLemmatizer
9
+ from nltk.corpus import stopwords
10
+
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ #from sklearn.linear_model import SGDClassifier
13
+ from sklearn.pipeline import Pipeline
14
+ from sklearn.model_selection import train_test_split, GridSearchCV
15
+ from sklearn.metrics import classification_report
16
+ #from sklearn.model_selection import StratifiedKFold
17
+ from sklearn.ensemble import RandomForestClassifier
18
+
19
+ from imblearn.over_sampling import RandomOverSampler
20
+ from imblearn.pipeline import Pipeline as ImbPipeline
21
+
22
+ nltk.download('stopwords')
23
+ nltk.download('wordnet')
24
+
25
+
26
+ class EmailClassifier:
27
+ def __init__(self):
28
+ self.model = None
29
+ self.vectorizer = None
30
+ self.classes = None
31
+ self.lemmatizer = WordNetLemmatizer()
32
+ self.stop_words = set(stopwords.words('english'))
33
+
34
+
35
+ def preprocess(self, text: str) -> str:
36
+ text = text.lower()
37
+
38
+ # Remove email addresses
39
+ text = re.sub(r'\S+@\S+', ' ', text)
40
+
41
+ # Keep alphanumerics, dots, underscores, hyphens (useful in tech terms)
42
+ text = re.sub(r'[^a-zA-Z0-9\s._-]', ' ', text)
43
+
44
+ tokens = text.split()
45
+
46
+ # Custom stopwords: remove common words but retain useful ones
47
+ custom_stop_words = self.stop_words - {'no', 'not', 'nor', 'against', 'aren', "aren't", 'isn', "isn't"}
48
+
49
+ # Lemmatize and filter
50
+ tokens = [
51
+ self.lemmatizer.lemmatize(word)
52
+ for word in tokens
53
+ if word not in custom_stop_words and len(word) > 1
54
+ ]
55
+
56
+ return ' '.join(tokens)
57
+
58
+
59
+
60
+ def train(self, X, y, use_grid_search=False):
61
+ print("Preprocessing data...")
62
+ X_processed = [self.preprocess(text) for text in X]
63
+
64
+ print("Oversampling minority classes...")
65
+ ros = RandomOverSampler(random_state=42)
66
+ X_resampled, y_resampled = ros.fit_resample(np.array(X_processed).reshape(-1, 1), y)
67
+ X_resampled = X_resampled.ravel() # Flatten the array back
68
+
69
+ print("Initializing pipeline...")
70
+ pipeline = ImbPipeline([
71
+ ('tfidf', TfidfVectorizer(
72
+ stop_words='english',
73
+ max_features=15000,
74
+ ngram_range=(1, 3),
75
+ sublinear_tf=True
76
+ )),
77
+ ('clf', RandomForestClassifier(n_estimators=100, class_weight='balanced_subsample', random_state=42))
78
+ ])
79
+
80
+
81
+ if use_grid_search:
82
+ print("Running Grid Search...")
83
+ params = {
84
+ 'clf__alpha': [1e-4, 1e-3, 1e-2],
85
+ 'clf__penalty': ['l2', 'l1', 'elasticnet']
86
+ }
87
+ grid = GridSearchCV(pipeline, param_grid=params, scoring='f1_weighted', cv=5, verbose=2)
88
+ grid.fit(X_resampled, y_resampled)
89
+ self.model = grid.best_estimator_
90
+ print("Best Params:", grid.best_params_)
91
+ else:
92
+ print("Fitting model...")
93
+ pipeline.fit(X_resampled, y_resampled)
94
+ self.model = pipeline
95
+
96
+ print("Model trained.")
97
+ self.classes = self.model.named_steps['clf'].classes_
98
+
99
+
100
+ def predict(self, text: str) -> str:
101
+ if not self.model:
102
+ raise ValueError("Model not trained or loaded")
103
+ processed_text = self.preprocess(text)
104
+ return self.model.predict([processed_text])[0]
105
+
106
+ def save_model(self, model_path: str):
107
+ if not self.model:
108
+ raise ValueError("Model not trained")
109
+ joblib.dump(self.model, model_path)
110
+
111
+ def load_model(self, model_path: str):
112
+ if not os.path.exists(model_path):
113
+ raise FileNotFoundError(f"Model file not found at {model_path}")
114
+ self.model = joblib.load(model_path)
115
+ self.classes = self.model.named_steps['clf'].classes_
116
+
117
+ @staticmethod
118
+ def load_data_from_csv(csv_path: str, text_col: str = "email", label_col: str = "type"):
119
+ df = pd.read_csv(csv_path)
120
+ return df[[text_col, label_col]].dropna()
121
+
122
+ def train_from_csv(self, csv_path: str, text_col: str = "email", label_col: str = "type", use_grid_search=False):
123
+ df = self.load_data_from_csv(csv_path, text_col, label_col)
124
+ #X_train, X_test, y_train, y_test = train_test_split(df[text_col], df[label_col], test_size=0.2, random_state=42)
125
+ X_train, X_test, y_train, y_test = train_test_split(
126
+ df[text_col], df[label_col], test_size=0.2, random_state=42, stratify=df[label_col]
127
+ )
128
+ self.train(X_train, y_train, use_grid_search=use_grid_search)
129
+ X_test_processed = [self.preprocess(text) for text in X_test]
130
+ y_pred = self.model.predict(X_test_processed)
131
+ print(classification_report(y_test, y_pred))
132
+ self.save_model("email_classifier.joblib")
133
+ print("Model trained and saved to email_classifier.joblib")
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.68.0
2
+ uvicorn>=0.15.0
3
+ pydantic>=1.8.0
4
+ scikit-learn>=0.24.0
5
+ pandas>=1.2.0
6
+ numpy>=1.20.0
7
+ joblib>=1.0.0
8
+ python-dateutil>=2.8.0
9
+ nltk
10
+ imblearn
11
+ huggingface_hub
12
+ scikit-learn
test.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from models import EmailClassifier
2
+
3
+ clf = EmailClassifier()
4
+ clf.train_from_csv("data/combined_emails_with_natural_pii.csv", use_grid_search=True)
train_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import EmailClassifier
2
+ import argparse
3
+
4
+ def main():
5
+ parser = argparse.ArgumentParser(description="Train email classification model")
6
+ parser.add_argument("--csv_path", type=str, required=True, help="Path to CSV dataset")
7
+ parser.add_argument("--text_col", type=str, default="email", help="Name of text column")
8
+ parser.add_argument("--label_col", type=str, default="type", help="Name of label column")
9
+
10
+ args = parser.parse_args()
11
+
12
+ # Initialize and train classifier
13
+ classifier = EmailClassifier()
14
+ classifier.train_from_csv(
15
+ csv_path=args.csv_path,
16
+ text_col=args.text_col,
17
+ label_col=args.label_col
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ main()
utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Dict, Tuple
3
+ from datetime import datetime
4
+
5
+ class PIIDetector:
6
+ """
7
+ Class for detecting and masking Personally Identifiable Information (PII) in text.
8
+ Uses regular expressions and pattern matching to identify PII entities.
9
+ """
10
+
11
+ def __init__(self):
12
+ # Compile regex patterns for different PII types
13
+ self.patterns = {
14
+ "full_name": re.compile(r'\b([A-Z][a-z]+(\s[A-Z][a-z]+)+)\b'),
15
+ "email": re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'),
16
+ "phone_number": re.compile(r'(\+?\d{1,3}[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b'),
17
+ "dob": re.compile(r'\b(\d{1,2}[-/]\d{1,2}[-/]\d{2,4}|(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]* \d{1,2}, \d{4})\b'),
18
+ "aadhar_num": re.compile(r'\b\d{4}[ -]?\d{4}[ -]?\d{4}\b'),
19
+ "credit_debit_no": re.compile(r'\b(?:\d[ -]*?){13,16}\b'),
20
+ "cvv_no": re.compile(r'\b\d{3,4}\b'),
21
+ "expiry_no": re.compile(r'\b(0[1-9]|1[0-2])[-/]\d{2}\b')
22
+ }
23
+
24
+ def detect_pii(self, text: str) -> List[Dict]:
25
+ """
26
+ Detect all PII entities in the given text.
27
+
28
+ Args:
29
+ text: Input text to scan for PII
30
+
31
+ Returns:
32
+ List of dictionaries containing PII entities with their positions and types
33
+ """
34
+ entities = []
35
+
36
+ for entity_type, pattern in self.patterns.items():
37
+ for match in pattern.finditer(text):
38
+ start, end = match.span()
39
+ entity_value = match.group()
40
+
41
+ # Additional validation for specific entity types
42
+ if entity_type == "credit_debit_no" and not self._validate_luhn(entity_value):
43
+ continue
44
+ if entity_type == "dob" and not self._validate_date(entity_value):
45
+ continue
46
+
47
+ entities.append({
48
+ "position": [start, end],
49
+ "classification": entity_type,
50
+ "entity": entity_value
51
+ })
52
+
53
+ # Sort entities by start position to handle masking in order
54
+ entities.sort(key=lambda x: x["position"][0])
55
+ return entities
56
+
57
+ def mask_pii(self, text: str, entities: List[Dict]) -> Tuple[str, List[Dict]]:
58
+ """
59
+ Mask detected PII entities in the text.
60
+
61
+ Args:
62
+ text: Original text containing PII
63
+ entities: List of detected PII entities
64
+
65
+ Returns:
66
+ Tuple of (masked_text, list_of_masked_entities)
67
+ """
68
+ masked_text = text
69
+ offset = 0
70
+ masked_entities = []
71
+
72
+ for entity in entities:
73
+ start, end = entity["position"]
74
+ entity_type = entity["classification"]
75
+ original_value = entity["entity"]
76
+
77
+ # Adjust positions based on previous replacements
78
+ adj_start = start + offset
79
+ adj_end = end + offset
80
+
81
+ # Create masked token
82
+ masked_token = f"[{entity_type}]"
83
+
84
+ # Replace the entity with masked token
85
+ masked_text = masked_text[:adj_start] + masked_token + masked_text[adj_end:]
86
+
87
+ # Update offset for next replacement
88
+ offset += len(masked_token) - (end - start)
89
+
90
+ # Store masked entity info
91
+ masked_entities.append({
92
+ "position": [start, end],
93
+ "classification": entity_type,
94
+ "entity": original_value
95
+ })
96
+
97
+ return masked_text, masked_entities
98
+
99
+ def _validate_luhn(self, card_number: str) -> bool:
100
+ """Validate credit card number using Luhn algorithm."""
101
+ # Remove non-digit characters
102
+ card_number = re.sub(r'[^0-9]', '', card_number)
103
+
104
+ if not card_number.isdigit() or len(card_number) < 13 or len(card_number) > 19:
105
+ return False
106
+
107
+ digits = list(map(int, card_number))
108
+ checksum = digits[-1]
109
+ total = 0
110
+
111
+ for i, digit in enumerate(digits[:-1]):
112
+ if i % 2 == 0:
113
+ digit *= 2
114
+ if digit > 9:
115
+ digit -= 9
116
+ total += digit
117
+
118
+ return (total * 9) % 10 == checksum
119
+
120
+ def _validate_date(self, date_str: str) -> bool:
121
+ """Validate date of birth."""
122
+ try:
123
+ # Try to parse different date formats
124
+ for fmt in ('%m/%d/%Y', '%m-%d-%Y', '%d/%m/%Y', '%d-%m-%Y',
125
+ '%b %d, %Y', '%B %d, %Y'):
126
+ try:
127
+ datetime.strptime(date_str, fmt)
128
+ return True
129
+ except ValueError:
130
+ continue
131
+ return False
132
+ except:
133
+ return False