|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Dict, Any |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
from utils import mask_pii_multilingual |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./model") |
|
|
model = AutoModelForSequenceClassification.from_pretrained("./model") |
|
|
model.eval() |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class EmailInput(BaseModel): |
|
|
input_email_body: str |
|
|
|
|
|
@app.post("/classify") |
|
|
def classify_email(input: EmailInput): |
|
|
email_body = input.input_email_body |
|
|
|
|
|
|
|
|
masked_text, masked_entities = mask_pii_multilingual(email_body) |
|
|
|
|
|
|
|
|
inputs = tokenizer(masked_text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
pred = torch.argmax(logits, dim=1).item() |
|
|
|
|
|
label_map = {0: "Incident", 1: "Request", 2: "Change", 3: "Problem"} |
|
|
|
|
|
return { |
|
|
"input_email_body": email_body, |
|
|
"list_of_masked_entities": masked_entities, |
|
|
"masked_email": masked_text, |
|
|
"category_of_the_email": label_map[pred] |
|
|
} |
|
|
|