sathish2352's picture
Update main.py
4dceb21 verified
raw
history blame
1.16 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict, Any
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from utils import mask_pii_multilingual
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("./model")
model = AutoModelForSequenceClassification.from_pretrained("./model")
model.eval()
app = FastAPI()
class EmailInput(BaseModel):
input_email_body: str
@app.post("/classify")
def classify_email(input: EmailInput):
email_body = input.input_email_body
# Step 1: Mask PII
masked_text, masked_entities = mask_pii_multilingual(email_body)
# Step 2: Classification
inputs = tokenizer(masked_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
label_map = {0: "Incident", 1: "Request", 2: "Change", 3: "Problem"}
return {
"input_email_body": email_body,
"list_of_masked_entities": masked_entities,
"masked_email": masked_text,
"category_of_the_email": label_map[pred]
}