refactored the code
Browse files- .gitignore +1 -1
- database.py +60 -54
- main.py +35 -19
- models.py +22 -20
- requirements.in +0 -381
- upload_model.py +34 -22
- utils.py +149 -88
.gitignore
CHANGED
|
@@ -33,7 +33,7 @@ env/
|
|
| 33 |
.vscode/
|
| 34 |
*.swp
|
| 35 |
*.swo
|
| 36 |
-
|
| 37 |
# Jupyter Notebook
|
| 38 |
.ipynb_checkpoints
|
| 39 |
|
|
|
|
| 33 |
.vscode/
|
| 34 |
*.swp
|
| 35 |
*.swo
|
| 36 |
+
.flake8
|
| 37 |
# Jupyter Notebook
|
| 38 |
.ipynb_checkpoints
|
| 39 |
|
database.py
CHANGED
|
@@ -4,7 +4,7 @@ Database module for handling email storage operations with SQLite.
|
|
| 4 |
import os
|
| 5 |
import json
|
| 6 |
import sqlite3
|
| 7 |
-
from typing import Dict, Any, Optional, List
|
| 8 |
from datetime import datetime
|
| 9 |
import uuid
|
| 10 |
|
|
@@ -14,29 +14,32 @@ class EmailDatabase:
|
|
| 14 |
Database class for storing and retrieving email data with PII masking information.
|
| 15 |
Uses SQLite for storage in Hugging Face's persistent directory.
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
def __init__(self, connection_string: str = None):
|
| 19 |
"""
|
| 20 |
Initialize the database connection.
|
| 21 |
-
|
| 22 |
Args:
|
| 23 |
-
connection_string: Database connection string or path.
|
| 24 |
-
|
| 25 |
"""
|
| 26 |
# Hugging Face Spaces has a /data directory that persists between restarts
|
| 27 |
self.db_path = connection_string or os.environ.get(
|
| 28 |
-
"DATABASE_PATH",
|
| 29 |
"/data/emails.db" # This path persists in Hugging Face Spaces
|
| 30 |
)
|
| 31 |
-
|
| 32 |
# Get the global access key from environment variables
|
| 33 |
-
self.access_key = os.environ.get(
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
# Ensure the data directory exists
|
| 36 |
self._ensure_data_directory()
|
| 37 |
-
|
| 38 |
self._create_tables()
|
| 39 |
-
|
| 40 |
def _ensure_data_directory(self):
|
| 41 |
"""Ensure the data directory exists, and use a fallback if needed."""
|
| 42 |
try:
|
|
@@ -47,74 +50,77 @@ class EmailDatabase:
|
|
| 47 |
# If we can't write to /data, fall back to the current directory
|
| 48 |
self.db_path = "emails.db"
|
| 49 |
print(f"Warning: Using fallback database path: {self.db_path}")
|
| 50 |
-
|
| 51 |
def _get_connection(self):
|
| 52 |
"""Get a database connection."""
|
| 53 |
return sqlite3.connect(self.db_path)
|
| 54 |
-
|
| 55 |
def _create_tables(self):
|
| 56 |
"""Create the necessary tables if they don't exist."""
|
| 57 |
conn = self._get_connection()
|
| 58 |
try:
|
| 59 |
cursor = conn.cursor()
|
| 60 |
-
|
| 61 |
# Create the emails table to store original emails and their masked versions
|
| 62 |
cursor.execute('''
|
| 63 |
CREATE TABLE IF NOT EXISTS emails (
|
| 64 |
id TEXT PRIMARY KEY,
|
| 65 |
original_email TEXT NOT NULL,
|
| 66 |
masked_email TEXT NOT NULL,
|
| 67 |
-
masked_entities TEXT NOT NULL,
|
| 68 |
category TEXT,
|
| 69 |
created_at TEXT NOT NULL
|
| 70 |
)
|
| 71 |
''')
|
| 72 |
-
|
| 73 |
conn.commit()
|
| 74 |
except Exception as e:
|
| 75 |
conn.rollback()
|
| 76 |
raise e
|
| 77 |
finally:
|
| 78 |
conn.close()
|
| 79 |
-
|
| 80 |
def _generate_id(self) -> str:
|
| 81 |
"""Generate a unique ID for the email record."""
|
| 82 |
return str(uuid.uuid4())
|
| 83 |
-
|
| 84 |
-
def store_email(
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
"""
|
| 87 |
Store the original email along with its masked version and related information.
|
| 88 |
-
|
| 89 |
Args:
|
| 90 |
original_email: The original email with PII
|
| 91 |
masked_email: The masked version of the email
|
| 92 |
masked_entities: List of entities that were masked
|
| 93 |
category: Optional category of the email
|
| 94 |
-
|
| 95 |
Returns:
|
| 96 |
email_id for future reference
|
| 97 |
"""
|
| 98 |
conn = self._get_connection()
|
| 99 |
try:
|
| 100 |
cursor = conn.cursor()
|
| 101 |
-
|
| 102 |
email_id = self._generate_id()
|
| 103 |
-
|
| 104 |
# Store the email data
|
| 105 |
cursor.execute(
|
| 106 |
-
'INSERT INTO emails
|
|
|
|
| 107 |
'VALUES (?, ?, ?, ?, ?, ?)',
|
| 108 |
(
|
| 109 |
email_id,
|
| 110 |
original_email,
|
| 111 |
masked_email,
|
| 112 |
-
json.dumps(masked_entities), #
|
| 113 |
category,
|
| 114 |
datetime.now().isoformat()
|
| 115 |
)
|
| 116 |
)
|
| 117 |
-
|
| 118 |
conn.commit()
|
| 119 |
return email_id
|
| 120 |
except Exception as e:
|
|
@@ -122,112 +128,112 @@ class EmailDatabase:
|
|
| 122 |
raise e
|
| 123 |
finally:
|
| 124 |
conn.close()
|
| 125 |
-
|
| 126 |
def get_original_email(self, email_id: str, access_key: str) -> Optional[Dict[str, Any]]:
|
| 127 |
"""
|
| 128 |
Retrieve the original email with PII using the access key.
|
| 129 |
-
|
| 130 |
Args:
|
| 131 |
email_id: The ID of the email record
|
| 132 |
access_key: The security key required to access the original email
|
| 133 |
-
|
| 134 |
Returns:
|
| 135 |
Dictionary with email data or None if not found or access_key is invalid
|
| 136 |
"""
|
| 137 |
# Verify the access key matches the global access key
|
| 138 |
if access_key != self.access_key:
|
| 139 |
return None
|
| 140 |
-
|
| 141 |
conn = self._get_connection()
|
| 142 |
try:
|
| 143 |
cursor = conn.cursor()
|
| 144 |
-
|
| 145 |
cursor.execute(
|
| 146 |
-
'SELECT id, original_email, masked_email, masked_entities, category,
|
| 147 |
-
'FROM emails WHERE id = ?',
|
| 148 |
(email_id,)
|
| 149 |
)
|
| 150 |
-
|
| 151 |
row = cursor.fetchone()
|
| 152 |
if not row:
|
| 153 |
return None
|
| 154 |
-
|
| 155 |
return {
|
| 156 |
"id": row[0],
|
| 157 |
"original_email": row[1],
|
| 158 |
"masked_email": row[2],
|
| 159 |
-
"masked_entities": json.loads(row[3]), # Convert
|
| 160 |
"category": row[4],
|
| 161 |
"created_at": row[5]
|
| 162 |
}
|
| 163 |
finally:
|
| 164 |
conn.close()
|
| 165 |
-
|
| 166 |
def get_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
|
| 167 |
"""
|
| 168 |
Retrieve the masked email data without the original PII-containing email.
|
| 169 |
-
|
| 170 |
Args:
|
| 171 |
email_id: The ID of the email
|
| 172 |
-
|
| 173 |
Returns:
|
| 174 |
Dictionary with masked email data or None if not found
|
| 175 |
"""
|
| 176 |
conn = self._get_connection()
|
| 177 |
try:
|
| 178 |
cursor = conn.cursor()
|
| 179 |
-
|
| 180 |
cursor.execute(
|
| 181 |
'SELECT id, masked_email, masked_entities, category, created_at '
|
| 182 |
'FROM emails WHERE id = ?',
|
| 183 |
(email_id,)
|
| 184 |
)
|
| 185 |
-
|
| 186 |
row = cursor.fetchone()
|
| 187 |
if not row:
|
| 188 |
return None
|
| 189 |
-
|
| 190 |
return {
|
| 191 |
"id": row[0],
|
| 192 |
"masked_email": row[1],
|
| 193 |
-
"masked_entities": json.loads(row[2]), # Convert
|
| 194 |
"category": row[3],
|
| 195 |
"created_at": row[4]
|
| 196 |
}
|
| 197 |
finally:
|
| 198 |
conn.close()
|
| 199 |
-
|
| 200 |
def get_email_by_masked_content(self, masked_email: str) -> Optional[Dict[str, Any]]:
|
| 201 |
"""
|
| 202 |
Retrieve the original email using the masked email content.
|
| 203 |
-
|
| 204 |
Args:
|
| 205 |
masked_email: The masked version of the email to search for
|
| 206 |
-
|
| 207 |
Returns:
|
| 208 |
Dictionary with full email data or None if not found
|
| 209 |
"""
|
| 210 |
conn = self._get_connection()
|
| 211 |
try:
|
| 212 |
cursor = conn.cursor()
|
| 213 |
-
|
| 214 |
cursor.execute(
|
| 215 |
-
'SELECT id, original_email, masked_email, masked_entities, category,
|
| 216 |
-
'FROM emails WHERE masked_email = ?',
|
| 217 |
(masked_email,)
|
| 218 |
)
|
| 219 |
-
|
| 220 |
row = cursor.fetchone()
|
| 221 |
if not row:
|
| 222 |
return None
|
| 223 |
-
|
| 224 |
return {
|
| 225 |
"id": row[0],
|
| 226 |
"original_email": row[1],
|
| 227 |
"masked_email": row[2],
|
| 228 |
-
"masked_entities": json.loads(row[3]), # Convert
|
| 229 |
"category": row[4],
|
| 230 |
"created_at": row[5]
|
| 231 |
}
|
| 232 |
finally:
|
| 233 |
-
conn.close()
|
|
|
|
| 4 |
import os
|
| 5 |
import json
|
| 6 |
import sqlite3
|
| 7 |
+
from typing import Dict, Any, Optional, List
|
| 8 |
from datetime import datetime
|
| 9 |
import uuid
|
| 10 |
|
|
|
|
| 14 |
Database class for storing and retrieving email data with PII masking information.
|
| 15 |
Uses SQLite for storage in Hugging Face's persistent directory.
|
| 16 |
"""
|
| 17 |
+
|
| 18 |
def __init__(self, connection_string: str = None):
|
| 19 |
"""
|
| 20 |
Initialize the database connection.
|
| 21 |
+
|
| 22 |
Args:
|
| 23 |
+
connection_string: Database connection string or path.
|
| 24 |
+
For SQLite, this will be treated as a file path.
|
| 25 |
"""
|
| 26 |
# Hugging Face Spaces has a /data directory that persists between restarts
|
| 27 |
self.db_path = connection_string or os.environ.get(
|
| 28 |
+
"DATABASE_PATH",
|
| 29 |
"/data/emails.db" # This path persists in Hugging Face Spaces
|
| 30 |
)
|
| 31 |
+
|
| 32 |
# Get the global access key from environment variables
|
| 33 |
+
self.access_key = os.environ.get(
|
| 34 |
+
"EMAIL_ACCESS_KEY",
|
| 35 |
+
"default_secure_access_key"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
# Ensure the data directory exists
|
| 39 |
self._ensure_data_directory()
|
| 40 |
+
|
| 41 |
self._create_tables()
|
| 42 |
+
|
| 43 |
def _ensure_data_directory(self):
|
| 44 |
"""Ensure the data directory exists, and use a fallback if needed."""
|
| 45 |
try:
|
|
|
|
| 50 |
# If we can't write to /data, fall back to the current directory
|
| 51 |
self.db_path = "emails.db"
|
| 52 |
print(f"Warning: Using fallback database path: {self.db_path}")
|
| 53 |
+
|
| 54 |
def _get_connection(self):
|
| 55 |
"""Get a database connection."""
|
| 56 |
return sqlite3.connect(self.db_path)
|
| 57 |
+
|
| 58 |
def _create_tables(self):
|
| 59 |
"""Create the necessary tables if they don't exist."""
|
| 60 |
conn = self._get_connection()
|
| 61 |
try:
|
| 62 |
cursor = conn.cursor()
|
| 63 |
+
|
| 64 |
# Create the emails table to store original emails and their masked versions
|
| 65 |
cursor.execute('''
|
| 66 |
CREATE TABLE IF NOT EXISTS emails (
|
| 67 |
id TEXT PRIMARY KEY,
|
| 68 |
original_email TEXT NOT NULL,
|
| 69 |
masked_email TEXT NOT NULL,
|
| 70 |
+
masked_entities TEXT NOT NULL,
|
| 71 |
category TEXT,
|
| 72 |
created_at TEXT NOT NULL
|
| 73 |
)
|
| 74 |
''')
|
| 75 |
+
|
| 76 |
conn.commit()
|
| 77 |
except Exception as e:
|
| 78 |
conn.rollback()
|
| 79 |
raise e
|
| 80 |
finally:
|
| 81 |
conn.close()
|
| 82 |
+
|
| 83 |
def _generate_id(self) -> str:
|
| 84 |
"""Generate a unique ID for the email record."""
|
| 85 |
return str(uuid.uuid4())
|
| 86 |
+
|
| 87 |
+
def store_email(
|
| 88 |
+
self, original_email: str, masked_email: str,
|
| 89 |
+
masked_entities: List[Dict[str, Any]], category: Optional[str] = None
|
| 90 |
+
) -> str:
|
| 91 |
"""
|
| 92 |
Store the original email along with its masked version and related information.
|
| 93 |
+
|
| 94 |
Args:
|
| 95 |
original_email: The original email with PII
|
| 96 |
masked_email: The masked version of the email
|
| 97 |
masked_entities: List of entities that were masked
|
| 98 |
category: Optional category of the email
|
| 99 |
+
|
| 100 |
Returns:
|
| 101 |
email_id for future reference
|
| 102 |
"""
|
| 103 |
conn = self._get_connection()
|
| 104 |
try:
|
| 105 |
cursor = conn.cursor()
|
| 106 |
+
|
| 107 |
email_id = self._generate_id()
|
| 108 |
+
|
| 109 |
# Store the email data
|
| 110 |
cursor.execute(
|
| 111 |
+
'INSERT INTO emails '
|
| 112 |
+
'(id, original_email, masked_email, masked_entities, category, created_at) '
|
| 113 |
'VALUES (?, ?, ?, ?, ?, ?)',
|
| 114 |
(
|
| 115 |
email_id,
|
| 116 |
original_email,
|
| 117 |
masked_email,
|
| 118 |
+
json.dumps(masked_entities), # JSON string for SQLite
|
| 119 |
category,
|
| 120 |
datetime.now().isoformat()
|
| 121 |
)
|
| 122 |
)
|
| 123 |
+
|
| 124 |
conn.commit()
|
| 125 |
return email_id
|
| 126 |
except Exception as e:
|
|
|
|
| 128 |
raise e
|
| 129 |
finally:
|
| 130 |
conn.close()
|
| 131 |
+
|
| 132 |
def get_original_email(self, email_id: str, access_key: str) -> Optional[Dict[str, Any]]:
|
| 133 |
"""
|
| 134 |
Retrieve the original email with PII using the access key.
|
| 135 |
+
|
| 136 |
Args:
|
| 137 |
email_id: The ID of the email record
|
| 138 |
access_key: The security key required to access the original email
|
| 139 |
+
|
| 140 |
Returns:
|
| 141 |
Dictionary with email data or None if not found or access_key is invalid
|
| 142 |
"""
|
| 143 |
# Verify the access key matches the global access key
|
| 144 |
if access_key != self.access_key:
|
| 145 |
return None
|
| 146 |
+
|
| 147 |
conn = self._get_connection()
|
| 148 |
try:
|
| 149 |
cursor = conn.cursor()
|
| 150 |
+
|
| 151 |
cursor.execute(
|
| 152 |
+
'SELECT id, original_email, masked_email, masked_entities, category, '
|
| 153 |
+
'created_at FROM emails WHERE id = ?',
|
| 154 |
(email_id,)
|
| 155 |
)
|
| 156 |
+
|
| 157 |
row = cursor.fetchone()
|
| 158 |
if not row:
|
| 159 |
return None
|
| 160 |
+
|
| 161 |
return {
|
| 162 |
"id": row[0],
|
| 163 |
"original_email": row[1],
|
| 164 |
"masked_email": row[2],
|
| 165 |
+
"masked_entities": json.loads(row[3]), # Convert JSON to dict
|
| 166 |
"category": row[4],
|
| 167 |
"created_at": row[5]
|
| 168 |
}
|
| 169 |
finally:
|
| 170 |
conn.close()
|
| 171 |
+
|
| 172 |
def get_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
|
| 173 |
"""
|
| 174 |
Retrieve the masked email data without the original PII-containing email.
|
| 175 |
+
|
| 176 |
Args:
|
| 177 |
email_id: The ID of the email
|
| 178 |
+
|
| 179 |
Returns:
|
| 180 |
Dictionary with masked email data or None if not found
|
| 181 |
"""
|
| 182 |
conn = self._get_connection()
|
| 183 |
try:
|
| 184 |
cursor = conn.cursor()
|
| 185 |
+
|
| 186 |
cursor.execute(
|
| 187 |
'SELECT id, masked_email, masked_entities, category, created_at '
|
| 188 |
'FROM emails WHERE id = ?',
|
| 189 |
(email_id,)
|
| 190 |
)
|
| 191 |
+
|
| 192 |
row = cursor.fetchone()
|
| 193 |
if not row:
|
| 194 |
return None
|
| 195 |
+
|
| 196 |
return {
|
| 197 |
"id": row[0],
|
| 198 |
"masked_email": row[1],
|
| 199 |
+
"masked_entities": json.loads(row[2]), # Convert JSON to dict
|
| 200 |
"category": row[3],
|
| 201 |
"created_at": row[4]
|
| 202 |
}
|
| 203 |
finally:
|
| 204 |
conn.close()
|
| 205 |
+
|
| 206 |
def get_email_by_masked_content(self, masked_email: str) -> Optional[Dict[str, Any]]:
|
| 207 |
"""
|
| 208 |
Retrieve the original email using the masked email content.
|
| 209 |
+
|
| 210 |
Args:
|
| 211 |
masked_email: The masked version of the email to search for
|
| 212 |
+
|
| 213 |
Returns:
|
| 214 |
Dictionary with full email data or None if not found
|
| 215 |
"""
|
| 216 |
conn = self._get_connection()
|
| 217 |
try:
|
| 218 |
cursor = conn.cursor()
|
| 219 |
+
|
| 220 |
cursor.execute(
|
| 221 |
+
'SELECT id, original_email, masked_email, masked_entities, category, '
|
| 222 |
+
'created_at FROM emails WHERE masked_email = ?',
|
| 223 |
(masked_email,)
|
| 224 |
)
|
| 225 |
+
|
| 226 |
row = cursor.fetchone()
|
| 227 |
if not row:
|
| 228 |
return None
|
| 229 |
+
|
| 230 |
return {
|
| 231 |
"id": row[0],
|
| 232 |
"original_email": row[1],
|
| 233 |
"masked_email": row[2],
|
| 234 |
+
"masked_entities": json.loads(row[3]), # Convert JSON to dict
|
| 235 |
"category": row[4],
|
| 236 |
"created_at": row[5]
|
| 237 |
}
|
| 238 |
finally:
|
| 239 |
+
conn.close()
|
main.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
from fastapi import FastAPI, HTTPException
|
| 3 |
from pydantic import BaseModel
|
| 4 |
-
from typing import Dict, Any, List, Tuple
|
| 5 |
import uvicorn
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
|
@@ -21,7 +21,7 @@ else:
|
|
| 21 |
db_path = "emails.db" # Fallback to local directory
|
| 22 |
|
| 23 |
# Initialize the FastAPI application
|
| 24 |
-
app = FastAPI(title="Email Classification API",
|
| 25 |
description="API for classifying support emails and masking PII",
|
| 26 |
version="1.0.0")
|
| 27 |
|
|
@@ -29,16 +29,19 @@ app = FastAPI(title="Email Classification API",
|
|
| 29 |
pii_masker = PIIMasker(db_path=db_path)
|
| 30 |
email_classifier = EmailClassifier()
|
| 31 |
|
|
|
|
| 32 |
class EmailInput(BaseModel):
|
| 33 |
"""Input model for the email classification endpoint"""
|
| 34 |
input_email_body: str
|
| 35 |
|
|
|
|
| 36 |
class EntityInfo(BaseModel):
|
| 37 |
"""Model for entity information"""
|
| 38 |
position: Tuple[int, int]
|
| 39 |
-
classification: str
|
| 40 |
entity: str
|
| 41 |
|
|
|
|
| 42 |
class EmailOutput(BaseModel):
|
| 43 |
"""Output model for the email classification endpoint"""
|
| 44 |
input_email_body: str
|
|
@@ -46,29 +49,30 @@ class EmailOutput(BaseModel):
|
|
| 46 |
masked_email: str
|
| 47 |
category_of_the_email: str
|
| 48 |
|
|
|
|
| 49 |
class MaskedEmailInput(BaseModel):
|
| 50 |
"""Input model for retrieving original email by masked email content"""
|
| 51 |
masked_email: str
|
| 52 |
access_key: str
|
| 53 |
|
|
|
|
| 54 |
@app.post("/classify", response_model=EmailOutput)
|
| 55 |
async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
|
| 56 |
"""
|
| 57 |
Classify an email into a support category while masking PII
|
| 58 |
-
|
| 59 |
Args:
|
| 60 |
email_input: The input email data
|
| 61 |
-
|
| 62 |
Returns:
|
| 63 |
The classified email data with masked PII
|
| 64 |
"""
|
| 65 |
try:
|
| 66 |
# Process the email to mask PII and store original in database
|
| 67 |
processed_data = pii_masker.process_email(email_input.input_email_body)
|
| 68 |
-
|
| 69 |
# Classify the masked email
|
| 70 |
classified_data = email_classifier.process_email(processed_data)
|
| 71 |
-
|
| 72 |
# Make sure we return only the fields expected in the response model
|
| 73 |
return {
|
| 74 |
"input_email_body": email_input.input_email_body,
|
|
@@ -79,28 +83,35 @@ async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
|
|
| 79 |
except Exception as e:
|
| 80 |
raise HTTPException(status_code=500, detail=f"Error processing email: {str(e)}")
|
| 81 |
|
|
|
|
| 82 |
@app.post("/api/v1/unmask-email", response_model=Dict[str, Any])
|
| 83 |
async def unmask_email(masked_email_input: MaskedEmailInput) -> Dict[str, Any]:
|
| 84 |
"""
|
| 85 |
-
Retrieve the original unmasked email
|
| 86 |
-
|
| 87 |
Args:
|
| 88 |
masked_email_input: Contains the masked email and access key
|
| 89 |
-
|
| 90 |
Returns:
|
| 91 |
The original email data with PII information
|
| 92 |
"""
|
| 93 |
try:
|
| 94 |
# Verify access key matches the global access key
|
| 95 |
-
if masked_email_input.access_key != os.environ.get(
|
|
|
|
| 96 |
raise HTTPException(status_code=401, detail="Invalid access key")
|
| 97 |
-
|
| 98 |
# Retrieve the original email using the masked content
|
| 99 |
-
email_data = pii_masker.get_original_by_masked_email(
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
if not email_data:
|
| 102 |
-
raise HTTPException(
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
return {
|
| 105 |
"status": "success",
|
| 106 |
"data": {
|
|
@@ -116,19 +127,24 @@ async def unmask_email(masked_email_input: MaskedEmailInput) -> Dict[str, Any]:
|
|
| 116 |
except Exception as e:
|
| 117 |
if isinstance(e, HTTPException):
|
| 118 |
raise e
|
| 119 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
@app.get("/health")
|
| 122 |
async def health_check():
|
| 123 |
"""
|
| 124 |
Health check endpoint
|
| 125 |
-
|
| 126 |
Returns:
|
| 127 |
Status message indicating the API is running
|
| 128 |
"""
|
| 129 |
return {"status": "healthy", "message": "Email classification API is running"}
|
| 130 |
|
|
|
|
| 131 |
# For local development and testing
|
| 132 |
if __name__ == "__main__":
|
| 133 |
port = int(os.environ.get("PORT", 8000))
|
| 134 |
-
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)
|
|
|
|
| 1 |
import os
|
| 2 |
from fastapi import FastAPI, HTTPException
|
| 3 |
from pydantic import BaseModel
|
| 4 |
+
from typing import Dict, Any, List, Tuple
|
| 5 |
import uvicorn
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
|
|
|
| 21 |
db_path = "emails.db" # Fallback to local directory
|
| 22 |
|
| 23 |
# Initialize the FastAPI application
|
| 24 |
+
app = FastAPI(title="Email Classification API",
|
| 25 |
description="API for classifying support emails and masking PII",
|
| 26 |
version="1.0.0")
|
| 27 |
|
|
|
|
| 29 |
pii_masker = PIIMasker(db_path=db_path)
|
| 30 |
email_classifier = EmailClassifier()
|
| 31 |
|
| 32 |
+
|
| 33 |
class EmailInput(BaseModel):
|
| 34 |
"""Input model for the email classification endpoint"""
|
| 35 |
input_email_body: str
|
| 36 |
|
| 37 |
+
|
| 38 |
class EntityInfo(BaseModel):
|
| 39 |
"""Model for entity information"""
|
| 40 |
position: Tuple[int, int]
|
| 41 |
+
classification: str
|
| 42 |
entity: str
|
| 43 |
|
| 44 |
+
|
| 45 |
class EmailOutput(BaseModel):
|
| 46 |
"""Output model for the email classification endpoint"""
|
| 47 |
input_email_body: str
|
|
|
|
| 49 |
masked_email: str
|
| 50 |
category_of_the_email: str
|
| 51 |
|
| 52 |
+
|
| 53 |
class MaskedEmailInput(BaseModel):
|
| 54 |
"""Input model for retrieving original email by masked email content"""
|
| 55 |
masked_email: str
|
| 56 |
access_key: str
|
| 57 |
|
| 58 |
+
|
| 59 |
@app.post("/classify", response_model=EmailOutput)
|
| 60 |
async def classify_email(email_input: EmailInput) -> Dict[str, Any]:
|
| 61 |
"""
|
| 62 |
Classify an email into a support category while masking PII
|
| 63 |
+
|
| 64 |
Args:
|
| 65 |
email_input: The input email data
|
| 66 |
+
|
| 67 |
Returns:
|
| 68 |
The classified email data with masked PII
|
| 69 |
"""
|
| 70 |
try:
|
| 71 |
# Process the email to mask PII and store original in database
|
| 72 |
processed_data = pii_masker.process_email(email_input.input_email_body)
|
|
|
|
| 73 |
# Classify the masked email
|
| 74 |
classified_data = email_classifier.process_email(processed_data)
|
| 75 |
+
|
| 76 |
# Make sure we return only the fields expected in the response model
|
| 77 |
return {
|
| 78 |
"input_email_body": email_input.input_email_body,
|
|
|
|
| 83 |
except Exception as e:
|
| 84 |
raise HTTPException(status_code=500, detail=f"Error processing email: {str(e)}")
|
| 85 |
|
| 86 |
+
|
| 87 |
@app.post("/api/v1/unmask-email", response_model=Dict[str, Any])
|
| 88 |
async def unmask_email(masked_email_input: MaskedEmailInput) -> Dict[str, Any]:
|
| 89 |
"""
|
| 90 |
+
Retrieve the original unmasked email.
|
| 91 |
+
|
| 92 |
Args:
|
| 93 |
masked_email_input: Contains the masked email and access key
|
| 94 |
+
|
| 95 |
Returns:
|
| 96 |
The original email data with PII information
|
| 97 |
"""
|
| 98 |
try:
|
| 99 |
# Verify access key matches the global access key
|
| 100 |
+
if masked_email_input.access_key != os.environ.get(
|
| 101 |
+
"EMAIL_ACCESS_KEY", "default_secure_access_key"):
|
| 102 |
raise HTTPException(status_code=401, detail="Invalid access key")
|
| 103 |
+
|
| 104 |
# Retrieve the original email using the masked content
|
| 105 |
+
email_data = pii_masker.get_original_by_masked_email(
|
| 106 |
+
masked_email_input.masked_email
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
if not email_data:
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=404,
|
| 112 |
+
detail="Original email not found for the provided masked email"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
return {
|
| 116 |
"status": "success",
|
| 117 |
"data": {
|
|
|
|
| 127 |
except Exception as e:
|
| 128 |
if isinstance(e, HTTPException):
|
| 129 |
raise e
|
| 130 |
+
raise HTTPException(
|
| 131 |
+
status_code=500,
|
| 132 |
+
detail=f"Error retrieving original email: {str(e)}"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
|
| 136 |
@app.get("/health")
|
| 137 |
async def health_check():
|
| 138 |
"""
|
| 139 |
Health check endpoint
|
| 140 |
+
|
| 141 |
Returns:
|
| 142 |
Status message indicating the API is running
|
| 143 |
"""
|
| 144 |
return {"status": "healthy", "message": "Email classification API is running"}
|
| 145 |
|
| 146 |
+
|
| 147 |
# For local development and testing
|
| 148 |
if __name__ == "__main__":
|
| 149 |
port = int(os.environ.get("PORT", 8000))
|
| 150 |
+
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)
|
models.py
CHANGED
|
@@ -3,79 +3,81 @@ import torch
|
|
| 3 |
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
| 4 |
from typing import Dict, Any
|
| 5 |
|
|
|
|
| 6 |
class EmailClassifier:
|
| 7 |
"""
|
| 8 |
Email classification model to categorize emails into different support categories
|
| 9 |
"""
|
| 10 |
-
|
| 11 |
CATEGORIES = ['Change', 'Incident', 'Problem', 'Request']
|
| 12 |
-
|
| 13 |
def __init__(self, model_path: str = None):
|
| 14 |
"""
|
| 15 |
Initialize the email classifier with a pre-trained model
|
| 16 |
-
|
| 17 |
Args:
|
| 18 |
model_path: Path or Hugging Face Hub model ID
|
| 19 |
"""
|
| 20 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
-
|
| 22 |
# Use environment variable for model path or fall back to Hugging Face Hub model
|
| 23 |
# This allows for flexibility in deployment
|
| 24 |
-
model_path = model_path or os.environ.get(
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
# Load the tokenizer and model from Hugging Face Hub or local path
|
| 27 |
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
|
| 28 |
self.model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
|
| 29 |
self.model.to(self.device)
|
| 30 |
self.model.eval()
|
| 31 |
-
|
| 32 |
def classify(self, masked_email: str) -> str:
|
| 33 |
"""
|
| 34 |
Classify a masked email into one of the predefined categories
|
| 35 |
-
|
| 36 |
Args:
|
| 37 |
masked_email: The email content with PII masked
|
| 38 |
-
|
| 39 |
Returns:
|
| 40 |
The predicted category as a string
|
| 41 |
"""
|
| 42 |
# Tokenize the masked email
|
| 43 |
inputs = self.tokenizer(
|
| 44 |
-
masked_email,
|
| 45 |
return_tensors="pt",
|
| 46 |
padding="max_length",
|
| 47 |
truncation=True,
|
| 48 |
max_length=512
|
| 49 |
)
|
| 50 |
-
|
| 51 |
inputs = {key: val.to(self.device) for key, val in inputs.items()}
|
| 52 |
-
|
| 53 |
# Perform inference
|
| 54 |
with torch.no_grad():
|
| 55 |
outputs = self.model(**inputs)
|
| 56 |
logits = outputs.logits
|
| 57 |
predicted_class_idx = torch.argmax(logits, dim=1).item()
|
| 58 |
-
|
| 59 |
# Map the predicted class index to the category
|
| 60 |
return self.CATEGORIES[predicted_class_idx]
|
| 61 |
-
|
| 62 |
def process_email(self, masked_email_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 63 |
"""
|
| 64 |
Process an email by classifying it into a category
|
| 65 |
-
|
| 66 |
Args:
|
| 67 |
masked_email_data: Dictionary containing the masked email and other data
|
| 68 |
-
|
| 69 |
Returns:
|
| 70 |
The input dictionary with the classification added
|
| 71 |
"""
|
| 72 |
# Extract masked email content
|
| 73 |
masked_email = masked_email_data["masked_email"]
|
| 74 |
-
|
| 75 |
# Classify the masked email
|
| 76 |
category = self.classify(masked_email)
|
| 77 |
-
|
| 78 |
# Add the classification to the data
|
| 79 |
masked_email_data["category_of_the_email"] = category
|
| 80 |
-
|
| 81 |
-
return masked_email_data
|
|
|
|
| 3 |
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
| 4 |
from typing import Dict, Any
|
| 5 |
|
| 6 |
+
|
| 7 |
class EmailClassifier:
|
| 8 |
"""
|
| 9 |
Email classification model to categorize emails into different support categories
|
| 10 |
"""
|
|
|
|
| 11 |
CATEGORIES = ['Change', 'Incident', 'Problem', 'Request']
|
| 12 |
+
|
| 13 |
def __init__(self, model_path: str = None):
|
| 14 |
"""
|
| 15 |
Initialize the email classifier with a pre-trained model
|
| 16 |
+
|
| 17 |
Args:
|
| 18 |
model_path: Path or Hugging Face Hub model ID
|
| 19 |
"""
|
| 20 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
# Use environment variable for model path or fall back to Hugging Face Hub model
|
| 23 |
# This allows for flexibility in deployment
|
| 24 |
+
model_path = model_path or os.environ.get(
|
| 25 |
+
"MODEL_PATH", "Sparkonix11/email-classifier-model"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
# Load the tokenizer and model from Hugging Face Hub or local path
|
| 29 |
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
|
| 30 |
self.model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
|
| 31 |
self.model.to(self.device)
|
| 32 |
self.model.eval()
|
| 33 |
+
|
| 34 |
def classify(self, masked_email: str) -> str:
|
| 35 |
"""
|
| 36 |
Classify a masked email into one of the predefined categories
|
| 37 |
+
|
| 38 |
Args:
|
| 39 |
masked_email: The email content with PII masked
|
| 40 |
+
|
| 41 |
Returns:
|
| 42 |
The predicted category as a string
|
| 43 |
"""
|
| 44 |
# Tokenize the masked email
|
| 45 |
inputs = self.tokenizer(
|
| 46 |
+
masked_email,
|
| 47 |
return_tensors="pt",
|
| 48 |
padding="max_length",
|
| 49 |
truncation=True,
|
| 50 |
max_length=512
|
| 51 |
)
|
| 52 |
+
|
| 53 |
inputs = {key: val.to(self.device) for key, val in inputs.items()}
|
| 54 |
+
|
| 55 |
# Perform inference
|
| 56 |
with torch.no_grad():
|
| 57 |
outputs = self.model(**inputs)
|
| 58 |
logits = outputs.logits
|
| 59 |
predicted_class_idx = torch.argmax(logits, dim=1).item()
|
| 60 |
+
|
| 61 |
# Map the predicted class index to the category
|
| 62 |
return self.CATEGORIES[predicted_class_idx]
|
| 63 |
+
|
| 64 |
def process_email(self, masked_email_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 65 |
"""
|
| 66 |
Process an email by classifying it into a category
|
| 67 |
+
|
| 68 |
Args:
|
| 69 |
masked_email_data: Dictionary containing the masked email and other data
|
| 70 |
+
|
| 71 |
Returns:
|
| 72 |
The input dictionary with the classification added
|
| 73 |
"""
|
| 74 |
# Extract masked email content
|
| 75 |
masked_email = masked_email_data["masked_email"]
|
| 76 |
+
|
| 77 |
# Classify the masked email
|
| 78 |
category = self.classify(masked_email)
|
| 79 |
+
|
| 80 |
# Add the classification to the data
|
| 81 |
masked_email_data["category_of_the_email"] = category
|
| 82 |
+
|
| 83 |
+
return masked_email_data
|
requirements.in
DELETED
|
@@ -1,381 +0,0 @@
|
|
| 1 |
-
# This file was autogenerated by uv via the following command:
|
| 2 |
-
# uv pip compile requirements.in -o requirements.txt
|
| 3 |
-
annotated-types==0.7.0
|
| 4 |
-
# via
|
| 5 |
-
# -r requirements.in
|
| 6 |
-
# pydantic
|
| 7 |
-
anyio==4.9.0
|
| 8 |
-
# via
|
| 9 |
-
# -r requirements.in
|
| 10 |
-
# starlette
|
| 11 |
-
blis==1.3.0
|
| 12 |
-
# via
|
| 13 |
-
# -r requirements.in
|
| 14 |
-
# thinc
|
| 15 |
-
catalogue==2.0.10
|
| 16 |
-
# via
|
| 17 |
-
# -r requirements.in
|
| 18 |
-
# spacy
|
| 19 |
-
# srsly
|
| 20 |
-
# thinc
|
| 21 |
-
certifi==2025.4.26
|
| 22 |
-
# via
|
| 23 |
-
# -r requirements.in
|
| 24 |
-
# requests
|
| 25 |
-
charset-normalizer==3.4.2
|
| 26 |
-
# via
|
| 27 |
-
# -r requirements.in
|
| 28 |
-
# requests
|
| 29 |
-
click==8.2.0
|
| 30 |
-
# via
|
| 31 |
-
# -r requirements.in
|
| 32 |
-
# typer
|
| 33 |
-
# uvicorn
|
| 34 |
-
cloudpathlib==0.21.1
|
| 35 |
-
# via
|
| 36 |
-
# -r requirements.in
|
| 37 |
-
# weasel
|
| 38 |
-
confection==0.1.5
|
| 39 |
-
# via
|
| 40 |
-
# -r requirements.in
|
| 41 |
-
# thinc
|
| 42 |
-
# weasel
|
| 43 |
-
cymem==2.0.11
|
| 44 |
-
# via
|
| 45 |
-
# -r requirements.in
|
| 46 |
-
# preshed
|
| 47 |
-
# spacy
|
| 48 |
-
# thinc
|
| 49 |
-
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
|
| 50 |
-
# via -r requirements.in
|
| 51 |
-
exceptiongroup==1.3.0
|
| 52 |
-
# via
|
| 53 |
-
# -r requirements.in
|
| 54 |
-
# anyio
|
| 55 |
-
fastapi==0.115.12
|
| 56 |
-
# via -r requirements.in
|
| 57 |
-
filelock==3.18.0
|
| 58 |
-
# via
|
| 59 |
-
# -r requirements.in
|
| 60 |
-
# huggingface-hub
|
| 61 |
-
# torch
|
| 62 |
-
# transformers
|
| 63 |
-
fsspec==2025.3.2
|
| 64 |
-
# via
|
| 65 |
-
# -r requirements.in
|
| 66 |
-
# huggingface-hub
|
| 67 |
-
# torch
|
| 68 |
-
h11==0.16.0
|
| 69 |
-
# via
|
| 70 |
-
# -r requirements.in
|
| 71 |
-
# uvicorn
|
| 72 |
-
huggingface-hub==0.31.2
|
| 73 |
-
# via
|
| 74 |
-
# -r requirements.in
|
| 75 |
-
# tokenizers
|
| 76 |
-
# transformers
|
| 77 |
-
idna==3.10
|
| 78 |
-
# via
|
| 79 |
-
# -r requirements.in
|
| 80 |
-
# anyio
|
| 81 |
-
# requests
|
| 82 |
-
jinja2==3.1.6
|
| 83 |
-
# via
|
| 84 |
-
# -r requirements.in
|
| 85 |
-
# spacy
|
| 86 |
-
# torch
|
| 87 |
-
langcodes==3.5.0
|
| 88 |
-
# via
|
| 89 |
-
# -r requirements.in
|
| 90 |
-
# spacy
|
| 91 |
-
language-data==1.3.0
|
| 92 |
-
# via
|
| 93 |
-
# -r requirements.in
|
| 94 |
-
# langcodes
|
| 95 |
-
marisa-trie==1.2.1
|
| 96 |
-
# via
|
| 97 |
-
# -r requirements.in
|
| 98 |
-
# language-data
|
| 99 |
-
markdown-it-py==3.0.0
|
| 100 |
-
# via
|
| 101 |
-
# -r requirements.in
|
| 102 |
-
# rich
|
| 103 |
-
markupsafe==3.0.2
|
| 104 |
-
# via
|
| 105 |
-
# -r requirements.in
|
| 106 |
-
# jinja2
|
| 107 |
-
mdurl==0.1.2
|
| 108 |
-
# via
|
| 109 |
-
# -r requirements.in
|
| 110 |
-
# markdown-it-py
|
| 111 |
-
mpmath==1.3.0
|
| 112 |
-
# via
|
| 113 |
-
# -r requirements.in
|
| 114 |
-
# sympy
|
| 115 |
-
murmurhash==1.0.12
|
| 116 |
-
# via
|
| 117 |
-
# -r requirements.in
|
| 118 |
-
# preshed
|
| 119 |
-
# spacy
|
| 120 |
-
# thinc
|
| 121 |
-
networkx==3.4.2
|
| 122 |
-
# via
|
| 123 |
-
# -r requirements.in
|
| 124 |
-
# torch
|
| 125 |
-
numpy==2.2.5
|
| 126 |
-
# via
|
| 127 |
-
# -r requirements.in
|
| 128 |
-
# blis
|
| 129 |
-
# spacy
|
| 130 |
-
# spacy-transformers
|
| 131 |
-
# thinc
|
| 132 |
-
# transformers
|
| 133 |
-
|
| 134 |
-
# SQLite is included in Python standard library
|
| 135 |
-
python-dotenv
|
| 136 |
-
# for environment variable management
|
| 137 |
-
|
| 138 |
-
nvidia-cublas-cu12==12.6.4.1
|
| 139 |
-
# via
|
| 140 |
-
# -r requirements.in
|
| 141 |
-
# nvidia-cudnn-cu12
|
| 142 |
-
# nvidia-cusolver-cu12
|
| 143 |
-
# torch
|
| 144 |
-
nvidia-cuda-cupti-cu12==12.6.80
|
| 145 |
-
# via
|
| 146 |
-
# -r requirements.in
|
| 147 |
-
# torch
|
| 148 |
-
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 149 |
-
# via
|
| 150 |
-
# -r requirements.in
|
| 151 |
-
# torch
|
| 152 |
-
nvidia-cuda-runtime-cu12==12.6.77
|
| 153 |
-
# via
|
| 154 |
-
# -r requirements.in
|
| 155 |
-
# torch
|
| 156 |
-
nvidia-cudnn-cu12==9.5.1.17
|
| 157 |
-
# via
|
| 158 |
-
# -r requirements.in
|
| 159 |
-
# torch
|
| 160 |
-
nvidia-cufft-cu12==11.3.0.4
|
| 161 |
-
# via
|
| 162 |
-
# -r requirements.in
|
| 163 |
-
# torch
|
| 164 |
-
nvidia-cufile-cu12==1.11.1.6
|
| 165 |
-
# via
|
| 166 |
-
# -r requirements.in
|
| 167 |
-
# torch
|
| 168 |
-
nvidia-curand-cu12==10.3.7.77
|
| 169 |
-
# via
|
| 170 |
-
# -r requirements.in
|
| 171 |
-
# torch
|
| 172 |
-
nvidia-cusolver-cu12==11.7.1.2
|
| 173 |
-
# via
|
| 174 |
-
# -r requirements.in
|
| 175 |
-
# torch
|
| 176 |
-
nvidia-cusparse-cu12==12.5.4.2
|
| 177 |
-
# via
|
| 178 |
-
# -r requirements.in
|
| 179 |
-
# nvidia-cusolver-cu12
|
| 180 |
-
# torch
|
| 181 |
-
nvidia-cusparselt-cu12==0.6.3
|
| 182 |
-
# via
|
| 183 |
-
# -r requirements.in
|
| 184 |
-
# torch
|
| 185 |
-
nvidia-nccl-cu12==2.26.2
|
| 186 |
-
# via
|
| 187 |
-
# -r requirements.in
|
| 188 |
-
# torch
|
| 189 |
-
nvidia-nvjitlink-cu12==12.6.85
|
| 190 |
-
# via
|
| 191 |
-
# -r requirements.in
|
| 192 |
-
# nvidia-cufft-cu12
|
| 193 |
-
# nvidia-cusolver-cu12
|
| 194 |
-
# nvidia-cusparse-cu12
|
| 195 |
-
# torch
|
| 196 |
-
nvidia-nvtx-cu12==12.6.77
|
| 197 |
-
# via
|
| 198 |
-
# -r requirements.in
|
| 199 |
-
# torch
|
| 200 |
-
packaging==25.0
|
| 201 |
-
# via
|
| 202 |
-
# -r requirements.in
|
| 203 |
-
# huggingface-hub
|
| 204 |
-
# spacy
|
| 205 |
-
# thinc
|
| 206 |
-
# transformers
|
| 207 |
-
# weasel
|
| 208 |
-
preshed==3.0.9
|
| 209 |
-
# via
|
| 210 |
-
# -r requirements.in
|
| 211 |
-
# spacy
|
| 212 |
-
# thinc
|
| 213 |
-
pydantic==2.11.4
|
| 214 |
-
# via
|
| 215 |
-
# -r requirements.in
|
| 216 |
-
# confection
|
| 217 |
-
# fastapi
|
| 218 |
-
# spacy
|
| 219 |
-
# thinc
|
| 220 |
-
# weasel
|
| 221 |
-
pydantic-core==2.33.2
|
| 222 |
-
# via
|
| 223 |
-
# -r requirements.in
|
| 224 |
-
# pydantic
|
| 225 |
-
pygments==2.19.1
|
| 226 |
-
# via
|
| 227 |
-
# -r requirements.in
|
| 228 |
-
# rich
|
| 229 |
-
python-multipart==0.0.20
|
| 230 |
-
# via -r requirements.in
|
| 231 |
-
pyyaml==6.0.2
|
| 232 |
-
# via
|
| 233 |
-
# -r requirements.in
|
| 234 |
-
# huggingface-hub
|
| 235 |
-
# transformers
|
| 236 |
-
regex==2024.11.6
|
| 237 |
-
# via
|
| 238 |
-
# -r requirements.in
|
| 239 |
-
# transformers
|
| 240 |
-
requests==2.32.3
|
| 241 |
-
# via
|
| 242 |
-
# -r requirements.in
|
| 243 |
-
# huggingface-hub
|
| 244 |
-
# spacy
|
| 245 |
-
# transformers
|
| 246 |
-
# weasel
|
| 247 |
-
rich==14.0.0
|
| 248 |
-
# via
|
| 249 |
-
# -r requirements.in
|
| 250 |
-
# typer
|
| 251 |
-
safetensors==0.5.3
|
| 252 |
-
# via
|
| 253 |
-
# -r requirements.in
|
| 254 |
-
# transformers
|
| 255 |
-
sentencepiece==0.2.0
|
| 256 |
-
# via -r requirements.in
|
| 257 |
-
setuptools==80.7.1
|
| 258 |
-
# via
|
| 259 |
-
# -r requirements.in
|
| 260 |
-
# marisa-trie
|
| 261 |
-
# spacy
|
| 262 |
-
# thinc
|
| 263 |
-
# triton
|
| 264 |
-
shellingham==1.5.4
|
| 265 |
-
# via
|
| 266 |
-
# -r requirements.in
|
| 267 |
-
# typer
|
| 268 |
-
smart-open==7.1.0
|
| 269 |
-
# via
|
| 270 |
-
# -r requirements.in
|
| 271 |
-
# weasel
|
| 272 |
-
sniffio==1.3.1
|
| 273 |
-
# via
|
| 274 |
-
# -r requirements.in
|
| 275 |
-
# anyio
|
| 276 |
-
spacy==3.8.5
|
| 277 |
-
# via
|
| 278 |
-
# -r requirements.in
|
| 279 |
-
# spacy-transformers
|
| 280 |
-
spacy-alignments==0.9.1
|
| 281 |
-
# via
|
| 282 |
-
# -r requirements.in
|
| 283 |
-
# spacy-transformers
|
| 284 |
-
spacy-legacy==3.0.12
|
| 285 |
-
# via
|
| 286 |
-
# -r requirements.in
|
| 287 |
-
# spacy
|
| 288 |
-
spacy-loggers==1.0.5
|
| 289 |
-
# via
|
| 290 |
-
# -r requirements.in
|
| 291 |
-
# spacy
|
| 292 |
-
spacy-transformers==1.3.8
|
| 293 |
-
# via -r requirements.in
|
| 294 |
-
srsly==2.5.1
|
| 295 |
-
# via
|
| 296 |
-
# -r requirements.in
|
| 297 |
-
# confection
|
| 298 |
-
# spacy
|
| 299 |
-
# spacy-transformers
|
| 300 |
-
# thinc
|
| 301 |
-
# weasel
|
| 302 |
-
starlette==0.46.2
|
| 303 |
-
# via
|
| 304 |
-
# -r requirements.in
|
| 305 |
-
# fastapi
|
| 306 |
-
sympy==1.14.0
|
| 307 |
-
# via
|
| 308 |
-
# -r requirements.in
|
| 309 |
-
# torch
|
| 310 |
-
thinc==8.3.6
|
| 311 |
-
# via
|
| 312 |
-
# -r requirements.in
|
| 313 |
-
# spacy
|
| 314 |
-
tokenizers==0.21.1
|
| 315 |
-
# via
|
| 316 |
-
# -r requirements.in
|
| 317 |
-
# transformers
|
| 318 |
-
torch==2.7.0
|
| 319 |
-
# via
|
| 320 |
-
# -r requirements.in
|
| 321 |
-
# spacy-transformers
|
| 322 |
-
tqdm==4.67.1
|
| 323 |
-
# via
|
| 324 |
-
# -r requirements.in
|
| 325 |
-
# huggingface-hub
|
| 326 |
-
# spacy
|
| 327 |
-
# transformers
|
| 328 |
-
transformers==4.49.0
|
| 329 |
-
# via
|
| 330 |
-
# -r requirements.in
|
| 331 |
-
# spacy-transformers
|
| 332 |
-
triton==3.3.0
|
| 333 |
-
# via
|
| 334 |
-
# -r requirements.in
|
| 335 |
-
# torch
|
| 336 |
-
typer==0.15.3
|
| 337 |
-
# via
|
| 338 |
-
# -r requirements.in
|
| 339 |
-
# spacy
|
| 340 |
-
# weasel
|
| 341 |
-
typing-extensions==4.13.2
|
| 342 |
-
# via
|
| 343 |
-
# -r requirements.in
|
| 344 |
-
# anyio
|
| 345 |
-
# cloudpathlib
|
| 346 |
-
# exceptiongroup
|
| 347 |
-
# fastapi
|
| 348 |
-
# huggingface-hub
|
| 349 |
-
# pydantic
|
| 350 |
-
# pydantic-core
|
| 351 |
-
# rich
|
| 352 |
-
# torch
|
| 353 |
-
# typer
|
| 354 |
-
# typing-inspection
|
| 355 |
-
# uvicorn
|
| 356 |
-
typing-inspection==0.4.0
|
| 357 |
-
# via
|
| 358 |
-
# -r requirements.in
|
| 359 |
-
# pydantic
|
| 360 |
-
urllib3==2.4.0
|
| 361 |
-
# via
|
| 362 |
-
# -r requirements.in
|
| 363 |
-
# requests
|
| 364 |
-
uvicorn==0.34.2
|
| 365 |
-
# via -r requirements.in
|
| 366 |
-
wasabi==1.1.3
|
| 367 |
-
# via
|
| 368 |
-
# -r requirements.in
|
| 369 |
-
# spacy
|
| 370 |
-
# thinc
|
| 371 |
-
# weasel
|
| 372 |
-
weasel==0.4.1
|
| 373 |
-
# via
|
| 374 |
-
# -r requirements.in
|
| 375 |
-
# spacy
|
| 376 |
-
wrapt==1.17.2
|
| 377 |
-
# via
|
| 378 |
-
# -r requirements.in
|
| 379 |
-
# smart-open
|
| 380 |
-
xx-ent-wiki-sm @ https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.8.0/xx_ent_wiki_sm-3.8.0-py3-none-any.whl
|
| 381 |
-
# via -r requirements.in
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
upload_model.py
CHANGED
|
@@ -2,27 +2,30 @@
|
|
| 2 |
Script to upload the email classification model to Hugging Face Hub
|
| 3 |
"""
|
| 4 |
|
| 5 |
-
import os
|
| 6 |
import sys
|
| 7 |
import argparse
|
| 8 |
import subprocess
|
| 9 |
import pkg_resources
|
| 10 |
|
|
|
|
| 11 |
def check_and_install_dependencies():
|
| 12 |
"""Check for required libraries and install if missing"""
|
| 13 |
required_packages = ['torch', 'transformers', 'sentencepiece']
|
| 14 |
installed_packages = {pkg.key for pkg in pkg_resources.working_set}
|
| 15 |
-
|
| 16 |
missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages]
|
| 17 |
-
|
| 18 |
if missing_packages:
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
print("Dependencies installed. You may need to restart the script.")
|
| 22 |
return False
|
| 23 |
-
|
| 24 |
return True
|
| 25 |
|
|
|
|
| 26 |
def get_huggingface_username(token=None):
|
| 27 |
"""Get the username for the authenticated user"""
|
| 28 |
try:
|
|
@@ -34,51 +37,57 @@ def get_huggingface_username(token=None):
|
|
| 34 |
print(f"Error getting Hugging Face username: {e}")
|
| 35 |
return None
|
| 36 |
|
|
|
|
| 37 |
def main():
|
| 38 |
"""Upload model to Hugging Face Hub"""
|
| 39 |
# Check dependencies first
|
| 40 |
if not check_and_install_dependencies():
|
| 41 |
return
|
| 42 |
-
|
| 43 |
# Import dependencies after installation check
|
| 44 |
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
| 45 |
from huggingface_hub import login
|
| 46 |
-
|
| 47 |
-
parser = argparse.ArgumentParser(
|
|
|
|
| 48 |
parser.add_argument("--model_path", type=str, default="classification_model",
|
| 49 |
help="Local path to the model files")
|
| 50 |
parser.add_argument("--hub_model_id", type=str,
|
| 51 |
-
help="Hugging Face Hub model ID (e.g.,
|
|
|
|
| 52 |
parser.add_argument("--model_name", type=str, default="email-classifier-model",
|
| 53 |
-
help="Name for the model repository
|
|
|
|
| 54 |
parser.add_argument("--token", type=str,
|
| 55 |
-
help="Hugging Face API token (optional, can use
|
| 56 |
-
|
|
|
|
| 57 |
args = parser.parse_args()
|
| 58 |
-
|
| 59 |
# Login if token is provided
|
| 60 |
if args.token:
|
| 61 |
login(token=args.token)
|
| 62 |
-
|
| 63 |
# If hub_model_id is not provided, try to get username and construct it
|
| 64 |
if not args.hub_model_id:
|
| 65 |
username = get_huggingface_username(args.token)
|
| 66 |
if not username:
|
| 67 |
-
print("Could not determine Hugging Face username.
|
|
|
|
| 68 |
return
|
| 69 |
args.hub_model_id = f"{username}/{args.model_name}"
|
| 70 |
-
|
| 71 |
print(f"Loading model from {args.model_path}...")
|
| 72 |
# Load the local model and tokenizer
|
| 73 |
model = XLMRobertaForSequenceClassification.from_pretrained(args.model_path)
|
| 74 |
tokenizer = XLMRobertaTokenizer.from_pretrained(args.model_path)
|
| 75 |
-
|
| 76 |
print(f"Uploading model to {args.hub_model_id}...")
|
| 77 |
try:
|
| 78 |
# Push to Hugging Face Hub
|
| 79 |
model.push_to_hub(args.hub_model_id)
|
| 80 |
tokenizer.push_to_hub(args.hub_model_id)
|
| 81 |
-
|
| 82 |
print("Model successfully uploaded to Hugging Face Hub!")
|
| 83 |
print(f"You can now use the model with the ID: {args.hub_model_id}")
|
| 84 |
print(f"Update the MODEL_PATH in Dockerfile to: {args.hub_model_id}")
|
|
@@ -86,8 +95,11 @@ def main():
|
|
| 86 |
print(f"Error uploading model: {e}")
|
| 87 |
print("\nPossible solutions:")
|
| 88 |
print("1. Make sure you're logged in with 'huggingface-cli login'")
|
| 89 |
-
print("2. Check that you have permission to create repos in the
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
| 93 |
-
main()
|
|
|
|
| 2 |
Script to upload the email classification model to Hugging Face Hub
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
import sys
|
| 6 |
import argparse
|
| 7 |
import subprocess
|
| 8 |
import pkg_resources
|
| 9 |
|
| 10 |
+
|
| 11 |
def check_and_install_dependencies():
|
| 12 |
"""Check for required libraries and install if missing"""
|
| 13 |
required_packages = ['torch', 'transformers', 'sentencepiece']
|
| 14 |
installed_packages = {pkg.key for pkg in pkg_resources.working_set}
|
| 15 |
+
|
| 16 |
missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages]
|
| 17 |
+
|
| 18 |
if missing_packages:
|
| 19 |
+
missing_packages_str = ", ".join(missing_packages)
|
| 20 |
+
print(f"Installing missing dependencies: {missing_packages_str}")
|
| 21 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install"]
|
| 22 |
+
+ missing_packages)
|
| 23 |
print("Dependencies installed. You may need to restart the script.")
|
| 24 |
return False
|
| 25 |
+
|
| 26 |
return True
|
| 27 |
|
| 28 |
+
|
| 29 |
def get_huggingface_username(token=None):
|
| 30 |
"""Get the username for the authenticated user"""
|
| 31 |
try:
|
|
|
|
| 37 |
print(f"Error getting Hugging Face username: {e}")
|
| 38 |
return None
|
| 39 |
|
| 40 |
+
|
| 41 |
def main():
|
| 42 |
"""Upload model to Hugging Face Hub"""
|
| 43 |
# Check dependencies first
|
| 44 |
if not check_and_install_dependencies():
|
| 45 |
return
|
| 46 |
+
|
| 47 |
# Import dependencies after installation check
|
| 48 |
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
| 49 |
from huggingface_hub import login
|
| 50 |
+
|
| 51 |
+
parser = argparse.ArgumentParser(
|
| 52 |
+
description="Upload email classification model to Hugging Face Hub")
|
| 53 |
parser.add_argument("--model_path", type=str, default="classification_model",
|
| 54 |
help="Local path to the model files")
|
| 55 |
parser.add_argument("--hub_model_id", type=str,
|
| 56 |
+
help="Hugging Face Hub model ID (e.g., "
|
| 57 |
+
"'username/email-classifier-model')")
|
| 58 |
parser.add_argument("--model_name", type=str, default="email-classifier-model",
|
| 59 |
+
help="Name for the model repository "
|
| 60 |
+
"(default: email-classifier-model)")
|
| 61 |
parser.add_argument("--token", type=str,
|
| 62 |
+
help="Hugging Face API token (optional, can use "
|
| 63 |
+
"environment variable or huggingface-cli login)")
|
| 64 |
+
|
| 65 |
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
# Login if token is provided
|
| 68 |
if args.token:
|
| 69 |
login(token=args.token)
|
| 70 |
+
|
| 71 |
# If hub_model_id is not provided, try to get username and construct it
|
| 72 |
if not args.hub_model_id:
|
| 73 |
username = get_huggingface_username(args.token)
|
| 74 |
if not username:
|
| 75 |
+
print("Could not determine Hugging Face username. "
|
| 76 |
+
"Please provide --hub_model_id explicitly.")
|
| 77 |
return
|
| 78 |
args.hub_model_id = f"{username}/{args.model_name}"
|
| 79 |
+
|
| 80 |
print(f"Loading model from {args.model_path}...")
|
| 81 |
# Load the local model and tokenizer
|
| 82 |
model = XLMRobertaForSequenceClassification.from_pretrained(args.model_path)
|
| 83 |
tokenizer = XLMRobertaTokenizer.from_pretrained(args.model_path)
|
| 84 |
+
|
| 85 |
print(f"Uploading model to {args.hub_model_id}...")
|
| 86 |
try:
|
| 87 |
# Push to Hugging Face Hub
|
| 88 |
model.push_to_hub(args.hub_model_id)
|
| 89 |
tokenizer.push_to_hub(args.hub_model_id)
|
| 90 |
+
|
| 91 |
print("Model successfully uploaded to Hugging Face Hub!")
|
| 92 |
print(f"You can now use the model with the ID: {args.hub_model_id}")
|
| 93 |
print(f"Update the MODEL_PATH in Dockerfile to: {args.hub_model_id}")
|
|
|
|
| 95 |
print(f"Error uploading model: {e}")
|
| 96 |
print("\nPossible solutions:")
|
| 97 |
print("1. Make sure you're logged in with 'huggingface-cli login'")
|
| 98 |
+
print("2. Check that you have permission to create repos in the "
|
| 99 |
+
"specified namespace")
|
| 100 |
+
print("3. Try using your own username: "
|
| 101 |
+
"--hub_model_id yourusername/email-classifier-model")
|
| 102 |
+
|
| 103 |
|
| 104 |
if __name__ == "__main__":
|
| 105 |
+
main()
|
utils.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
import re
|
| 2 |
import spacy
|
| 3 |
from typing import List, Dict, Tuple, Any, Optional
|
|
|
|
| 4 |
from database import EmailDatabase
|
| 5 |
|
|
|
|
| 6 |
class Entity:
|
| 7 |
def __init__(self, start: int, end: int, entity_type: str, value: str):
|
| 8 |
self.start = start
|
|
@@ -17,11 +19,19 @@ class Entity:
|
|
| 17 |
"entity": self.value
|
| 18 |
}
|
| 19 |
|
| 20 |
-
def __repr__(self):
|
| 21 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class PIIMasker:
|
| 24 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Load SpaCy model
|
| 26 |
try:
|
| 27 |
self.nlp = spacy.load(spacy_model_name)
|
|
@@ -42,7 +52,7 @@ class PIIMasker:
|
|
| 42 |
|
| 43 |
# Initialize database connection with SQLite path
|
| 44 |
self.db = EmailDatabase(connection_string=db_path)
|
| 45 |
-
|
| 46 |
# Initialize regex patterns
|
| 47 |
self._initialize_patterns()
|
| 48 |
|
|
@@ -51,7 +61,11 @@ class PIIMasker:
|
|
| 51 |
self.patterns = {
|
| 52 |
"email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
| 53 |
# Simplified phone regex to capture both standard and international formats
|
| 54 |
-
"phone_number":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# Card number regex: common formats, allows optional spaces/hyphens
|
| 56 |
"credit_debit_no": r'\b(?:(?:\d{4}[\s-]?){3}\d{4}|\d{13,19})\b',
|
| 57 |
# CVV: 3 or 4 digits, ensuring it's a standalone number (word boundary)
|
|
@@ -60,7 +74,10 @@ class PIIMasker:
|
|
| 60 |
"expiry_no": r'\b(0[1-9]|1[0-2])[/\s-]([0-9]{2}|20[0-9]{2})\b',
|
| 61 |
"aadhar_num": r'\b\d{4}\s?\d{4}\s?\d{4}\b',
|
| 62 |
# DOB: DD/MM/YYYY or DD-MM-YYYY etc.
|
| 63 |
-
"dob":
|
|
|
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
def detect_regex_entities(self, text: str) -> List[Entity]:
|
|
@@ -83,14 +100,19 @@ class PIIMasker:
|
|
| 83 |
if not self.verify_phone_number(text, match):
|
| 84 |
continue
|
| 85 |
elif entity_type == "dob":
|
| 86 |
-
if not self._verify_with_context(
|
|
|
|
|
|
|
| 87 |
continue
|
| 88 |
|
| 89 |
-
# Avoid detecting parts of already matched longer entities
|
|
|
|
| 90 |
# This is a simple check; more robust overlap handling is done later
|
| 91 |
is_substring_of_existing = False
|
| 92 |
for existing_entity in entities:
|
| 93 |
-
if existing_entity.start <= start
|
|
|
|
|
|
|
| 94 |
is_substring_of_existing = True
|
| 95 |
break
|
| 96 |
if is_substring_of_existing:
|
|
@@ -99,7 +121,9 @@ class PIIMasker:
|
|
| 99 |
entities.append(Entity(start, end, entity_type, value))
|
| 100 |
return entities
|
| 101 |
|
| 102 |
-
def _verify_with_context(
|
|
|
|
|
|
|
| 103 |
"""Verify an entity match using surrounding context"""
|
| 104 |
context_before = text[max(0, start - window):start].lower()
|
| 105 |
context_after = text[end:min(len(text), end + window)].lower()
|
|
@@ -117,7 +141,10 @@ class PIIMasker:
|
|
| 117 |
context_before = text[max(0, start - context_window):start].lower()
|
| 118 |
context_after = text[end:min(len(text), end + context_window)].lower()
|
| 119 |
|
| 120 |
-
card_keywords = [
|
|
|
|
|
|
|
|
|
|
| 121 |
for keyword in card_keywords:
|
| 122 |
if keyword in context_before or keyword in context_after:
|
| 123 |
return True
|
|
@@ -125,19 +152,19 @@ class PIIMasker:
|
|
| 125 |
# For simplicity, we'll rely on context here. If needed, Luhn can be added.
|
| 126 |
return False
|
| 127 |
|
| 128 |
-
|
| 129 |
def verify_cvv(self, text: str, match: re.Match) -> bool:
|
| 130 |
"""Verify if a 3-4 digit number is actually a CVV using contextual clues"""
|
| 131 |
context_window = 50
|
| 132 |
start, end = match.span()
|
| 133 |
value = match.group()
|
| 134 |
|
| 135 |
-
# If it's part of a longer number sequence (like a phone number or ID),
|
|
|
|
| 136 |
# Check character immediately before and after
|
| 137 |
-
char_before = text[start-1:start] if start > 0 else ""
|
| 138 |
-
char_after = text[end:end+1] if end < len(text) else ""
|
| 139 |
if char_before.isdigit() or char_after.isdigit():
|
| 140 |
-
return False
|
| 141 |
|
| 142 |
# Only consider 3-4 digit numbers
|
| 143 |
if not value.isdigit() or len(value) < 3 or len(value) > 4:
|
|
@@ -148,14 +175,16 @@ class PIIMasker:
|
|
| 148 |
|
| 149 |
# Expanded list of CVV-related keywords to improve detection
|
| 150 |
cvv_keywords = [
|
| 151 |
-
"cvv", "cvc", "csc", "security code", "card verification",
|
| 152 |
-
"security", "security number", "cv2",
|
|
|
|
| 153 |
]
|
| 154 |
-
|
| 155 |
-
date_keywords = ["date", "year", "/", "born", "age", "since", "established"]
|
| 156 |
|
| 157 |
# Look for CVV context clues
|
| 158 |
-
is_cvv_context = any(
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# If explicitly mentioned as a CVV, immediately return true
|
| 161 |
if is_cvv_context:
|
|
@@ -163,17 +192,23 @@ class PIIMasker:
|
|
| 163 |
|
| 164 |
# If it looks like a year, reject it
|
| 165 |
if len(value) == 4 and 1900 <= int(value) <= 2100:
|
| 166 |
-
if any(
|
|
|
|
|
|
|
|
|
|
| 167 |
return False
|
| 168 |
|
| 169 |
# If in expiry date context, reject it
|
| 170 |
if re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()):
|
| 171 |
return False
|
| 172 |
-
|
| 173 |
-
# If no context clues but we have a credit card mention nearby,
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
| 177 |
return is_cvv_context or (card_context and len(value) in [3, 4])
|
| 178 |
|
| 179 |
def verify_phone_number(self, text: str, match: re.Match) -> bool:
|
|
@@ -182,56 +217,62 @@ class PIIMasker:
|
|
| 182 |
"""
|
| 183 |
value = match.group()
|
| 184 |
start, end = match.span()
|
| 185 |
-
|
| 186 |
# Extract only digits to count them
|
| 187 |
digits = ''.join(c for c in value if c.isdigit())
|
| 188 |
digit_count = len(digits)
|
| 189 |
-
|
| 190 |
# Most phone numbers worldwide have between 7 and 15 digits
|
| 191 |
if digit_count < 7 or digit_count > 15:
|
| 192 |
return False
|
| 193 |
-
|
| 194 |
# Check for common phone number indicators
|
| 195 |
context_window = 50
|
| 196 |
context_before = text[max(0, start - context_window):start].lower()
|
| 197 |
context_after = text[end:min(len(text), end + context_window)].lower()
|
| 198 |
-
|
| 199 |
# Expanded phone keywords
|
| 200 |
phone_keywords = [
|
| 201 |
-
"phone", "call", "tel", "telephone", "contact", "dial", "mobile",
|
| 202 |
-
"number", "direct", "office", "fax", "reach me at",
|
| 203 |
-
"line", "extension", "ext", "phone number"
|
| 204 |
]
|
| 205 |
-
|
| 206 |
# Check for phone context
|
| 207 |
-
has_phone_context = any(
|
| 208 |
-
|
|
|
|
|
|
|
| 209 |
# Check for formatting that indicates a phone number
|
| 210 |
-
has_phone_formatting = bool(re.search(r'[-\s.()
|
| 211 |
-
|
| 212 |
# Check for international prefix
|
| 213 |
has_intl_prefix = value.startswith('+') or value.startswith('00')
|
| 214 |
-
|
| 215 |
# Return true if any of these conditions are met:
|
| 216 |
# 1. Has explicit phone context
|
| 217 |
# 2. Has phone-like formatting AND reasonable digit count
|
| 218 |
# 3. Has international prefix AND reasonable digit count
|
| 219 |
# 4. Has 10 digits exactly (common in many countries) with formatting
|
| 220 |
-
return
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
| 224 |
|
| 225 |
def detect_name_entities(self, text: str) -> List[Entity]:
|
| 226 |
"""Detect name entities using SpaCy NER"""
|
| 227 |
entities = []
|
| 228 |
doc = self.nlp(text)
|
| 229 |
-
|
| 230 |
for ent in doc.ents:
|
| 231 |
# Use PER for person, common in many models like xx_ent_wiki_sm
|
| 232 |
# Also checking for PERSON as some models might use it.
|
| 233 |
if ent.label_ in ["PER", "PERSON"]:
|
| 234 |
-
entities.append(
|
|
|
|
|
|
|
| 235 |
return entities
|
| 236 |
|
| 237 |
def detect_all_entities(self, text: str) -> List[Entity]:
|
|
@@ -265,58 +306,74 @@ class PIIMasker:
|
|
| 265 |
# A simple greedy approach: iterate and remove/adjust overlaps
|
| 266 |
# This can be made more sophisticated
|
| 267 |
resolved_entities: List[Entity] = []
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
| 269 |
is_overlapped_or_contained = False
|
| 270 |
temp_resolved = []
|
| 271 |
for i, res_entity in enumerate(resolved_entities):
|
| 272 |
# Check for overlap:
|
| 273 |
# Current: |----|
|
| 274 |
# Res: |----| or |----| or |--| or |------|
|
| 275 |
-
overlap = max(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
if overlap > 0:
|
| 278 |
is_overlapped_or_contained = True
|
| 279 |
# Preference:
|
| 280 |
-
# 1. NER
|
| 281 |
# 2. Longer entity wins
|
| 282 |
current_len = current_entity.end - current_entity.start
|
| 283 |
res_len = res_entity.end - res_entity.start
|
| 284 |
|
| 285 |
-
# If current is a name and overlaps, and previous is not a name,
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
#
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
temp_resolved.append(res_entity)
|
| 296 |
-
is_overlapped_or_contained = True
|
| 297 |
-
break
|
| 298 |
|
| 299 |
# General case: longer entity wins
|
| 300 |
if current_len > res_len:
|
| 301 |
-
# current is longer, res_entity is removed from
|
| 302 |
-
|
|
|
|
| 303 |
elif res_len > current_len:
|
| 304 |
# res is longer, current is dominated
|
| 305 |
temp_resolved.append(res_entity)
|
| 306 |
-
is_overlapped_or_contained = True
|
| 307 |
break
|
| 308 |
-
else:
|
| 309 |
temp_resolved.append(res_entity)
|
| 310 |
-
is_overlapped_or_contained = True
|
| 311 |
break
|
| 312 |
-
else:
|
| 313 |
temp_resolved.append(res_entity)
|
| 314 |
|
| 315 |
if not is_overlapped_or_contained:
|
| 316 |
temp_resolved.append(current_entity)
|
| 317 |
|
| 318 |
-
resolved_entities = sorted(
|
| 319 |
-
|
|
|
|
| 320 |
|
| 321 |
# Final pass to remove fully contained entities if a larger one exists
|
| 322 |
final_entities = []
|
|
@@ -329,8 +386,10 @@ class PIIMasker:
|
|
| 329 |
if i == j:
|
| 330 |
continue
|
| 331 |
# If 'entity' is strictly contained within 'other_entity'
|
| 332 |
-
if other_entity.start <= entity.start
|
| 333 |
-
|
|
|
|
|
|
|
| 334 |
is_contained = True
|
| 335 |
break
|
| 336 |
if not is_contained:
|
|
@@ -338,7 +397,6 @@ class PIIMasker:
|
|
| 338 |
|
| 339 |
return final_entities
|
| 340 |
|
| 341 |
-
|
| 342 |
def mask_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
|
| 343 |
"""
|
| 344 |
Mask PII entities in the text and return masked text and entity information
|
|
@@ -370,7 +428,6 @@ class PIIMasker:
|
|
| 370 |
|
| 371 |
return "".join(new_text_parts), entity_info
|
| 372 |
|
| 373 |
-
|
| 374 |
def process_email(self, email_text: str) -> Dict[str, Any]:
|
| 375 |
"""
|
| 376 |
Process an email by detecting and masking PII entities.
|
|
@@ -378,56 +435,60 @@ class PIIMasker:
|
|
| 378 |
"""
|
| 379 |
# Mask the email
|
| 380 |
masked_email, entity_info = self.mask_text(email_text)
|
| 381 |
-
|
| 382 |
# Store the email in the SQLite database - only get back email_id now
|
| 383 |
email_id = self.db.store_email(
|
| 384 |
original_email=email_text,
|
| 385 |
masked_email=masked_email,
|
| 386 |
masked_entities=entity_info
|
| 387 |
)
|
| 388 |
-
|
| 389 |
# Return the processed data with just the email_id
|
| 390 |
return {
|
| 391 |
-
"input_email_body": email_text, # Return original
|
| 392 |
"list_of_masked_entities": entity_info,
|
| 393 |
"masked_email": masked_email,
|
| 394 |
"category_of_the_email": "",
|
| 395 |
"email_id": email_id
|
| 396 |
}
|
| 397 |
-
|
| 398 |
-
def get_original_email(
|
|
|
|
|
|
|
| 399 |
"""
|
| 400 |
Retrieve the original email with PII using the email ID and access key.
|
| 401 |
-
|
| 402 |
Args:
|
| 403 |
email_id: The ID of the stored email
|
| 404 |
access_key: The security key for accessing the original email
|
| 405 |
-
|
| 406 |
Returns:
|
| 407 |
The original email data or None if not found or access_key is invalid
|
| 408 |
"""
|
| 409 |
return self.db.get_original_email(email_id, access_key)
|
| 410 |
-
|
| 411 |
def get_masked_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
|
| 412 |
"""
|
| 413 |
Retrieve a masked email by its ID (without the original PII-containing email).
|
| 414 |
-
|
| 415 |
Args:
|
| 416 |
email_id: The ID of the stored email
|
| 417 |
-
|
| 418 |
Returns:
|
| 419 |
The masked email data or None if not found
|
| 420 |
"""
|
| 421 |
return self.db.get_email_by_id(email_id)
|
| 422 |
-
|
| 423 |
-
def get_original_by_masked_email(
|
|
|
|
|
|
|
| 424 |
"""
|
| 425 |
Retrieve the original unmasked email using the masked email content.
|
| 426 |
-
|
| 427 |
Args:
|
| 428 |
masked_email: The masked version of the email to search for
|
| 429 |
-
|
| 430 |
Returns:
|
| 431 |
The original email data or None if not found
|
| 432 |
"""
|
| 433 |
-
return self.db.get_email_by_masked_content(masked_email)
|
|
|
|
| 1 |
import re
|
| 2 |
import spacy
|
| 3 |
from typing import List, Dict, Tuple, Any, Optional
|
| 4 |
+
|
| 5 |
from database import EmailDatabase
|
| 6 |
|
| 7 |
+
|
| 8 |
class Entity:
|
| 9 |
def __init__(self, start: int, end: int, entity_type: str, value: str):
|
| 10 |
self.start = start
|
|
|
|
| 19 |
"entity": self.value
|
| 20 |
}
|
| 21 |
|
| 22 |
+
def __repr__(self): # Added for easier debugging
|
| 23 |
+
return (
|
| 24 |
+
f"Entity(type='{self.entity_type}', value='{self.value}', "
|
| 25 |
+
f"start={self.start}, end={self.end})"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
|
| 29 |
class PIIMasker:
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
spacy_model_name: str = "xx_ent_wiki_sm",
|
| 33 |
+
db_path: str = None
|
| 34 |
+
): # Allow model choice
|
| 35 |
# Load SpaCy model
|
| 36 |
try:
|
| 37 |
self.nlp = spacy.load(spacy_model_name)
|
|
|
|
| 52 |
|
| 53 |
# Initialize database connection with SQLite path
|
| 54 |
self.db = EmailDatabase(connection_string=db_path)
|
| 55 |
+
|
| 56 |
# Initialize regex patterns
|
| 57 |
self._initialize_patterns()
|
| 58 |
|
|
|
|
| 61 |
self.patterns = {
|
| 62 |
"email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
| 63 |
# Simplified phone regex to capture both standard and international formats
|
| 64 |
+
"phone_number": (
|
| 65 |
+
r'\b(?:(?:\+|00)[1-9]\d{0,3}[-\s.]?)?'
|
| 66 |
+
r'(?:\(?\d{1,5}\)?[-\s.]?)?\d{1,5}'
|
| 67 |
+
r'(?:[-\s.]\d{1,5}){1,4}\b'
|
| 68 |
+
),
|
| 69 |
# Card number regex: common formats, allows optional spaces/hyphens
|
| 70 |
"credit_debit_no": r'\b(?:(?:\d{4}[\s-]?){3}\d{4}|\d{13,19})\b',
|
| 71 |
# CVV: 3 or 4 digits, ensuring it's a standalone number (word boundary)
|
|
|
|
| 74 |
"expiry_no": r'\b(0[1-9]|1[0-2])[/\s-]([0-9]{2}|20[0-9]{2})\b',
|
| 75 |
"aadhar_num": r'\b\d{4}\s?\d{4}\s?\d{4}\b',
|
| 76 |
# DOB: DD/MM/YYYY or DD-MM-YYYY etc.
|
| 77 |
+
"dob": (
|
| 78 |
+
r'\b(0[1-9]|[12][0-9]|3[01])[/\s-]'
|
| 79 |
+
r'(0[1-9]|1[0-2])[/\s-](?:19|20)\d\d\b'
|
| 80 |
+
)
|
| 81 |
}
|
| 82 |
|
| 83 |
def detect_regex_entities(self, text: str) -> List[Entity]:
|
|
|
|
| 100 |
if not self.verify_phone_number(text, match):
|
| 101 |
continue
|
| 102 |
elif entity_type == "dob":
|
| 103 |
+
if not self._verify_with_context(
|
| 104 |
+
text, start, end, ["birth", "dob", "born"]
|
| 105 |
+
):
|
| 106 |
continue
|
| 107 |
|
| 108 |
+
# Avoid detecting parts of already matched longer entities
|
| 109 |
+
# (e.g. year within a DOB)
|
| 110 |
# This is a simple check; more robust overlap handling is done later
|
| 111 |
is_substring_of_existing = False
|
| 112 |
for existing_entity in entities:
|
| 113 |
+
if (existing_entity.start <= start
|
| 114 |
+
and existing_entity.end >= end # W504 corrected
|
| 115 |
+
and existing_entity.value != value): # W504 corrected
|
| 116 |
is_substring_of_existing = True
|
| 117 |
break
|
| 118 |
if is_substring_of_existing:
|
|
|
|
| 121 |
entities.append(Entity(start, end, entity_type, value))
|
| 122 |
return entities
|
| 123 |
|
| 124 |
+
def _verify_with_context(
|
| 125 |
+
self, text: str, start: int, end: int, keywords: List[str], window: int = 50
|
| 126 |
+
) -> bool:
|
| 127 |
"""Verify an entity match using surrounding context"""
|
| 128 |
context_before = text[max(0, start - window):start].lower()
|
| 129 |
context_after = text[end:min(len(text), end + window)].lower()
|
|
|
|
| 141 |
context_before = text[max(0, start - context_window):start].lower()
|
| 142 |
context_after = text[end:min(len(text), end + context_window)].lower()
|
| 143 |
|
| 144 |
+
card_keywords = [
|
| 145 |
+
"card", "credit", "debit", "visa", "mastercard",
|
| 146 |
+
"payment", "amex", "account no", "card no"
|
| 147 |
+
]
|
| 148 |
for keyword in card_keywords:
|
| 149 |
if keyword in context_before or keyword in context_after:
|
| 150 |
return True
|
|
|
|
| 152 |
# For simplicity, we'll rely on context here. If needed, Luhn can be added.
|
| 153 |
return False
|
| 154 |
|
|
|
|
| 155 |
def verify_cvv(self, text: str, match: re.Match) -> bool:
|
| 156 |
"""Verify if a 3-4 digit number is actually a CVV using contextual clues"""
|
| 157 |
context_window = 50
|
| 158 |
start, end = match.span()
|
| 159 |
value = match.group()
|
| 160 |
|
| 161 |
+
# If it's part of a longer number sequence (like a phone number or ID),
|
| 162 |
+
# it's likely not a CVV
|
| 163 |
# Check character immediately before and after
|
| 164 |
+
char_before = text[start - 1:start] if start > 0 else ""
|
| 165 |
+
char_after = text[end:end + 1] if end < len(text) else ""
|
| 166 |
if char_before.isdigit() or char_after.isdigit():
|
| 167 |
+
return False # It's part of a larger number
|
| 168 |
|
| 169 |
# Only consider 3-4 digit numbers
|
| 170 |
if not value.isdigit() or len(value) < 3 or len(value) > 4:
|
|
|
|
| 175 |
|
| 176 |
# Expanded list of CVV-related keywords to improve detection
|
| 177 |
cvv_keywords = [
|
| 178 |
+
"cvv", "cvc", "csc", "security code", "card verification",
|
| 179 |
+
"verification no", "security", "security number", "cv2",
|
| 180 |
+
"card code", "security value"
|
| 181 |
]
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Look for CVV context clues
|
| 184 |
+
is_cvv_context = any(
|
| 185 |
+
keyword in context_before or keyword in context_after
|
| 186 |
+
for keyword in cvv_keywords
|
| 187 |
+
)
|
| 188 |
|
| 189 |
# If explicitly mentioned as a CVV, immediately return true
|
| 190 |
if is_cvv_context:
|
|
|
|
| 192 |
|
| 193 |
# If it looks like a year, reject it
|
| 194 |
if len(value) == 4 and 1900 <= int(value) <= 2100:
|
| 195 |
+
if any(
|
| 196 |
+
k in context_before or k in context_after
|
| 197 |
+
for k in ["year", "born", "established", "since"]
|
| 198 |
+
):
|
| 199 |
return False
|
| 200 |
|
| 201 |
# If in expiry date context, reject it
|
| 202 |
if re.search(r'\b(0[1-9]|1[0-2])[/\s-]$', context_before.strip()):
|
| 203 |
return False
|
| 204 |
+
|
| 205 |
+
# If no context clues but we have a credit card mention nearby,
|
| 206 |
+
# it could be a CVV
|
| 207 |
+
card_context = any(
|
| 208 |
+
k in context_before or k in context_after for k in
|
| 209 |
+
["card", "credit", "visa", "mastercard", "amex", "discover"]
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
return is_cvv_context or (card_context and len(value) in [3, 4])
|
| 213 |
|
| 214 |
def verify_phone_number(self, text: str, match: re.Match) -> bool:
|
|
|
|
| 217 |
"""
|
| 218 |
value = match.group()
|
| 219 |
start, end = match.span()
|
| 220 |
+
|
| 221 |
# Extract only digits to count them
|
| 222 |
digits = ''.join(c for c in value if c.isdigit())
|
| 223 |
digit_count = len(digits)
|
| 224 |
+
|
| 225 |
# Most phone numbers worldwide have between 7 and 15 digits
|
| 226 |
if digit_count < 7 or digit_count > 15:
|
| 227 |
return False
|
| 228 |
+
|
| 229 |
# Check for common phone number indicators
|
| 230 |
context_window = 50
|
| 231 |
context_before = text[max(0, start - context_window):start].lower()
|
| 232 |
context_after = text[end:min(len(text), end + context_window)].lower()
|
| 233 |
+
|
| 234 |
# Expanded phone keywords
|
| 235 |
phone_keywords = [
|
| 236 |
+
"phone", "call", "tel", "telephone", "contact", "dial", "mobile",
|
| 237 |
+
"cell", "number", "direct", "office", "fax", "reach me at",
|
| 238 |
+
"call me", "contact me", "line", "extension", "ext", "phone number"
|
| 239 |
]
|
| 240 |
+
|
| 241 |
# Check for phone context
|
| 242 |
+
has_phone_context = any(
|
| 243 |
+
kw in context_before or kw in context_after for kw in phone_keywords
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
# Check for formatting that indicates a phone number
|
| 247 |
+
has_phone_formatting = bool(re.search(r'[-\s.()\\+]', value))
|
| 248 |
+
|
| 249 |
# Check for international prefix
|
| 250 |
has_intl_prefix = value.startswith('+') or value.startswith('00')
|
| 251 |
+
|
| 252 |
# Return true if any of these conditions are met:
|
| 253 |
# 1. Has explicit phone context
|
| 254 |
# 2. Has phone-like formatting AND reasonable digit count
|
| 255 |
# 3. Has international prefix AND reasonable digit count
|
| 256 |
# 4. Has 10 digits exactly (common in many countries) with formatting
|
| 257 |
+
return (
|
| 258 |
+
has_phone_context
|
| 259 |
+
or (has_phone_formatting and digit_count >= 7)
|
| 260 |
+
or (has_intl_prefix)
|
| 261 |
+
or (digit_count == 10 and has_phone_formatting)
|
| 262 |
+
)
|
| 263 |
|
| 264 |
def detect_name_entities(self, text: str) -> List[Entity]:
|
| 265 |
"""Detect name entities using SpaCy NER"""
|
| 266 |
entities = []
|
| 267 |
doc = self.nlp(text)
|
| 268 |
+
|
| 269 |
for ent in doc.ents:
|
| 270 |
# Use PER for person, common in many models like xx_ent_wiki_sm
|
| 271 |
# Also checking for PERSON as some models might use it.
|
| 272 |
if ent.label_ in ["PER", "PERSON"]:
|
| 273 |
+
entities.append(
|
| 274 |
+
Entity(ent.start_char, ent.end_char, "full_name", ent.text)
|
| 275 |
+
)
|
| 276 |
return entities
|
| 277 |
|
| 278 |
def detect_all_entities(self, text: str) -> List[Entity]:
|
|
|
|
| 306 |
# A simple greedy approach: iterate and remove/adjust overlaps
|
| 307 |
# This can be made more sophisticated
|
| 308 |
resolved_entities: List[Entity] = []
|
| 309 |
+
# Process by start, then by longest
|
| 310 |
+
for current_entity in sorted(
|
| 311 |
+
entities, key=lambda e: (e.start, -(e.end - e.start))
|
| 312 |
+
):
|
| 313 |
is_overlapped_or_contained = False
|
| 314 |
temp_resolved = []
|
| 315 |
for i, res_entity in enumerate(resolved_entities):
|
| 316 |
# Check for overlap:
|
| 317 |
# Current: |----|
|
| 318 |
# Res: |----| or |----| or |--| or |------|
|
| 319 |
+
overlap = max(
|
| 320 |
+
0,
|
| 321 |
+
min(current_entity.end, res_entity.end) # Fixed W504 line break
|
| 322 |
+
- max(current_entity.start, res_entity.start)
|
| 323 |
+
)
|
| 324 |
|
| 325 |
if overlap > 0:
|
| 326 |
is_overlapped_or_contained = True
|
| 327 |
# Preference:
|
| 328 |
+
# 1. NER often trump regex if they are the ones causing overlap
|
| 329 |
# 2. Longer entity wins
|
| 330 |
current_len = current_entity.end - current_entity.start
|
| 331 |
res_len = res_entity.end - res_entity.start
|
| 332 |
|
| 333 |
+
# If current is a name and overlaps, and previous is not a name,
|
| 334 |
+
# prefer current if it's not fully contained
|
| 335 |
+
if (current_entity.entity_type == "full_name" # E501 corrected
|
| 336 |
+
and res_entity.entity_type != "full_name"):
|
| 337 |
+
# current not fully contained by res
|
| 338 |
+
if not (res_entity.start <= current_entity.start
|
| 339 |
+
and res_entity.end >= current_entity.end):
|
| 340 |
+
# remove res_entity, current will be added later
|
| 341 |
+
continue # go to next res_entity, marked for removal
|
| 342 |
+
elif (res_entity.entity_type == "full_name"
|
| 343 |
+
and current_entity.entity_type != "full_name"):
|
| 344 |
+
# res_entity is a name, current is not. Prefer res_entity
|
| 345 |
+
# if it's not fully contained
|
| 346 |
+
if not (current_entity.start <= res_entity.start
|
| 347 |
+
and current_entity.end >= res_entity.end):
|
| 348 |
+
# current entity is subsumed or less important,
|
| 349 |
+
# so don't add current and keep res_entity
|
| 350 |
temp_resolved.append(res_entity)
|
| 351 |
+
is_overlapped_or_contained = True # Mark current as handled
|
| 352 |
+
break # Current is dominated
|
| 353 |
|
| 354 |
# General case: longer entity wins
|
| 355 |
if current_len > res_len:
|
| 356 |
+
# current is longer, res_entity is removed from
|
| 357 |
+
# consideration for this current_entity
|
| 358 |
+
pass # res_entity not added to temp_resolved if fully replaced
|
| 359 |
elif res_len > current_len:
|
| 360 |
# res is longer, current is dominated
|
| 361 |
temp_resolved.append(res_entity)
|
| 362 |
+
is_overlapped_or_contained = True # Mark current as handled
|
| 363 |
break
|
| 364 |
+
else: # Same length, keep existing one (res_entity)
|
| 365 |
temp_resolved.append(res_entity)
|
| 366 |
+
is_overlapped_or_contained = True # Mark current as handled
|
| 367 |
break
|
| 368 |
+
else: # No overlap
|
| 369 |
temp_resolved.append(res_entity)
|
| 370 |
|
| 371 |
if not is_overlapped_or_contained:
|
| 372 |
temp_resolved.append(current_entity)
|
| 373 |
|
| 374 |
+
resolved_entities = sorted(
|
| 375 |
+
temp_resolved, key=lambda e: (e.start, -(e.end - e.start))
|
| 376 |
+
)
|
| 377 |
|
| 378 |
# Final pass to remove fully contained entities if a larger one exists
|
| 379 |
final_entities = []
|
|
|
|
| 386 |
if i == j:
|
| 387 |
continue
|
| 388 |
# If 'entity' is strictly contained within 'other_entity'
|
| 389 |
+
if (other_entity.start <= entity.start
|
| 390 |
+
and other_entity.end >= entity.end
|
| 391 |
+
and (other_entity.end - other_entity.start
|
| 392 |
+
> entity.end - entity.start)):
|
| 393 |
is_contained = True
|
| 394 |
break
|
| 395 |
if not is_contained:
|
|
|
|
| 397 |
|
| 398 |
return final_entities
|
| 399 |
|
|
|
|
| 400 |
def mask_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
|
| 401 |
"""
|
| 402 |
Mask PII entities in the text and return masked text and entity information
|
|
|
|
| 428 |
|
| 429 |
return "".join(new_text_parts), entity_info
|
| 430 |
|
|
|
|
| 431 |
def process_email(self, email_text: str) -> Dict[str, Any]:
|
| 432 |
"""
|
| 433 |
Process an email by detecting and masking PII entities.
|
|
|
|
| 435 |
"""
|
| 436 |
# Mask the email
|
| 437 |
masked_email, entity_info = self.mask_text(email_text)
|
| 438 |
+
|
| 439 |
# Store the email in the SQLite database - only get back email_id now
|
| 440 |
email_id = self.db.store_email(
|
| 441 |
original_email=email_text,
|
| 442 |
masked_email=masked_email,
|
| 443 |
masked_entities=entity_info
|
| 444 |
)
|
| 445 |
+
|
| 446 |
# Return the processed data with just the email_id
|
| 447 |
return {
|
| 448 |
+
"input_email_body": email_text, # Return original for API compatibility
|
| 449 |
"list_of_masked_entities": entity_info,
|
| 450 |
"masked_email": masked_email,
|
| 451 |
"category_of_the_email": "",
|
| 452 |
"email_id": email_id
|
| 453 |
}
|
| 454 |
+
|
| 455 |
+
def get_original_email(
|
| 456 |
+
self, email_id: str, access_key: str
|
| 457 |
+
) -> Optional[Dict[str, Any]]:
|
| 458 |
"""
|
| 459 |
Retrieve the original email with PII using the email ID and access key.
|
| 460 |
+
|
| 461 |
Args:
|
| 462 |
email_id: The ID of the stored email
|
| 463 |
access_key: The security key for accessing the original email
|
| 464 |
+
|
| 465 |
Returns:
|
| 466 |
The original email data or None if not found or access_key is invalid
|
| 467 |
"""
|
| 468 |
return self.db.get_original_email(email_id, access_key)
|
| 469 |
+
|
| 470 |
def get_masked_email_by_id(self, email_id: str) -> Optional[Dict[str, Any]]:
|
| 471 |
"""
|
| 472 |
Retrieve a masked email by its ID (without the original PII-containing email).
|
| 473 |
+
|
| 474 |
Args:
|
| 475 |
email_id: The ID of the stored email
|
| 476 |
+
|
| 477 |
Returns:
|
| 478 |
The masked email data or None if not found
|
| 479 |
"""
|
| 480 |
return self.db.get_email_by_id(email_id)
|
| 481 |
+
|
| 482 |
+
def get_original_by_masked_email(
|
| 483 |
+
self, masked_email: str
|
| 484 |
+
) -> Optional[Dict[str, Any]]:
|
| 485 |
"""
|
| 486 |
Retrieve the original unmasked email using the masked email content.
|
| 487 |
+
|
| 488 |
Args:
|
| 489 |
masked_email: The masked version of the email to search for
|
| 490 |
+
|
| 491 |
Returns:
|
| 492 |
The original email data or None if not found
|
| 493 |
"""
|
| 494 |
+
return self.db.get_email_by_masked_content(masked_email)
|