Mummia-99's picture
Update server.py
6afefd9 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import imaplib, email
from email.header import decode_header
from transformers import pipeline
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import os
app = FastAPI()
# os.environ["huggingfacetoken"] = "/app/.cache"
# model_name = "facebook/bart-large-mnli"
# Force PyTorch model instead of Flax
# model = AutoModelForSequenceClassification.from_pretrained(model_name, from_flax=True)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
"facebook/bart-large-mnli",
force_download=True # Forces re-download
classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
categories = ["Spam", "Not Spam"]
)
class EmailCredentials(BaseModel):
email: str
password: str
def extract_email_content(msg):
subject, encoding = decode_header(msg["Subject"])[0]
if isinstance(subject, bytes):
subject = subject.decode(encoding or "utf-8")
sender = msg.get("From")
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode("utf-8", errors="ignore")
break
else:
body = msg.get_payload(decode=True).decode("utf-8", errors="ignore")
return sender, subject, body
@app.post("/classify_emails")
def classify_emails(credentials: EmailCredentials):
try:
mail = imaplib.IMAP4_SSL("imap.gmail.com")
mail.login(credentials.email, credentials.password)
mail.select("inbox")
status, messages = mail.search(None, "ALL")
email_ids = messages[0].split()[-10:]
results = []
for email_id in email_ids:
status, msg_data = mail.fetch(email_id, "(RFC822)")
for response_part in msg_data:
if isinstance(response_part, tuple):
msg = email.message_from_bytes(response_part[1])
sender, subject, body = extract_email_content(msg)
classification = classifier(subject + " " + body[:200], categories)
results.append({
"from": sender,
"subject": subject,
"category": classification["labels"][0],
"confidence": classification["scores"][0]
})
mail.logout()
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)