nandini2455508's picture
Update api.py
bf35fed verified
# -*- coding: utf-8 -*-
"""api.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1_vvePPVuXHIon25W71PD4kt1Z5J2sKgh
"""
# Install required packages
# Install required packages
# -*- coding: utf-8 -*-
"""api.py"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any
import joblib
import cloudpickle
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import os
import logging
# Your custom utils
from utils import mask_pii # Keep this if you have utils.py uploaded
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# NLTK setup
nltk_data_dir = "/tmp/nltk_data"
os.makedirs(nltk_data_dir, exist_ok=True)
nltk.download('stopwords', download_dir=nltk_data_dir)
nltk.download('wordnet', download_dir=nltk_data_dir)
nltk.data.path.append(nltk_data_dir)
# Load Models
def load_pickle_file(file_path):
try:
return joblib.load(file_path)
except ModuleNotFoundError:
with open(file_path, 'rb') as f:
return cloudpickle.load(f)
try:
model = load_pickle_file('rf_classifier_v3.pkl')
vectorizer = load_pickle_file('vectorizer.pkl')
label_encoder = load_pickle_file('label_encoder.pkl')
except Exception as e:
logger.critical(f"Model loading failed: {e}")
exit(1)
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
app = FastAPI(title="Email Classification API")
class EmailRequest(BaseModel):
email_body: str
class EmailResponse(BaseModel):
masked_email: str
predicted_category: str
pii_entities: List[Dict[str, Any]]
def clean_text(text):
text = re.sub(r'\n|\r', ' ', text)
text = re.sub(r'Subject:', '', text)
text = text.lower()
text = re.sub(r'[^a-z\s]', '', text)
text = re.sub(r'\s+', ' ', text)
tokens = text.split()
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
return ' '.join(tokens)
def classify_email(text: str):
cleaned = clean_text(text)
vec = vectorizer.transform([cleaned])
pred = model.predict(vec)
return label_encoder.inverse_transform(pred)[0]
@app.get("/")
async def root():
return {"message": "API is running."}
@app.post("/predict", response_model=EmailResponse)
async def predict(request: EmailRequest):
if not request.email_body.strip():
raise HTTPException(status_code=400, detail="Email body is required.")
masked_text, pii_entities = mask_pii(request.email_body)
predicted_category = classify_email(masked_text)
return EmailResponse(
masked_email=masked_text,
predicted_category=predicted_category,
pii_entities=pii_entities
)