sathish2352 commited on
Commit
4dceb21
·
verified ·
1 Parent(s): e01c7e7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +38 -0
main.py CHANGED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict, Any
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ from utils import mask_pii_multilingual
7
+
8
+ # Load model and tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained("./model")
10
+ model = AutoModelForSequenceClassification.from_pretrained("./model")
11
+ model.eval()
12
+
13
+ app = FastAPI()
14
+
15
+ class EmailInput(BaseModel):
16
+ input_email_body: str
17
+
18
+ @app.post("/classify")
19
+ def classify_email(input: EmailInput):
20
+ email_body = input.input_email_body
21
+
22
+ # Step 1: Mask PII
23
+ masked_text, masked_entities = mask_pii_multilingual(email_body)
24
+
25
+ # Step 2: Classification
26
+ inputs = tokenizer(masked_text, return_tensors="pt", truncation=True, padding=True)
27
+ with torch.no_grad():
28
+ logits = model(**inputs).logits
29
+ pred = torch.argmax(logits, dim=1).item()
30
+
31
+ label_map = {0: "Incident", 1: "Request", 2: "Change", 3: "Problem"}
32
+
33
+ return {
34
+ "input_email_body": email_body,
35
+ "list_of_masked_entities": masked_entities,
36
+ "masked_email": masked_text,
37
+ "category_of_the_email": label_map[pred]
38
+ }