File size: 2,718 Bytes
70985f9
 
85327fb
 
 
 
4658146
3e847ba
70985f9
6afefd9
 
3e847ba
85327fb
6afefd9
 
 
 
 
 
43ba64b
85327fb
3e847ba
85327fb
43ba64b
6afefd9
 
 
85327fb
 
 
3e847ba
85327fb
 
 
 
 
 
 
 
 
 
 
 
 
 
3e847ba
85327fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e847ba
85327fb
43ba64b
3e847ba
85327fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import imaplib, email
from email.header import decode_header
from transformers import pipeline
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import os

app = FastAPI()
# os.environ["huggingfacetoken"] = "/app/.cache"
# model_name = "facebook/bart-large-mnli"

# Force PyTorch model instead of Flax
# model = AutoModelForSequenceClassification.from_pretrained(model_name, from_flax=True)
# tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSequenceClassification.from_pretrained(
    "facebook/bart-large-mnli",
    force_download=True  # Forces re-download

classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)

categories = ["Spam", "Not Spam"]

)


class EmailCredentials(BaseModel):
    email: str
    password: str

def extract_email_content(msg):
    subject, encoding = decode_header(msg["Subject"])[0]
    if isinstance(subject, bytes):
        subject = subject.decode(encoding or "utf-8")
    sender = msg.get("From")
    body = ""
    if msg.is_multipart():
        for part in msg.walk():
            if part.get_content_type() == "text/plain":
                body = part.get_payload(decode=True).decode("utf-8", errors="ignore")
                break
    else:
        body = msg.get_payload(decode=True).decode("utf-8", errors="ignore")
    return sender, subject, body

@app.post("/classify_emails")
def classify_emails(credentials: EmailCredentials):
    try:
        mail = imaplib.IMAP4_SSL("imap.gmail.com")
        mail.login(credentials.email, credentials.password)
        mail.select("inbox")
        status, messages = mail.search(None, "ALL")
        email_ids = messages[0].split()[-10:]
        results = []
        
        for email_id in email_ids:
            status, msg_data = mail.fetch(email_id, "(RFC822)")
            for response_part in msg_data:
                if isinstance(response_part, tuple):
                    msg = email.message_from_bytes(response_part[1])
                    sender, subject, body = extract_email_content(msg)
                    classification = classifier(subject + " " + body[:200], categories)
                    results.append({
                        "from": sender,
                        "subject": subject,
                        "category": classification["labels"][0],
                        "confidence": classification["scores"][0]
                    })
        mail.logout()
        return results
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)