Sparkonix commited on
Commit
edc8356
·
1 Parent(s): adf8222

refactored the code

Browse files
Files changed (7) hide show
  1. .gitignore +1 -1
  2. database.py +60 -54
  3. main.py +35 -19
  4. models.py +22 -20
  5. requirements.in +0 -381
  6. upload_model.py +34 -22
  7. utils.py +149 -88
.gitignore CHANGED
@@ -33,7 +33,7 @@ env/
33
  .vscode/
34
  *.swp
35
  *.swo
36
-
37
  # Jupyter Notebook
38
  .ipynb_checkpoints
39
 
 
33
  .vscode/
34
  *.swp
35
  *.swo
36
+ .flake8
37
  # Jupyter Notebook
38
  .ipynb_checkpoints
39
 
database.py CHANGED
@@ -4,7 +4,7 @@ Database module for handling email storage operations with SQLite.
4
  import os
5
  import json
6
  import sqlite3
7
- from typing import Dict, Any, Optional, List, Tuple
8
  from datetime import datetime
9
  import uuid
10
 
@@ -14,29 +14,32 @@ class EmailDatabase:
14
  Database class for storing and retrieving email data with PII masking information.
15
  Uses SQLite for storage in Hugging Face's persistent directory.
16
  """
17
-
18
  def __init__(self, connection_string: str = None):
19
  """
20
  Initialize the database connection.
21
-
22
  Args:
23
- connection_string: Database connection string or path.
24
- For SQLite, this will be treated as a file path.
25
  """
26
  # Hugging Face Spaces has a /data directory that persists between restarts
27
  self.db_path = connection_string or os.environ.get(
28
- "DATABASE_PATH",
29
  "/data/emails.db" # This path persists in Hugging Face Spaces
30
  )
31
-
32
  # Get the global access key from environment variables
33
- self.access_key = os.environ.get("EMAIL_ACCESS_KEY", "default_secure_access_key")
34
-
 
 
 
35
  # Ensure the data directory exists
36
  self._ensure_data_directory()
37
-
38
  self._create_tables()
39
-
40
  def _ensure_data_directory(self):
41
  """Ensure the data directory exists, and use a fallback if needed."""
42
  try:
@@ -47,74 +50,77 @@ class EmailDatabase:
47
  # If we can't write to /data, fall back to the current directory
48
  self.db_path = "emails.db"
49
  print(f"Warning: Using fallback database path: {self.db_path}")
50
-
51
  def _get_connection(self):
52
  """Get a database connection."""
53
  return sqlite3.connect(self.db_path)
54
-
55
  def _create_tables(self):
56
  """Create the necessary tables if they don't exist."""
57
  conn = self._get_connection()
58
  try:
59
  cursor = conn.cursor()
60
-
61
  # Create the emails table to store original emails and their masked versions
62
  cursor.execute('''
63
  CREATE TABLE IF NOT EXISTS emails (
64
  id TEXT PRIMARY KEY,
65
  original_email TEXT NOT NULL,
66
  masked_email TEXT NOT NULL,
67
- masked_entities TEXT NOT NULL,
68
  category TEXT,
69
  created_at TEXT NOT NULL
70
  )
71
  ''')
72
-
73
  conn.commit()
74
  except Exception as e:
75
  conn.rollback()
76
  raise e
77
  finally:
78
  conn.close()
79
-
80
  def _generate_id(self) -> str:
81
  """Generate a unique ID for the email record."""
82
  return str(uuid.uuid4())
83
-
84
- def store_email(self, original_email: str, masked_email: str,
85
- masked_entities: List[Dict[str, Any]], category: Optional[str] = None) -> str:
 
 
86
  """
87
  Store the original email along with its masked version and related information.
88
-
89
  Args:
90
  original_email: The original email with PII
91
  masked_email: The masked version of the email
92
  masked_entities: List of entities that were masked
93
  category: Optional category of the email
94
-
95
  Returns:
96
  email_id for future reference
97
  """
98
  conn = self._get_connection()
99
  try:
100
  cursor = conn.cursor()
101
-
102
  email_id = self._generate_id()
103
-
104
  # Store the email data
105
  cursor.execute(
106
- 'INSERT INTO emails (id, original_email, masked_email, masked_entities, category, created_at) '
 
107
  'VALUES (?, ?, ?, ?, ?, ?)',
108
  (
109
  email_id,
110
  original_email,
111
  masked_email,
112
- json.dumps(masked_entities), # Convert to JSON string for SQLite
113
  category,
114
  datetime.now().isoformat()
115
  )
116
  )
117
-
118
  conn.commit()
119
  return email_id
120
  except Exception as e:
@@ -122,112 +128,112 @@ class EmailDatabase:
122
  raise e
123
  finally:
124
  conn.close()
125
-
126
  def get_original_email(self, email_id: str, access_key: str) -> Optional[Dict[str, Any]]:
127
  """
128
  Retrieve the original email with PII using the access key.
129
-
130
  Args:
131
  email_id: The ID of the email record
132
  access_key: The security key required to access the original email
133
-
134
  Returns:
135
  Dictionary with email data or None if not found or access_key is invalid
136
  """
137
  # Verify the access key matches the global access key
138
  if access_key != self.access_key:
139
  return None
140
-
141
  conn = self._get_connection()
142
  try:
143
  cursor = conn.cursor()
144
-
145
  cursor.execute(
146
- 'SELECT id, original_email, masked_email, masked_entities, category, created_at '
147
- 'FROM emails WHERE id = ?',
148
  (email_id,)
149
  )
150
-
151
  row = cursor.fetchone()
152
  if not row:
153
  return None
154
-
155
  return {
156
  "id": row[0],
157
  "original_email": row[1],
158
  "masked_email": row[2],
159
- "masked_entities": json.loads(row[3]), # Convert from JSON string back to Python dict
160
  "category": row[4],
161
  "created_at": row[5]
162
  }
163
  finally:
164
  conn.close()
165
-
166
  def get_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
167
  """
168
  Retrieve the masked email data without the original PII-containing email.
169
-
170
  Args:
171
  email_id: The ID of the email
172
-
173
  Returns:
174
  Dictionary with masked email data or None if not found
175
  """
176
  conn = self._get_connection()
177
  try:
178
  cursor = conn.cursor()
179
-
180
  cursor.execute(
181
  'SELECT id, masked_email, masked_entities, category, created_at '
182
  'FROM emails WHERE id = ?',
183
  (email_id,)
184
  )
185
-
186
  row = cursor.fetchone()
187
  if not row:
188
  return None
189
-
190
  return {
191
  "id": row[0],
192
  "masked_email": row[1],
193
- "masked_entities": json.loads(row[2]), # Convert from JSON string back to Python dict
194
  "category": row[3],
195
  "created_at": row[4]
196
  }
197
  finally:
198
  conn.close()
199
-
200
  def get_email_by_masked_content(self, masked_email: str) -> Optional[Dict[str, Any]]:
201
  """
202
  Retrieve the original email using the masked email content.
203
-
204
  Args:
205
  masked_email: The masked version of the email to search for
206
-
207
  Returns:
208
  Dictionary with full email data or None if not found
209
  """
210
  conn = self._get_connection()
211
  try:
212
  cursor = conn.cursor()
213
-
214
  cursor.execute(
215
- 'SELECT id, original_email, masked_email, masked_entities, category, created_at '
216
- 'FROM emails WHERE masked_email = ?',
217
  (masked_email,)
218
  )
219
-
220
  row = cursor.fetchone()
221
  if not row:
222
  return None
223
-
224
  return {
225
  "id": row[0],
226
  "original_email": row[1],
227
  "masked_email": row[2],
228
- "masked_entities": json.loads(row[3]), # Convert from JSON string back to Python dict
229
  "category": row[4],
230
  "created_at": row[5]
231
  }
232
  finally:
233
- conn.close()
 
4
  import os
5
  import json
6
  import sqlite3
7
+ from typing import Dict, Any, Optional, List
8
  from datetime import datetime
9
  import uuid
10
 
 
14
  Database class for storing and retrieving email data with PII masking information.
15
  Uses SQLite for storage in Hugging Face's persistent directory.
16
  """
17
+
18
  def __init__(self, connection_string: str = None):
19
  """
20
  Initialize the database connection.
21
+
22
  Args:
23
+ connection_string: Database connection string or path.
24
+ For SQLite, this will be treated as a file path.
25
  """
26
  # Hugging Face Spaces has a /data directory that persists between restarts
27
  self.db_path = connection_string or os.environ.get(
28
+ "DATABASE_PATH",
29
  "/data/emails.db" # This path persists in Hugging Face Spaces
30
  )
31
+
32
  # Get the global access key from environment variables
33
+ self.access_key = os.environ.get(
34
+ "EMAIL_ACCESS_KEY",
35
+ "default_secure_access_key"
36
+ )
37
+
38
  # Ensure the data directory exists
39
  self._ensure_data_directory()
40
+
41
  self._create_tables()
42
+
43
  def _ensure_data_directory(self):
44
  """Ensure the data directory exists, and use a fallback if needed."""
45
  try:
 
50
  # If we can't write to /data, fall back to the current directory
51
  self.db_path = "emails.db"
52
  print(f"Warning: Using fallback database path: {self.db_path}")
53
+
54
  def _get_connection(self):
55
  """Get a database connection."""
56
  return sqlite3.connect(self.db_path)
57
+
58
  def _create_tables(self):
59
  """Create the necessary tables if they don't exist."""
60
  conn = self._get_connection()
61
  try:
62
  cursor = conn.cursor()
63
+
64
  # Create the emails table to store original emails and their masked versions
65
  cursor.execute('''
66
  CREATE TABLE IF NOT EXISTS emails (
67
  id TEXT PRIMARY KEY,
68
  original_email TEXT NOT NULL,
69
  masked_email TEXT NOT NULL,
70
+ masked_entities TEXT NOT NULL,
71
  category TEXT,
72
  created_at TEXT NOT NULL
73
  )
74
  ''')
75
+
76
  conn.commit()
77
  except Exception as e:
78
  conn.rollback()
79
  raise e
80
  finally:
81
  conn.close()
82
+
83
  def _generate_id(self) -> str:
84
  """Generate a unique ID for the email record."""
85
  return str(uuid.uuid4())
86
+
87
+ def store_email(
88
+ self, original_email: str, masked_email: str,
89
+ masked_entities: List[Dict[str, Any]], category: Optional[str] = None
90
+ ) -> str:
91
  """
92
  Store the original email along with its masked version and related information.
93
+
94
  Args:
95
  original_email: The original email with PII
96
  masked_email: The masked version of the email
97
  masked_entities: List of entities that were masked
98
  category: Optional category of the email
99
+
100
  Returns:
101
  email_id for future reference
102
  """
103
  conn = self._get_connection()
104
  try:
105
  cursor = conn.cursor()
106
+
107
  email_id = self._generate_id()
108
+
109
  # Store the email data
110
  cursor.execute(
111
+ 'INSERT INTO emails '
112
+ '(id, original_email, masked_email, masked_entities, category, created_at) '
113
  'VALUES (?, ?, ?, ?, ?, ?)',
114
  (
115
  email_id,
116
  original_email,
117
  masked_email,
118
+ json.dumps(masked_entities), # JSON string for SQLite
119
  category,
120
  datetime.now().isoformat()
121
  )
122
  )
123
+
124
  conn.commit()
125
  return email_id
126
  except Exception as e:
 
128
  raise e
129
  finally:
130
  conn.close()
131
+
132
  def get_original_email(self, email_id: str, access_key: str) -> Optional[Dict[str, Any]]:
133
  """
134
  Retrieve the original email with PII using the access key.
135
+
136
  Args:
137
  email_id: The ID of the email record
138
  access_key: The security key required to access the original email
139
+
140
  Returns:
141
  Dictionary with email data or None if not found or access_key is invalid
142
  """
143
  # Verify the access key matches the global access key
144
  if access_key != self.access_key:
145
  return None
146
+
147
  conn = self._get_connection()
148
  try:
149
  cursor = conn.cursor()
150
+
151
  cursor.execute(
152
+ 'SELECT id, original_email, masked_email, masked_entities, category, '
153
+ 'created_at FROM emails WHERE id = ?',
154
  (email_id,)
155
  )
156
+
157
  row = cursor.fetchone()
158
  if not row:
159
  return None
160
+
161
  return {
162
  "id": row[0],
163
  "original_email": row[1],
164
  "masked_email": row[2],
165
+ "masked_entities": json.loads(row[3]), # Convert JSON to dict
166
  "category": row[4],
167
  "created_at": row[5]
168
  }
169
  finally:
170
  conn.close()
171
+
172
  def get_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
173
  """
174
  Retrieve the masked email data without the original PII-containing email.
175
+
176
  Args:
177
  email_id: The ID of the email
178
+
179
  Returns:
180
  Dictionary with masked email data or None if not found
181
  """
182
  conn = self._get_connection()
183
  try:
184
  cursor = conn.cursor()
185
+
186
  cursor.execute(
187
  'SELECT id, masked_email, masked_entities, category, created_at '
188
  'FROM emails WHERE id = ?',
189
  (email_id,)
190
  )
191
+
192
  row = cursor.fetchone()
193
  if not row:
194
  return None
195
+
196
  return {
197
  "id": row[0],
198
  "masked_email": row[1],
199
+ "masked_entities": json.loads(row[2]), # Convert JSON to dict
200
  "category": row[3],
201
  "created_at": row[4]
202
  }
203
  finally:
204
  conn.close()
205
+
206
  def get_email_by_masked_content(self, masked_email: str) -> Optional[Dict[str, Any]]:
207
  """
208
  Retrieve the original email using the masked email content.
209
+
210
  Args:
211
  masked_email: The masked version of the email to search for
212
+
213
  Returns:
214
  Dictionary with full email data or None if not found
215
  """
216
  conn = self._get_connection()
217
  try:
218
  cursor = conn.cursor()
219
+
220
  cursor.execute(
221
+ 'SELECT id, original_email, masked_email, masked_entities, category, '
222
+ 'created_at FROM emails WHERE masked_email = ?',
223
  (masked_email,)
224
  )
225
+
226
  row = cursor.fetchone()
227
  if not row:
228
  return None
229
+
230
  return {
231
  "id": row[0],
232
  "original_email": row[1],
233
  "masked_email": row[2],
234
+ "masked_entities": json.loads(row[3]), # Convert JSON to dict
235
  "category": row[4],
236
  "created_at": row[5]
237
  }
238
  finally:
239
+ conn.close()
main.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
- from typing import Dict, Any, List, Tuple, Optional
5
  import uvicorn
6
  from dotenv import load_dotenv
7
 
@@ -21,7 +21,7 @@ else:
21
  db_path = "emails.db" # Fallback to local directory
22
 
23
  # Initialize the FastAPI application
24
- app = FastAPI(title="Email Classification API",
25
  description="API for classifying support emails and masking PII",
26
  version="1.0.0")
27
 
@@ -29,16 +29,19 @@ app = FastAPI(title="Email Classification API",
29
  pii_masker = PIIMasker(db_path=db_path)
30
  email_classifier = EmailClassifier()
31
 
 
32
  class EmailInput(BaseModel):
33
  """Input model for the email classification endpoint"""
34
  input_email_body: str
35
 
 
36
  class EntityInfo(BaseModel):
37
  """Model for entity information"""
38
  position: Tuple[int, int]
39
- classification: str
40
  entity: str
41
 
 
42
  class EmailOutput(BaseModel):
43
  """Output model for the email classification endpoint"""
44
  input_email_body: str
@@ -46,29 +49,30 @@ class EmailOutput(BaseModel):
46
  masked_email: str
47
  category_of_the_email: str
48
 
 
49
  class MaskedEmailInput(BaseModel):
50
  """Input model for retrieving original email by masked email content"""
51
  masked_email: str
52
  access_key: str
53
 
 
54
  @app.post("/classify", response_model=EmailOutput)
55
  async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
56
  """
57
  Classify an email into a support category while masking PII
58
-
59
  Args:
60
  email_input: The input email data
61
-
62
  Returns:
63
  The classified email data with masked PII
64
  """
65
  try:
66
  # Process the email to mask PII and store original in database
67
  processed_data = pii_masker.process_email(email_input.input_email_body)
68
-
69
  # Classify the masked email
70
  classified_data = email_classifier.process_email(processed_data)
71
-
72
  # Make sure we return only the fields expected in the response model
73
  return {
74
  "input_email_body": email_input.input_email_body,
@@ -79,28 +83,35 @@ async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
79
  except Exception as e:
80
  raise HTTPException(status_code=500, detail=f"Error processing email: {str(e)}")
81
 
 
82
  @app.post("/api/v1/unmask-email", response_model=Dict[str, Any])
83
  async def unmask_email(masked_email_input: MaskedEmailInput) -> Dict[str, Any]:
84
  """
85
- Retrieve the original unmasked email using the masked email content from the classify response.
86
-
87
  Args:
88
  masked_email_input: Contains the masked email and access key
89
-
90
  Returns:
91
  The original email data with PII information
92
  """
93
  try:
94
  # Verify access key matches the global access key
95
- if masked_email_input.access_key != os.environ.get("EMAIL_ACCESS_KEY", "default_secure_access_key"):
 
96
  raise HTTPException(status_code=401, detail="Invalid access key")
97
-
98
  # Retrieve the original email using the masked content
99
- email_data = pii_masker.get_original_by_masked_email(masked_email_input.masked_email)
100
-
 
 
101
  if not email_data:
102
- raise HTTPException(status_code=404, detail="Original email not found for the provided masked email")
103
-
 
 
 
104
  return {
105
  "status": "success",
106
  "data": {
@@ -116,19 +127,24 @@ async def unmask_email(masked_email_input: MaskedEmailInput) -> Dict[str, Any]:
116
  except Exception as e:
117
  if isinstance(e, HTTPException):
118
  raise e
119
- raise HTTPException(status_code=500, detail=f"Error retrieving original email: {str(e)}")
 
 
 
 
120
 
121
  @app.get("/health")
122
  async def health_check():
123
  """
124
  Health check endpoint
125
-
126
  Returns:
127
  Status message indicating the API is running
128
  """
129
  return {"status": "healthy", "message": "Email classification API is running"}
130
 
 
131
  # For local development and testing
132
  if __name__ == "__main__":
133
  port = int(os.environ.get("PORT", 8000))
134
- uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)
 
1
  import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
+ from typing import Dict, Any, List, Tuple
5
  import uvicorn
6
  from dotenv import load_dotenv
7
 
 
21
  db_path = "emails.db" # Fallback to local directory
22
 
23
  # Initialize the FastAPI application
24
+ app = FastAPI(title="Email Classification API",
25
  description="API for classifying support emails and masking PII",
26
  version="1.0.0")
27
 
 
29
  pii_masker = PIIMasker(db_path=db_path)
30
  email_classifier = EmailClassifier()
31
 
32
+
33
  class EmailInput(BaseModel):
34
  """Input model for the email classification endpoint"""
35
  input_email_body: str
36
 
37
+
38
  class EntityInfo(BaseModel):
39
  """Model for entity information"""
40
  position: Tuple[int, int]
41
+ classification: str
42
  entity: str
43
 
44
+
45
  class EmailOutput(BaseModel):
46
  """Output model for the email classification endpoint"""
47
  input_email_body: str
 
49
  masked_email: str
50
  category_of_the_email: str
51
 
52
+
53
  class MaskedEmailInput(BaseModel):
54
  """Input model for retrieving original email by masked email content"""
55
  masked_email: str
56
  access_key: str
57
 
58
+
59
  @app.post("/classify", response_model=EmailOutput)
60
  async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
61
  """
62
  Classify an email into a support category while masking PII
63
+
64
  Args:
65
  email_input: The input email data
66
+
67
  Returns:
68
  The classified email data with masked PII
69
  """
70
  try:
71
  # Process the email to mask PII and store original in database
72
  processed_data = pii_masker.process_email(email_input.input_email_body)
 
73
  # Classify the masked email
74
  classified_data = email_classifier.process_email(processed_data)
75
+
76
  # Make sure we return only the fields expected in the response model
77
  return {
78
  "input_email_body": email_input.input_email_body,
 
83
  except Exception as e:
84
  raise HTTPException(status_code=500, detail=f"Error processing email: {str(e)}")
85
 
86
+
87
  @app.post("/api/v1/unmask-email", response_model=Dict[str, Any])
88
  async def unmask_email(masked_email_input: MaskedEmailInput) -> Dict[str, Any]:
89
  """
90
+ Retrieve the original unmasked email.
91
+
92
  Args:
93
  masked_email_input: Contains the masked email and access key
94
+
95
  Returns:
96
  The original email data with PII information
97
  """
98
  try:
99
  # Verify access key matches the global access key
100
+ if masked_email_input.access_key != os.environ.get(
101
+ "EMAIL_ACCESS_KEY", "default_secure_access_key"):
102
  raise HTTPException(status_code=401, detail="Invalid access key")
103
+
104
  # Retrieve the original email using the masked content
105
+ email_data = pii_masker.get_original_by_masked_email(
106
+ masked_email_input.masked_email
107
+ )
108
+
109
  if not email_data:
110
+ raise HTTPException(
111
+ status_code=404,
112
+ detail="Original email not found for the provided masked email"
113
+ )
114
+
115
  return {
116
  "status": "success",
117
  "data": {
 
127
  except Exception as e:
128
  if isinstance(e, HTTPException):
129
  raise e
130
+ raise HTTPException(
131
+ status_code=500,
132
+ detail=f"Error retrieving original email: {str(e)}"
133
+ )
134
+
135
 
136
  @app.get("/health")
137
  async def health_check():
138
  """
139
  Health check endpoint
140
+
141
  Returns:
142
  Status message indicating the API is running
143
  """
144
  return {"status": "healthy", "message": "Email classification API is running"}
145
 
146
+
147
  # For local development and testing
148
  if __name__ == "__main__":
149
  port = int(os.environ.get("PORT", 8000))
150
+ uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)
models.py CHANGED
@@ -3,79 +3,81 @@ import torch
3
  from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
4
  from typing import Dict, Any
5
 
 
6
  class EmailClassifier:
7
  """
8
  Email classification model to categorize emails into different support categories
9
  """
10
-
11
  CATEGORIES = ['Change', 'Incident', 'Problem', 'Request']
12
-
13
  def __init__(self, model_path: str = None):
14
  """
15
  Initialize the email classifier with a pre-trained model
16
-
17
  Args:
18
  model_path: Path or Hugging Face Hub model ID
19
  """
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
  # Use environment variable for model path or fall back to Hugging Face Hub model
23
  # This allows for flexibility in deployment
24
- model_path = model_path or os.environ.get("MODEL_PATH", "Sparkonix11/email-classifier-model")
25
-
 
 
26
  # Load the tokenizer and model from Hugging Face Hub or local path
27
  self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
28
  self.model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
29
  self.model.to(self.device)
30
  self.model.eval()
31
-
32
  def classify(self, masked_email: str) -> str:
33
  """
34
  Classify a masked email into one of the predefined categories
35
-
36
  Args:
37
  masked_email: The email content with PII masked
38
-
39
  Returns:
40
  The predicted category as a string
41
  """
42
  # Tokenize the masked email
43
  inputs = self.tokenizer(
44
- masked_email,
45
  return_tensors="pt",
46
  padding="max_length",
47
  truncation=True,
48
  max_length=512
49
  )
50
-
51
  inputs = {key: val.to(self.device) for key, val in inputs.items()}
52
-
53
  # Perform inference
54
  with torch.no_grad():
55
  outputs = self.model(**inputs)
56
  logits = outputs.logits
57
  predicted_class_idx = torch.argmax(logits, dim=1).item()
58
-
59
  # Map the predicted class index to the category
60
  return self.CATEGORIES[predicted_class_idx]
61
-
62
  def process_email(self, masked_email_data: Dict[str, Any]) -> Dict[str, Any]:
63
  """
64
  Process an email by classifying it into a category
65
-
66
  Args:
67
  masked_email_data: Dictionary containing the masked email and other data
68
-
69
  Returns:
70
  The input dictionary with the classification added
71
  """
72
  # Extract masked email content
73
  masked_email = masked_email_data["masked_email"]
74
-
75
  # Classify the masked email
76
  category = self.classify(masked_email)
77
-
78
  # Add the classification to the data
79
  masked_email_data["category_of_the_email"] = category
80
-
81
- return masked_email_data
 
3
  from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
4
  from typing import Dict, Any
5
 
6
+
7
  class EmailClassifier:
8
  """
9
  Email classification model to categorize emails into different support categories
10
  """
 
11
  CATEGORIES = ['Change', 'Incident', 'Problem', 'Request']
12
+
13
  def __init__(self, model_path: str = None):
14
  """
15
  Initialize the email classifier with a pre-trained model
16
+
17
  Args:
18
  model_path: Path or Hugging Face Hub model ID
19
  """
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
  # Use environment variable for model path or fall back to Hugging Face Hub model
23
  # This allows for flexibility in deployment
24
+ model_path = model_path or os.environ.get(
25
+ "MODEL_PATH", "Sparkonix11/email-classifier-model"
26
+ )
27
+
28
  # Load the tokenizer and model from Hugging Face Hub or local path
29
  self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
30
  self.model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
31
  self.model.to(self.device)
32
  self.model.eval()
33
+
34
  def classify(self, masked_email: str) -> str:
35
  """
36
  Classify a masked email into one of the predefined categories
37
+
38
  Args:
39
  masked_email: The email content with PII masked
40
+
41
  Returns:
42
  The predicted category as a string
43
  """
44
  # Tokenize the masked email
45
  inputs = self.tokenizer(
46
+ masked_email,
47
  return_tensors="pt",
48
  padding="max_length",
49
  truncation=True,
50
  max_length=512
51
  )
52
+
53
  inputs = {key: val.to(self.device) for key, val in inputs.items()}
54
+
55
  # Perform inference
56
  with torch.no_grad():
57
  outputs = self.model(**inputs)
58
  logits = outputs.logits
59
  predicted_class_idx = torch.argmax(logits, dim=1).item()
60
+
61
  # Map the predicted class index to the category
62
  return self.CATEGORIES[predicted_class_idx]
63
+
64
  def process_email(self, masked_email_data: Dict[str, Any]) -> Dict[str, Any]:
65
  """
66
  Process an email by classifying it into a category
67
+
68
  Args:
69
  masked_email_data: Dictionary containing the masked email and other data
70
+
71
  Returns:
72
  The input dictionary with the classification added
73
  """
74
  # Extract masked email content
75
  masked_email = masked_email_data["masked_email"]
76
+
77
  # Classify the masked email
78
  category = self.classify(masked_email)
79
+
80
  # Add the classification to the data
81
  masked_email_data["category_of_the_email"] = category
82
+
83
+ return masked_email_data
requirements.in DELETED
@@ -1,381 +0,0 @@
1
- # This file was autogenerated by uv via the following command:
2
- # uv pip compile requirements.in -o requirements.txt
3
- annotated-types==0.7.0
4
- # via
5
- # -r requirements.in
6
- # pydantic
7
- anyio==4.9.0
8
- # via
9
- # -r requirements.in
10
- # starlette
11
- blis==1.3.0
12
- # via
13
- # -r requirements.in
14
- # thinc
15
- catalogue==2.0.10
16
- # via
17
- # -r requirements.in
18
- # spacy
19
- # srsly
20
- # thinc
21
- certifi==2025.4.26
22
- # via
23
- # -r requirements.in
24
- # requests
25
- charset-normalizer==3.4.2
26
- # via
27
- # -r requirements.in
28
- # requests
29
- click==8.2.0
30
- # via
31
- # -r requirements.in
32
- # typer
33
- # uvicorn
34
- cloudpathlib==0.21.1
35
- # via
36
- # -r requirements.in
37
- # weasel
38
- confection==0.1.5
39
- # via
40
- # -r requirements.in
41
- # thinc
42
- # weasel
43
- cymem==2.0.11
44
- # via
45
- # -r requirements.in
46
- # preshed
47
- # spacy
48
- # thinc
49
- en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
50
- # via -r requirements.in
51
- exceptiongroup==1.3.0
52
- # via
53
- # -r requirements.in
54
- # anyio
55
- fastapi==0.115.12
56
- # via -r requirements.in
57
- filelock==3.18.0
58
- # via
59
- # -r requirements.in
60
- # huggingface-hub
61
- # torch
62
- # transformers
63
- fsspec==2025.3.2
64
- # via
65
- # -r requirements.in
66
- # huggingface-hub
67
- # torch
68
- h11==0.16.0
69
- # via
70
- # -r requirements.in
71
- # uvicorn
72
- huggingface-hub==0.31.2
73
- # via
74
- # -r requirements.in
75
- # tokenizers
76
- # transformers
77
- idna==3.10
78
- # via
79
- # -r requirements.in
80
- # anyio
81
- # requests
82
- jinja2==3.1.6
83
- # via
84
- # -r requirements.in
85
- # spacy
86
- # torch
87
- langcodes==3.5.0
88
- # via
89
- # -r requirements.in
90
- # spacy
91
- language-data==1.3.0
92
- # via
93
- # -r requirements.in
94
- # langcodes
95
- marisa-trie==1.2.1
96
- # via
97
- # -r requirements.in
98
- # language-data
99
- markdown-it-py==3.0.0
100
- # via
101
- # -r requirements.in
102
- # rich
103
- markupsafe==3.0.2
104
- # via
105
- # -r requirements.in
106
- # jinja2
107
- mdurl==0.1.2
108
- # via
109
- # -r requirements.in
110
- # markdown-it-py
111
- mpmath==1.3.0
112
- # via
113
- # -r requirements.in
114
- # sympy
115
- murmurhash==1.0.12
116
- # via
117
- # -r requirements.in
118
- # preshed
119
- # spacy
120
- # thinc
121
- networkx==3.4.2
122
- # via
123
- # -r requirements.in
124
- # torch
125
- numpy==2.2.5
126
- # via
127
- # -r requirements.in
128
- # blis
129
- # spacy
130
- # spacy-transformers
131
- # thinc
132
- # transformers
133
-
134
- # SQLite is included in Python standard library
135
- python-dotenv
136
- # for environment variable management
137
-
138
- nvidia-cublas-cu12==12.6.4.1
139
- # via
140
- # -r requirements.in
141
- # nvidia-cudnn-cu12
142
- # nvidia-cusolver-cu12
143
- # torch
144
- nvidia-cuda-cupti-cu12==12.6.80
145
- # via
146
- # -r requirements.in
147
- # torch
148
- nvidia-cuda-nvrtc-cu12==12.6.77
149
- # via
150
- # -r requirements.in
151
- # torch
152
- nvidia-cuda-runtime-cu12==12.6.77
153
- # via
154
- # -r requirements.in
155
- # torch
156
- nvidia-cudnn-cu12==9.5.1.17
157
- # via
158
- # -r requirements.in
159
- # torch
160
- nvidia-cufft-cu12==11.3.0.4
161
- # via
162
- # -r requirements.in
163
- # torch
164
- nvidia-cufile-cu12==1.11.1.6
165
- # via
166
- # -r requirements.in
167
- # torch
168
- nvidia-curand-cu12==10.3.7.77
169
- # via
170
- # -r requirements.in
171
- # torch
172
- nvidia-cusolver-cu12==11.7.1.2
173
- # via
174
- # -r requirements.in
175
- # torch
176
- nvidia-cusparse-cu12==12.5.4.2
177
- # via
178
- # -r requirements.in
179
- # nvidia-cusolver-cu12
180
- # torch
181
- nvidia-cusparselt-cu12==0.6.3
182
- # via
183
- # -r requirements.in
184
- # torch
185
- nvidia-nccl-cu12==2.26.2
186
- # via
187
- # -r requirements.in
188
- # torch
189
- nvidia-nvjitlink-cu12==12.6.85
190
- # via
191
- # -r requirements.in
192
- # nvidia-cufft-cu12
193
- # nvidia-cusolver-cu12
194
- # nvidia-cusparse-cu12
195
- # torch
196
- nvidia-nvtx-cu12==12.6.77
197
- # via
198
- # -r requirements.in
199
- # torch
200
- packaging==25.0
201
- # via
202
- # -r requirements.in
203
- # huggingface-hub
204
- # spacy
205
- # thinc
206
- # transformers
207
- # weasel
208
- preshed==3.0.9
209
- # via
210
- # -r requirements.in
211
- # spacy
212
- # thinc
213
- pydantic==2.11.4
214
- # via
215
- # -r requirements.in
216
- # confection
217
- # fastapi
218
- # spacy
219
- # thinc
220
- # weasel
221
- pydantic-core==2.33.2
222
- # via
223
- # -r requirements.in
224
- # pydantic
225
- pygments==2.19.1
226
- # via
227
- # -r requirements.in
228
- # rich
229
- python-multipart==0.0.20
230
- # via -r requirements.in
231
- pyyaml==6.0.2
232
- # via
233
- # -r requirements.in
234
- # huggingface-hub
235
- # transformers
236
- regex==2024.11.6
237
- # via
238
- # -r requirements.in
239
- # transformers
240
- requests==2.32.3
241
- # via
242
- # -r requirements.in
243
- # huggingface-hub
244
- # spacy
245
- # transformers
246
- # weasel
247
- rich==14.0.0
248
- # via
249
- # -r requirements.in
250
- # typer
251
- safetensors==0.5.3
252
- # via
253
- # -r requirements.in
254
- # transformers
255
- sentencepiece==0.2.0
256
- # via -r requirements.in
257
- setuptools==80.7.1
258
- # via
259
- # -r requirements.in
260
- # marisa-trie
261
- # spacy
262
- # thinc
263
- # triton
264
- shellingham==1.5.4
265
- # via
266
- # -r requirements.in
267
- # typer
268
- smart-open==7.1.0
269
- # via
270
- # -r requirements.in
271
- # weasel
272
- sniffio==1.3.1
273
- # via
274
- # -r requirements.in
275
- # anyio
276
- spacy==3.8.5
277
- # via
278
- # -r requirements.in
279
- # spacy-transformers
280
- spacy-alignments==0.9.1
281
- # via
282
- # -r requirements.in
283
- # spacy-transformers
284
- spacy-legacy==3.0.12
285
- # via
286
- # -r requirements.in
287
- # spacy
288
- spacy-loggers==1.0.5
289
- # via
290
- # -r requirements.in
291
- # spacy
292
- spacy-transformers==1.3.8
293
- # via -r requirements.in
294
- srsly==2.5.1
295
- # via
296
- # -r requirements.in
297
- # confection
298
- # spacy
299
- # spacy-transformers
300
- # thinc
301
- # weasel
302
- starlette==0.46.2
303
- # via
304
- # -r requirements.in
305
- # fastapi
306
- sympy==1.14.0
307
- # via
308
- # -r requirements.in
309
- # torch
310
- thinc==8.3.6
311
- # via
312
- # -r requirements.in
313
- # spacy
314
- tokenizers==0.21.1
315
- # via
316
- # -r requirements.in
317
- # transformers
318
- torch==2.7.0
319
- # via
320
- # -r requirements.in
321
- # spacy-transformers
322
- tqdm==4.67.1
323
- # via
324
- # -r requirements.in
325
- # huggingface-hub
326
- # spacy
327
- # transformers
328
- transformers==4.49.0
329
- # via
330
- # -r requirements.in
331
- # spacy-transformers
332
- triton==3.3.0
333
- # via
334
- # -r requirements.in
335
- # torch
336
- typer==0.15.3
337
- # via
338
- # -r requirements.in
339
- # spacy
340
- # weasel
341
- typing-extensions==4.13.2
342
- # via
343
- # -r requirements.in
344
- # anyio
345
- # cloudpathlib
346
- # exceptiongroup
347
- # fastapi
348
- # huggingface-hub
349
- # pydantic
350
- # pydantic-core
351
- # rich
352
- # torch
353
- # typer
354
- # typing-inspection
355
- # uvicorn
356
- typing-inspection==0.4.0
357
- # via
358
- # -r requirements.in
359
- # pydantic
360
- urllib3==2.4.0
361
- # via
362
- # -r requirements.in
363
- # requests
364
- uvicorn==0.34.2
365
- # via -r requirements.in
366
- wasabi==1.1.3
367
- # via
368
- # -r requirements.in
369
- # spacy
370
- # thinc
371
- # weasel
372
- weasel==0.4.1
373
- # via
374
- # -r requirements.in
375
- # spacy
376
- wrapt==1.17.2
377
- # via
378
- # -r requirements.in
379
- # smart-open
380
- xx-ent-wiki-sm @ https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.8.0/xx_ent_wiki_sm-3.8.0-py3-none-any.whl
381
- # via -r requirements.in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
upload_model.py CHANGED
@@ -2,27 +2,30 @@
2
  Script to upload the email classification model to Hugging Face Hub
3
  """
4
 
5
- import os
6
  import sys
7
  import argparse
8
  import subprocess
9
  import pkg_resources
10
 
 
11
  def check_and_install_dependencies():
12
  """Check for required libraries and install if missing"""
13
  required_packages = ['torch', 'transformers', 'sentencepiece']
14
  installed_packages = {pkg.key for pkg in pkg_resources.working_set}
15
-
16
  missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages]
17
-
18
  if missing_packages:
19
- print(f"Installing missing dependencies: {', '.join(missing_packages)}")
20
- subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing_packages)
 
 
21
  print("Dependencies installed. You may need to restart the script.")
22
  return False
23
-
24
  return True
25
 
 
26
  def get_huggingface_username(token=None):
27
  """Get the username for the authenticated user"""
28
  try:
@@ -34,51 +37,57 @@ def get_huggingface_username(token=None):
34
  print(f"Error getting Hugging Face username: {e}")
35
  return None
36
 
 
37
  def main():
38
  """Upload model to Hugging Face Hub"""
39
  # Check dependencies first
40
  if not check_and_install_dependencies():
41
  return
42
-
43
  # Import dependencies after installation check
44
  from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
45
  from huggingface_hub import login
46
-
47
- parser = argparse.ArgumentParser(description="Upload email classification model to Hugging Face Hub")
 
48
  parser.add_argument("--model_path", type=str, default="classification_model",
49
  help="Local path to the model files")
50
  parser.add_argument("--hub_model_id", type=str,
51
- help="Hugging Face Hub model ID (e.g., 'username/email-classifier-model')")
 
52
  parser.add_argument("--model_name", type=str, default="email-classifier-model",
53
- help="Name for the model repository (default: email-classifier-model)")
 
54
  parser.add_argument("--token", type=str,
55
- help="Hugging Face API token (optional, can use environment variable or huggingface-cli login)")
56
-
 
57
  args = parser.parse_args()
58
-
59
  # Login if token is provided
60
  if args.token:
61
  login(token=args.token)
62
-
63
  # If hub_model_id is not provided, try to get username and construct it
64
  if not args.hub_model_id:
65
  username = get_huggingface_username(args.token)
66
  if not username:
67
- print("Could not determine Hugging Face username. Please provide --hub_model_id explicitly.")
 
68
  return
69
  args.hub_model_id = f"{username}/{args.model_name}"
70
-
71
  print(f"Loading model from {args.model_path}...")
72
  # Load the local model and tokenizer
73
  model = XLMRobertaForSequenceClassification.from_pretrained(args.model_path)
74
  tokenizer = XLMRobertaTokenizer.from_pretrained(args.model_path)
75
-
76
  print(f"Uploading model to {args.hub_model_id}...")
77
  try:
78
  # Push to Hugging Face Hub
79
  model.push_to_hub(args.hub_model_id)
80
  tokenizer.push_to_hub(args.hub_model_id)
81
-
82
  print("Model successfully uploaded to Hugging Face Hub!")
83
  print(f"You can now use the model with the ID: {args.hub_model_id}")
84
  print(f"Update the MODEL_PATH in Dockerfile to: {args.hub_model_id}")
@@ -86,8 +95,11 @@ def main():
86
  print(f"Error uploading model: {e}")
87
  print("\nPossible solutions:")
88
  print("1. Make sure you're logged in with 'huggingface-cli login'")
89
- print("2. Check that you have permission to create repos in the specified namespace")
90
- print("3. Try using your own username: --hub_model_id yourusername/email-classifier-model")
 
 
 
91
 
92
  if __name__ == "__main__":
93
- main()
 
2
  Script to upload the email classification model to Hugging Face Hub
3
  """
4
 
 
5
  import sys
6
  import argparse
7
  import subprocess
8
  import pkg_resources
9
 
10
+
11
  def check_and_install_dependencies():
12
  """Check for required libraries and install if missing"""
13
  required_packages = ['torch', 'transformers', 'sentencepiece']
14
  installed_packages = {pkg.key for pkg in pkg_resources.working_set}
15
+
16
  missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages]
17
+
18
  if missing_packages:
19
+ missing_packages_str = ", ".join(missing_packages)
20
+ print(f"Installing missing dependencies: {missing_packages_str}")
21
+ subprocess.check_call([sys.executable, "-m", "pip", "install"]
22
+ + missing_packages)
23
  print("Dependencies installed. You may need to restart the script.")
24
  return False
25
+
26
  return True
27
 
28
+
29
  def get_huggingface_username(token=None):
30
  """Get the username for the authenticated user"""
31
  try:
 
37
  print(f"Error getting Hugging Face username: {e}")
38
  return None
39
 
40
+
41
  def main():
42
  """Upload model to Hugging Face Hub"""
43
  # Check dependencies first
44
  if not check_and_install_dependencies():
45
  return
46
+
47
  # Import dependencies after installation check
48
  from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
49
  from huggingface_hub import login
50
+
51
+ parser = argparse.ArgumentParser(
52
+ description="Upload email classification model to Hugging Face Hub")
53
  parser.add_argument("--model_path", type=str, default="classification_model",
54
  help="Local path to the model files")
55
  parser.add_argument("--hub_model_id", type=str,
56
+ help="Hugging Face Hub model ID (e.g., "
57
+ "'username/email-classifier-model')")
58
  parser.add_argument("--model_name", type=str, default="email-classifier-model",
59
+ help="Name for the model repository "
60
+ "(default: email-classifier-model)")
61
  parser.add_argument("--token", type=str,
62
+ help="Hugging Face API token (optional, can use "
63
+ "environment variable or huggingface-cli login)")
64
+
65
  args = parser.parse_args()
66
+
67
  # Login if token is provided
68
  if args.token:
69
  login(token=args.token)
70
+
71
  # If hub_model_id is not provided, try to get username and construct it
72
  if not args.hub_model_id:
73
  username = get_huggingface_username(args.token)
74
  if not username:
75
+ print("Could not determine Hugging Face username. "
76
+ "Please provide --hub_model_id explicitly.")
77
  return
78
  args.hub_model_id = f"{username}/{args.model_name}"
79
+
80
  print(f"Loading model from {args.model_path}...")
81
  # Load the local model and tokenizer
82
  model = XLMRobertaForSequenceClassification.from_pretrained(args.model_path)
83
  tokenizer = XLMRobertaTokenizer.from_pretrained(args.model_path)
84
+
85
  print(f"Uploading model to {args.hub_model_id}...")
86
  try:
87
  # Push to Hugging Face Hub
88
  model.push_to_hub(args.hub_model_id)
89
  tokenizer.push_to_hub(args.hub_model_id)
90
+
91
  print("Model successfully uploaded to Hugging Face Hub!")
92
  print(f"You can now use the model with the ID: {args.hub_model_id}")
93
  print(f"Update the MODEL_PATH in Dockerfile to: {args.hub_model_id}")
 
95
  print(f"Error uploading model: {e}")
96
  print("\nPossible solutions:")
97
  print("1. Make sure you're logged in with 'huggingface-cli login'")
98
+ print("2. Check that you have permission to create repos in the "
99
+ "specified namespace")
100
+ print("3. Try using your own username: "
101
+ "--hub_model_id yourusername/email-classifier-model")
102
+
103
 
104
  if __name__ == "__main__":
105
+ main()
utils.py CHANGED
@@ -1,8 +1,10 @@
1
  import re
2
  import spacy
3
  from typing import List, Dict, Tuple, Any, Optional
 
4
  from database import EmailDatabase
5
 
 
6
  class Entity:
7
  def __init__(self, start: int, end: int, entity_type: str, value: str):
8
  self.start = start
@@ -17,11 +19,19 @@ class Entity:
17
  "entity": self.value
18
  }
19
 
20
- def __repr__(self): # Added for easier debugging
21
- return f"Entity(type='{self.entity_type}', value='{self.value}', start={self.start}, end={self.end})"
 
 
 
 
22
 
23
  class PIIMasker:
24
- def __init__(self, spacy_model_name: str = "xx_ent_wiki_sm", db_path: str = None): # Allow model choice
 
 
 
 
25
  # Load SpaCy model
26
  try:
27
  self.nlp = spacy.load(spacy_model_name)
@@ -42,7 +52,7 @@ class PIIMasker:
42
 
43
  # Initialize database connection with SQLite path
44
  self.db = EmailDatabase(connection_string=db_path)
45
-
46
  # Initialize regex patterns
47
  self._initialize_patterns()
48
 
@@ -51,7 +61,11 @@ class PIIMasker:
51
  self.patterns = {
52
  "email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
53
  # Simplified phone regex to capture both standard and international formats
54
- "phone_number": r'\b(?:(?:\+|00)[1-9]\d{0,3}[-\s.]?)?(?:\(?\d{1,5}\)?[-\s.]?)?\d{1,5}(?:[-\s.]\d{1,5}){1,4}\b',
 
 
 
 
55
  # Card number regex: common formats, allows optional spaces/hyphens
56
  "credit_debit_no": r'\b(?:(?:\d{4}[\s-]?){3}\d{4}|\d{13,19})\b',
57
  # CVV: 3 or 4 digits, ensuring it's a standalone number (word boundary)
@@ -60,7 +74,10 @@ class PIIMasker:
60
  "expiry_no": r'\b(0[1-9]|1[0-2])[/\s-]([0-9]{2}|20[0-9]{2})\b',
61
  "aadhar_num": r'\b\d{4}\s?\d{4}\s?\d{4}\b',
62
  # DOB: DD/MM/YYYY or DD-MM-YYYY etc.
63
- "dob": r'\b(0[1-9]|[12][0-9]|3[01])[/\s-](0[1-9]|1[0-2])[/\s-](?:19|20)\d\d\b'
 
 
 
64
  }
65
 
66
  def detect_regex_entities(self, text: str) -> List[Entity]:
@@ -83,14 +100,19 @@ class PIIMasker:
83
  if not self.verify_phone_number(text, match):
84
  continue
85
  elif entity_type == "dob":
86
- if not self._verify_with_context(text, start, end, ["birth", "dob", "born"]):
 
 
87
  continue
88
 
89
- # Avoid detecting parts of already matched longer entities (e.g. year within a DOB)
 
90
  # This is a simple check; more robust overlap handling is done later
91
  is_substring_of_existing = False
92
  for existing_entity in entities:
93
- if existing_entity.start <= start and existing_entity.end >= end and existing_entity.value != value:
 
 
94
  is_substring_of_existing = True
95
  break
96
  if is_substring_of_existing:
@@ -99,7 +121,9 @@ class PIIMasker:
99
  entities.append(Entity(start, end, entity_type, value))
100
  return entities
101
 
102
- def _verify_with_context(self, text: str, start: int, end: int, keywords: List[str], window: int = 50) -> bool:
 
 
103
  """Verify an entity match using surrounding context"""
104
  context_before = text[max(0, start - window):start].lower()
105
  context_after = text[end:min(len(text), end + window)].lower()
@@ -117,7 +141,10 @@ class PIIMasker:
117
  context_before = text[max(0, start - context_window):start].lower()
118
  context_after = text[end:min(len(text), end + context_window)].lower()
119
 
120
- card_keywords = ["card", "credit", "debit", "visa", "mastercard", "payment", "amex", "account no", "card no"]
 
 
 
121
  for keyword in card_keywords:
122
  if keyword in context_before or keyword in context_after:
123
  return True
@@ -125,19 +152,19 @@ class PIIMasker:
125
  # For simplicity, we'll rely on context here. If needed, Luhn can be added.
126
  return False
127
 
128
-
129
  def verify_cvv(self, text: str, match: re.Match) -> bool:
130
  """Verify if a 3-4 digit number is actually a CVV using contextual clues"""
131
  context_window = 50
132
  start, end = match.span()
133
  value = match.group()
134
 
135
- # If it's part of a longer number sequence (like a phone number or ID), it's likely not a CVV
 
136
  # Check character immediately before and after
137
- char_before = text[start-1:start] if start > 0 else ""
138
- char_after = text[end:end+1] if end < len(text) else ""
139
  if char_before.isdigit() or char_after.isdigit():
140
- return False # It's part of a larger number
141
 
142
  # Only consider 3-4 digit numbers
143
  if not value.isdigit() or len(value) < 3 or len(value) > 4:
@@ -148,14 +175,16 @@ class PIIMasker:
148
 
149
  # Expanded list of CVV-related keywords to improve detection
150
  cvv_keywords = [
151
- "cvv", "cvc", "csc", "security code", "card verification", "verification no",
152
- "security", "security number", "cv2", "card code", "security value"
 
153
  ]
154
-
155
- date_keywords = ["date", "year", "/", "born", "age", "since", "established"]
156
 
157
  # Look for CVV context clues
158
- is_cvv_context = any(keyword in context_before or keyword in context_after for keyword in cvv_keywords)
 
 
 
159
 
160
  # If explicitly mentioned as a CVV, immediately return true
161
  if is_cvv_context:
@@ -163,17 +192,23 @@ class PIIMasker:
163
 
164
  # If it looks like a year, reject it
165
  if len(value) == 4 and 1900 <= int(value) <= 2100:
166
- if any(k in context_before or k in context_after for k in ["year", "born", "established", "since"]):
 
 
 
167
  return False
168
 
169
  # If in expiry date context, reject it
170
  if re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()):
171
  return False
172
-
173
- # If no context clues but we have a credit card mention nearby, it could be a CVV
174
- card_context = any(k in context_before or k in context_after
175
- for k in ["card", "credit", "visa", "mastercard", "amex", "discover"])
176
-
 
 
 
177
  return is_cvv_context or (card_context and len(value) in [3, 4])
178
 
179
  def verify_phone_number(self, text: str, match: re.Match) -> bool:
@@ -182,56 +217,62 @@ class PIIMasker:
182
  """
183
  value = match.group()
184
  start, end = match.span()
185
-
186
  # Extract only digits to count them
187
  digits = ''.join(c for c in value if c.isdigit())
188
  digit_count = len(digits)
189
-
190
  # Most phone numbers worldwide have between 7 and 15 digits
191
  if digit_count < 7 or digit_count > 15:
192
  return False
193
-
194
  # Check for common phone number indicators
195
  context_window = 50
196
  context_before = text[max(0, start - context_window):start].lower()
197
  context_after = text[end:min(len(text), end + context_window)].lower()
198
-
199
  # Expanded phone keywords
200
  phone_keywords = [
201
- "phone", "call", "tel", "telephone", "contact", "dial", "mobile", "cell",
202
- "number", "direct", "office", "fax", "reach me at", "call me", "contact me",
203
- "line", "extension", "ext", "phone number"
204
  ]
205
-
206
  # Check for phone context
207
- has_phone_context = any(kw in context_before or kw in context_after for kw in phone_keywords)
208
-
 
 
209
  # Check for formatting that indicates a phone number
210
- has_phone_formatting = bool(re.search(r'[-\s.()\+]', value))
211
-
212
  # Check for international prefix
213
  has_intl_prefix = value.startswith('+') or value.startswith('00')
214
-
215
  # Return true if any of these conditions are met:
216
  # 1. Has explicit phone context
217
  # 2. Has phone-like formatting AND reasonable digit count
218
  # 3. Has international prefix AND reasonable digit count
219
  # 4. Has 10 digits exactly (common in many countries) with formatting
220
- return has_phone_context or \
221
- (has_phone_formatting and digit_count >= 7) or \
222
- (has_intl_prefix) or \
223
- (digit_count == 10 and has_phone_formatting)
 
 
224
 
225
  def detect_name_entities(self, text: str) -> List[Entity]:
226
  """Detect name entities using SpaCy NER"""
227
  entities = []
228
  doc = self.nlp(text)
229
-
230
  for ent in doc.ents:
231
  # Use PER for person, common in many models like xx_ent_wiki_sm
232
  # Also checking for PERSON as some models might use it.
233
  if ent.label_ in ["PER", "PERSON"]:
234
- entities.append(Entity(ent.start_char, ent.end_char, "full_name", ent.text))
 
 
235
  return entities
236
 
237
  def detect_all_entities(self, text: str) -> List[Entity]:
@@ -265,58 +306,74 @@ class PIIMasker:
265
  # A simple greedy approach: iterate and remove/adjust overlaps
266
  # This can be made more sophisticated
267
  resolved_entities: List[Entity] = []
268
- for current_entity in sorted(entities, key=lambda e: (e.start, -(e.end - e.start))): # Process by start, then by longest
 
 
 
269
  is_overlapped_or_contained = False
270
  temp_resolved = []
271
  for i, res_entity in enumerate(resolved_entities):
272
  # Check for overlap:
273
  # Current: |----|
274
  # Res: |----| or |----| or |--| or |------|
275
- overlap = max(0, min(current_entity.end, res_entity.end) - max(current_entity.start, res_entity.start))
 
 
 
 
276
 
277
  if overlap > 0:
278
  is_overlapped_or_contained = True
279
  # Preference:
280
- # 1. NER names often trump regex if they are the ones causing overlap
281
  # 2. Longer entity wins
282
  current_len = current_entity.end - current_entity.start
283
  res_len = res_entity.end - res_entity.start
284
 
285
- # If current is a name and overlaps, and previous is not a name, prefer current if it's not fully contained
286
- if current_entity.entity_type == "full_name" and res_entity.entity_type != "full_name":
287
- if not (res_entity.start <= current_entity.start and res_entity.end >= current_entity.end): # current not fully contained by res
288
- # remove res_entity, current will be added later
289
- continue # go to next res_entity, this one is marked for removal
290
- elif res_entity.entity_type == "full_name" and current_entity.entity_type != "full_name":
291
- # res_entity is a name, current is not. Prefer res_entity if it's not fully contained
292
- if not (current_entity.start <= res_entity.start and current_entity.end >= res_entity.end):
293
- # current entity is subsumed or less important, so don't add current
294
- # and keep res_entity
 
 
 
 
 
 
 
295
  temp_resolved.append(res_entity)
296
- is_overlapped_or_contained = True # Mark current as handled
297
- break # Current is dominated
298
 
299
  # General case: longer entity wins
300
  if current_len > res_len:
301
- # current is longer, res_entity is removed from consideration for this current_entity
302
- pass # res_entity will not be added to temp_resolved if it's fully replaced
 
303
  elif res_len > current_len:
304
  # res is longer, current is dominated
305
  temp_resolved.append(res_entity)
306
- is_overlapped_or_contained = True # Mark current as handled
307
  break
308
- else: # Same length, keep existing one (res_entity)
309
  temp_resolved.append(res_entity)
310
- is_overlapped_or_contained = True # Mark current as handled
311
  break
312
- else: # No overlap
313
  temp_resolved.append(res_entity)
314
 
315
  if not is_overlapped_or_contained:
316
  temp_resolved.append(current_entity)
317
 
318
- resolved_entities = sorted(temp_resolved, key=lambda e: (e.start, -(e.end - e.start)))
319
-
 
320
 
321
  # Final pass to remove fully contained entities if a larger one exists
322
  final_entities = []
@@ -329,8 +386,10 @@ class PIIMasker:
329
  if i == j:
330
  continue
331
  # If 'entity' is strictly contained within 'other_entity'
332
- if other_entity.start <= entity.start and other_entity.end >= entity.end and \
333
- (other_entity.end - other_entity.start > entity.end - entity.start):
 
 
334
  is_contained = True
335
  break
336
  if not is_contained:
@@ -338,7 +397,6 @@ class PIIMasker:
338
 
339
  return final_entities
340
 
341
-
342
  def mask_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
343
  """
344
  Mask PII entities in the text and return masked text and entity information
@@ -370,7 +428,6 @@ class PIIMasker:
370
 
371
  return "".join(new_text_parts), entity_info
372
 
373
-
374
  def process_email(self, email_text: str) -> Dict[str, Any]:
375
  """
376
  Process an email by detecting and masking PII entities.
@@ -378,56 +435,60 @@ class PIIMasker:
378
  """
379
  # Mask the email
380
  masked_email, entity_info = self.mask_text(email_text)
381
-
382
  # Store the email in the SQLite database - only get back email_id now
383
  email_id = self.db.store_email(
384
  original_email=email_text,
385
  masked_email=masked_email,
386
  masked_entities=entity_info
387
  )
388
-
389
  # Return the processed data with just the email_id
390
  return {
391
- "input_email_body": email_text, # Return original input for API compatibility
392
  "list_of_masked_entities": entity_info,
393
  "masked_email": masked_email,
394
  "category_of_the_email": "",
395
  "email_id": email_id
396
  }
397
-
398
- def get_original_email(self, email_id: str, access_key: str) -> Optional[Dict[str, Any]]:
 
 
399
  """
400
  Retrieve the original email with PII using the email ID and access key.
401
-
402
  Args:
403
  email_id: The ID of the stored email
404
  access_key: The security key for accessing the original email
405
-
406
  Returns:
407
  The original email data or None if not found or access_key is invalid
408
  """
409
  return self.db.get_original_email(email_id, access_key)
410
-
411
  def get_masked_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
412
  """
413
  Retrieve a masked email by its ID (without the original PII-containing email).
414
-
415
  Args:
416
  email_id: The ID of the stored email
417
-
418
  Returns:
419
  The masked email data or None if not found
420
  """
421
  return self.db.get_email_by_id(email_id)
422
-
423
- def get_original_by_masked_email(self, masked_email: str) -> Optional[Dict[str, Any]]:
 
 
424
  """
425
  Retrieve the original unmasked email using the masked email content.
426
-
427
  Args:
428
  masked_email: The masked version of the email to search for
429
-
430
  Returns:
431
  The original email data or None if not found
432
  """
433
- return self.db.get_email_by_masked_content(masked_email)
 
1
  import re
2
  import spacy
3
  from typing import List, Dict, Tuple, Any, Optional
4
+
5
  from database import EmailDatabase
6
 
7
+
8
  class Entity:
9
  def __init__(self, start: int, end: int, entity_type: str, value: str):
10
  self.start = start
 
19
  "entity": self.value
20
  }
21
 
22
+ def __repr__(self): # Added for easier debugging
23
+ return (
24
+ f"Entity(type='{self.entity_type}', value='{self.value}', "
25
+ f"start={self.start}, end={self.end})"
26
+ )
27
+
28
 
29
  class PIIMasker:
30
+ def __init__(
31
+ self,
32
+ spacy_model_name: str = "xx_ent_wiki_sm",
33
+ db_path: str = None
34
+ ): # Allow model choice
35
  # Load SpaCy model
36
  try:
37
  self.nlp = spacy.load(spacy_model_name)
 
52
 
53
  # Initialize database connection with SQLite path
54
  self.db = EmailDatabase(connection_string=db_path)
55
+
56
  # Initialize regex patterns
57
  self._initialize_patterns()
58
 
 
61
  self.patterns = {
62
  "email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
63
  # Simplified phone regex to capture both standard and international formats
64
+ "phone_number": (
65
+ r'\b(?:(?:\+|00)[1-9]\d{0,3}[-\s.]?)?'
66
+ r'(?:\(?\d{1,5}\)?[-\s.]?)?\d{1,5}'
67
+ r'(?:[-\s.]\d{1,5}){1,4}\b'
68
+ ),
69
  # Card number regex: common formats, allows optional spaces/hyphens
70
  "credit_debit_no": r'\b(?:(?:\d{4}[\s-]?){3}\d{4}|\d{13,19})\b',
71
  # CVV: 3 or 4 digits, ensuring it's a standalone number (word boundary)
 
74
  "expiry_no": r'\b(0[1-9]|1[0-2])[/\s-]([0-9]{2}|20[0-9]{2})\b',
75
  "aadhar_num": r'\b\d{4}\s?\d{4}\s?\d{4}\b',
76
  # DOB: DD/MM/YYYY or DD-MM-YYYY etc.
77
+ "dob": (
78
+ r'\b(0[1-9]|[12][0-9]|3[01])[/\s-]'
79
+ r'(0[1-9]|1[0-2])[/\s-](?:19|20)\d\d\b'
80
+ )
81
  }
82
 
83
  def detect_regex_entities(self, text: str) -> List[Entity]:
 
100
  if not self.verify_phone_number(text, match):
101
  continue
102
  elif entity_type == "dob":
103
+ if not self._verify_with_context(
104
+ text, start, end, ["birth", "dob", "born"]
105
+ ):
106
  continue
107
 
108
+ # Avoid detecting parts of already matched longer entities
109
+ # (e.g. year within a DOB)
110
  # This is a simple check; more robust overlap handling is done later
111
  is_substring_of_existing = False
112
  for existing_entity in entities:
113
+ if (existing_entity.start <= start
114
+ and existing_entity.end >= end # W504 corrected
115
+ and existing_entity.value != value): # W504 corrected
116
  is_substring_of_existing = True
117
  break
118
  if is_substring_of_existing:
 
121
  entities.append(Entity(start, end, entity_type, value))
122
  return entities
123
 
124
+ def _verify_with_context(
125
+ self, text: str, start: int, end: int, keywords: List[str], window: int = 50
126
+ ) -> bool:
127
  """Verify an entity match using surrounding context"""
128
  context_before = text[max(0, start - window):start].lower()
129
  context_after = text[end:min(len(text), end + window)].lower()
 
141
  context_before = text[max(0, start - context_window):start].lower()
142
  context_after = text[end:min(len(text), end + context_window)].lower()
143
 
144
+ card_keywords = [
145
+ "card", "credit", "debit", "visa", "mastercard",
146
+ "payment", "amex", "account no", "card no"
147
+ ]
148
  for keyword in card_keywords:
149
  if keyword in context_before or keyword in context_after:
150
  return True
 
152
  # For simplicity, we'll rely on context here. If needed, Luhn can be added.
153
  return False
154
 
 
155
  def verify_cvv(self, text: str, match: re.Match) -> bool:
156
  """Verify if a 3-4 digit number is actually a CVV using contextual clues"""
157
  context_window = 50
158
  start, end = match.span()
159
  value = match.group()
160
 
161
+ # If it's part of a longer number sequence (like a phone number or ID),
162
+ # it's likely not a CVV
163
  # Check character immediately before and after
164
+ char_before = text[start - 1:start] if start > 0 else ""
165
+ char_after = text[end:end + 1] if end < len(text) else ""
166
  if char_before.isdigit() or char_after.isdigit():
167
+ return False # It's part of a larger number
168
 
169
  # Only consider 3-4 digit numbers
170
  if not value.isdigit() or len(value) < 3 or len(value) > 4:
 
175
 
176
  # Expanded list of CVV-related keywords to improve detection
177
  cvv_keywords = [
178
+ "cvv", "cvc", "csc", "security code", "card verification",
179
+ "verification no", "security", "security number", "cv2",
180
+ "card code", "security value"
181
  ]
 
 
182
 
183
  # Look for CVV context clues
184
+ is_cvv_context = any(
185
+ keyword in context_before or keyword in context_after
186
+ for keyword in cvv_keywords
187
+ )
188
 
189
  # If explicitly mentioned as a CVV, immediately return true
190
  if is_cvv_context:
 
192
 
193
  # If it looks like a year, reject it
194
  if len(value) == 4 and 1900 <= int(value) <= 2100:
195
+ if any(
196
+ k in context_before or k in context_after
197
+ for k in ["year", "born", "established", "since"]
198
+ ):
199
  return False
200
 
201
  # If in expiry date context, reject it
202
  if re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()):
203
  return False
204
+
205
+ # If no context clues but we have a credit card mention nearby,
206
+ # it could be a CVV
207
+ card_context = any(
208
+ k in context_before or k in context_after for k in
209
+ ["card", "credit", "visa", "mastercard", "amex", "discover"]
210
+ )
211
+
212
  return is_cvv_context or (card_context and len(value) in [3, 4])
213
 
214
  def verify_phone_number(self, text: str, match: re.Match) -> bool:
 
217
  """
218
  value = match.group()
219
  start, end = match.span()
220
+
221
  # Extract only digits to count them
222
  digits = ''.join(c for c in value if c.isdigit())
223
  digit_count = len(digits)
224
+
225
  # Most phone numbers worldwide have between 7 and 15 digits
226
  if digit_count < 7 or digit_count > 15:
227
  return False
228
+
229
  # Check for common phone number indicators
230
  context_window = 50
231
  context_before = text[max(0, start - context_window):start].lower()
232
  context_after = text[end:min(len(text), end + context_window)].lower()
233
+
234
  # Expanded phone keywords
235
  phone_keywords = [
236
+ "phone", "call", "tel", "telephone", "contact", "dial", "mobile",
237
+ "cell", "number", "direct", "office", "fax", "reach me at",
238
+ "call me", "contact me", "line", "extension", "ext", "phone number"
239
  ]
240
+
241
  # Check for phone context
242
+ has_phone_context = any(
243
+ kw in context_before or kw in context_after for kw in phone_keywords
244
+ )
245
+
246
  # Check for formatting that indicates a phone number
247
+ has_phone_formatting = bool(re.search(r'[-\s.()\\+]', value))
248
+
249
  # Check for international prefix
250
  has_intl_prefix = value.startswith('+') or value.startswith('00')
251
+
252
  # Return true if any of these conditions are met:
253
  # 1. Has explicit phone context
254
  # 2. Has phone-like formatting AND reasonable digit count
255
  # 3. Has international prefix AND reasonable digit count
256
  # 4. Has 10 digits exactly (common in many countries) with formatting
257
+ return (
258
+ has_phone_context
259
+ or (has_phone_formatting and digit_count >= 7)
260
+ or (has_intl_prefix)
261
+ or (digit_count == 10 and has_phone_formatting)
262
+ )
263
 
264
  def detect_name_entities(self, text: str) -> List[Entity]:
265
  """Detect name entities using SpaCy NER"""
266
  entities = []
267
  doc = self.nlp(text)
268
+
269
  for ent in doc.ents:
270
  # Use PER for person, common in many models like xx_ent_wiki_sm
271
  # Also checking for PERSON as some models might use it.
272
  if ent.label_ in ["PER", "PERSON"]:
273
+ entities.append(
274
+ Entity(ent.start_char, ent.end_char, "full_name", ent.text)
275
+ )
276
  return entities
277
 
278
  def detect_all_entities(self, text: str) -> List[Entity]:
 
306
  # A simple greedy approach: iterate and remove/adjust overlaps
307
  # This can be made more sophisticated
308
  resolved_entities: List[Entity] = []
309
+ # Process by start, then by longest
310
+ for current_entity in sorted(
311
+ entities, key=lambda e: (e.start, -(e.end - e.start))
312
+ ):
313
  is_overlapped_or_contained = False
314
  temp_resolved = []
315
  for i, res_entity in enumerate(resolved_entities):
316
  # Check for overlap:
317
  # Current: |----|
318
  # Res: |----| or |----| or |--| or |------|
319
+ overlap = max(
320
+ 0,
321
+ min(current_entity.end, res_entity.end) # Fixed W504 line break
322
+ - max(current_entity.start, res_entity.start)
323
+ )
324
 
325
  if overlap > 0:
326
  is_overlapped_or_contained = True
327
  # Preference:
328
+ # 1. NER often trump regex if they are the ones causing overlap
329
  # 2. Longer entity wins
330
  current_len = current_entity.end - current_entity.start
331
  res_len = res_entity.end - res_entity.start
332
 
333
+ # If current is a name and overlaps, and previous is not a name,
334
+ # prefer current if it's not fully contained
335
+ if (current_entity.entity_type == "full_name" # E501 corrected
336
+ and res_entity.entity_type != "full_name"):
337
+ # current not fully contained by res
338
+ if not (res_entity.start <= current_entity.start
339
+ and res_entity.end >= current_entity.end):
340
+ # remove res_entity, current will be added later
341
+ continue # go to next res_entity, marked for removal
342
+ elif (res_entity.entity_type == "full_name"
343
+ and current_entity.entity_type != "full_name"):
344
+ # res_entity is a name, current is not. Prefer res_entity
345
+ # if it's not fully contained
346
+ if not (current_entity.start <= res_entity.start
347
+ and current_entity.end >= res_entity.end):
348
+ # current entity is subsumed or less important,
349
+ # so don't add current and keep res_entity
350
  temp_resolved.append(res_entity)
351
+ is_overlapped_or_contained = True # Mark current as handled
352
+ break # Current is dominated
353
 
354
  # General case: longer entity wins
355
  if current_len > res_len:
356
+ # current is longer, res_entity is removed from
357
+ # consideration for this current_entity
358
+ pass # res_entity not added to temp_resolved if fully replaced
359
  elif res_len > current_len:
360
  # res is longer, current is dominated
361
  temp_resolved.append(res_entity)
362
+ is_overlapped_or_contained = True # Mark current as handled
363
  break
364
+ else: # Same length, keep existing one (res_entity)
365
  temp_resolved.append(res_entity)
366
+ is_overlapped_or_contained = True # Mark current as handled
367
  break
368
+ else: # No overlap
369
  temp_resolved.append(res_entity)
370
 
371
  if not is_overlapped_or_contained:
372
  temp_resolved.append(current_entity)
373
 
374
+ resolved_entities = sorted(
375
+ temp_resolved, key=lambda e: (e.start, -(e.end - e.start))
376
+ )
377
 
378
  # Final pass to remove fully contained entities if a larger one exists
379
  final_entities = []
 
386
  if i == j:
387
  continue
388
  # If 'entity' is strictly contained within 'other_entity'
389
+ if (other_entity.start <= entity.start
390
+ and other_entity.end >= entity.end
391
+ and (other_entity.end - other_entity.start
392
+ > entity.end - entity.start)):
393
  is_contained = True
394
  break
395
  if not is_contained:
 
397
 
398
  return final_entities
399
 
 
400
  def mask_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
401
  """
402
  Mask PII entities in the text and return masked text and entity information
 
428
 
429
  return "".join(new_text_parts), entity_info
430
 
 
431
  def process_email(self, email_text: str) -> Dict[str, Any]:
432
  """
433
  Process an email by detecting and masking PII entities.
 
435
  """
436
  # Mask the email
437
  masked_email, entity_info = self.mask_text(email_text)
438
+
439
  # Store the email in the SQLite database - only get back email_id now
440
  email_id = self.db.store_email(
441
  original_email=email_text,
442
  masked_email=masked_email,
443
  masked_entities=entity_info
444
  )
445
+
446
  # Return the processed data with just the email_id
447
  return {
448
+ "input_email_body": email_text, # Return original for API compatibility
449
  "list_of_masked_entities": entity_info,
450
  "masked_email": masked_email,
451
  "category_of_the_email": "",
452
  "email_id": email_id
453
  }
454
+
455
+ def get_original_email(
456
+ self, email_id: str, access_key: str
457
+ ) -> Optional[Dict[str, Any]]:
458
  """
459
  Retrieve the original email with PII using the email ID and access key.
460
+
461
  Args:
462
  email_id: The ID of the stored email
463
  access_key: The security key for accessing the original email
464
+
465
  Returns:
466
  The original email data or None if not found or access_key is invalid
467
  """
468
  return self.db.get_original_email(email_id, access_key)
469
+
470
  def get_masked_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
471
  """
472
  Retrieve a masked email by its ID (without the original PII-containing email).
473
+
474
  Args:
475
  email_id: The ID of the stored email
476
+
477
  Returns:
478
  The masked email data or None if not found
479
  """
480
  return self.db.get_email_by_id(email_id)
481
+
482
+ def get_original_by_masked_email(
483
+ self, masked_email: str
484
+ ) -> Optional[Dict[str, Any]]:
485
  """
486
  Retrieve the original unmasked email using the masked email content.
487
+
488
  Args:
489
  masked_email: The masked version of the email to search for
490
+
491
  Returns:
492
  The original email data or None if not found
493
  """
494
+ return self.db.get_email_by_masked_content(masked_email)