Spaces:
Sleeping
Sleeping
Surya8663
commited on
Commit
·
4c21dbb
1
Parent(s):
5c04706
Final version for free tier (removes gmail fetch)
Browse files- backend/app.py +62 -49
backend/app.py
CHANGED
|
@@ -1,66 +1,75 @@
|
|
| 1 |
-
# backend/app.py (
|
| 2 |
|
| 3 |
from transformers import pipeline
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
# --- This is the core change ---
|
| 11 |
-
# 1. We load a pre-trained "zero-shot-classification" pipeline from Hugging Face.
|
| 12 |
-
# The first time this code runs, it will automatically download the model (approx. 1.6GB).
|
| 13 |
print("Loading zero-shot classification model...")
|
| 14 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
| 15 |
print("Model loaded successfully.")
|
| 16 |
-
# -----------------------------
|
| 17 |
|
| 18 |
database.init_db()
|
| 19 |
|
| 20 |
-
# --- These are the labels we want to sort emails into ---
|
| 21 |
-
# You can change or add to this list!
|
| 22 |
POSSIBLE_LABELS = ["Work", "Promotions", "Personal", "Spam", "Important"]
|
| 23 |
-
email_storage = {}
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
sequences = (df['Subject'] + " " + df['Body']).tolist()
|
| 32 |
|
| 33 |
-
print(f"Classifying {len(sequences)} emails...")
|
| 34 |
-
# Use the zero-shot pipeline. It's powerful but can be slow on CPU for many emails.
|
| 35 |
predictions = classifier(sequences, candidate_labels=POSSIBLE_LABELS)
|
| 36 |
print("Classification complete.")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
# Add results to the DataFrame
|
| 43 |
df['predicted_folder'] = pred_labels
|
| 44 |
df['confidence'] = pred_scores
|
| 45 |
|
| 46 |
-
# Save to the database
|
| 47 |
database.add_email_records(df)
|
| 48 |
-
print(f"Saved {len(df)} records to the database.")
|
| 49 |
|
| 50 |
# Format for display
|
| 51 |
df_display = df.copy()
|
| 52 |
df_display['confidence'] = [f"{score:.2%}" for score in pred_scores]
|
| 53 |
-
return df_display
|
| 54 |
-
|
| 55 |
-
def fetch_and_classify_emails():
|
| 56 |
-
"""Fetches emails from Gmail and triggers the classification."""
|
| 57 |
-
emails = fetch_latest_emails()
|
| 58 |
-
if not emails:
|
| 59 |
-
email_storage.clear()
|
| 60 |
-
return pd.DataFrame({"From": ["No new emails found or an error occurred."], "Subject": [""]})
|
| 61 |
-
|
| 62 |
-
df = pd.DataFrame(emails)
|
| 63 |
-
return predict_and_save_emails(df)
|
| 64 |
|
| 65 |
# --- History & Feedback Functions (Unchanged) ---
|
| 66 |
def show_history():
|
|
@@ -73,20 +82,28 @@ def show_history():
|
|
| 73 |
def save_feedback_and_refresh(record_id, corrected_label):
|
| 74 |
if not record_id or not corrected_label:
|
| 75 |
return show_history(), "Please enter a Record ID and a correct label first."
|
| 76 |
-
|
| 77 |
record_id_int = int(record_id)
|
| 78 |
database.update_feedback(record_id_int, corrected_label)
|
| 79 |
-
|
| 80 |
return show_history(), f"Success! Record {record_id_int} updated."
|
| 81 |
|
| 82 |
-
# --- Build the Gradio Interface
|
| 83 |
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
| 84 |
gr.Markdown("## 📧 Smart Email Sorter (Zero-Shot Model)")
|
| 85 |
|
| 86 |
-
with gr.Tab("
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
with gr.Tab("History & Feedback"):
|
| 92 |
gr.Markdown("### Prediction History & Feedback")
|
|
@@ -101,8 +118,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
| 101 |
history_df = gr.Dataframe(label="All Saved Predictions", wrap=True)
|
| 102 |
app.load(fn=show_history, outputs=history_df)
|
| 103 |
refresh_history_btn.click(fn=show_history, outputs=history_df)
|
| 104 |
-
submit_feedback_btn.click(
|
| 105 |
-
fn=save_feedback_and_refresh,
|
| 106 |
-
inputs=[record_id_input, feedback_label],
|
| 107 |
-
outputs=[history_df, feedback_status]
|
| 108 |
-
)
|
|
|
|
| 1 |
+
# backend/app.py (FINAL VERSION FOR FREE TIER)
|
| 2 |
|
| 3 |
from transformers import pipeline
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
+
# We no longer need gmail_fetcher, so we remove the import
|
| 7 |
+
import database
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
# --- Diagnostic and Model Loading (Unchanged) ---
|
| 11 |
+
print("--- Starting Application ---")
|
| 12 |
+
try:
|
| 13 |
+
print(f"Gradio Version: {gr.__version__}")
|
| 14 |
+
except Exception as e:
|
| 15 |
+
print(f"Could not get Gradio version: {e}")
|
| 16 |
+
print("--------------------------")
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
print("Loading zero-shot classification model...")
|
| 19 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
| 20 |
print("Model loaded successfully.")
|
|
|
|
| 21 |
|
| 22 |
database.init_db()
|
| 23 |
|
|
|
|
|
|
|
| 24 |
POSSIBLE_LABELS = ["Work", "Promotions", "Personal", "Spam", "Important"]
|
|
|
|
| 25 |
|
| 26 |
+
# --- NEW/RESTORED: Functions for Single and Batch Prediction ---
|
| 27 |
+
|
| 28 |
+
def predict_single(subject, body):
|
| 29 |
+
"""Classifies a single email from subject and body text."""
|
| 30 |
+
if not subject.strip() and not body.strip():
|
| 31 |
+
return "Please provide a subject or body."
|
| 32 |
+
|
| 33 |
+
sequence = subject + " " + body
|
| 34 |
+
prediction = classifier(sequence, candidate_labels=POSSIBLE_LABELS)
|
| 35 |
+
|
| 36 |
+
top_label = prediction['labels'][0]
|
| 37 |
+
top_score = prediction['scores'][0]
|
| 38 |
+
|
| 39 |
+
return f"Predicted Folder: {top_label} (Confidence: {top_score:.2%})"
|
| 40 |
+
|
| 41 |
+
def predict_batch_csv(file):
|
| 42 |
+
"""Classifies a batch of emails from an uploaded CSV file."""
|
| 43 |
+
try:
|
| 44 |
+
df = pd.read_csv(file.name)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]})
|
| 47 |
|
| 48 |
+
if 'subject' not in df.columns or 'body' not in df.columns:
|
| 49 |
+
return pd.DataFrame({"Error": ["CSV must have 'subject' and 'body' columns."]})
|
| 50 |
+
|
| 51 |
+
df['Subject'] = df['subject'].fillna('')
|
| 52 |
+
df['Body'] = df['body'].fillna('')
|
| 53 |
sequences = (df['Subject'] + " " + df['Body']).tolist()
|
| 54 |
|
| 55 |
+
print(f"Classifying {len(sequences)} emails from CSV...")
|
|
|
|
| 56 |
predictions = classifier(sequences, candidate_labels=POSSIBLE_LABELS)
|
| 57 |
print("Classification complete.")
|
| 58 |
+
|
| 59 |
+
pred_labels = [p['labels'][0] for p in predictions]
|
| 60 |
+
pred_scores = [p['scores'][0] for p in predictions]
|
| 61 |
+
|
|
|
|
|
|
|
| 62 |
df['predicted_folder'] = pred_labels
|
| 63 |
df['confidence'] = pred_scores
|
| 64 |
|
| 65 |
+
# Save results to the database
|
| 66 |
database.add_email_records(df)
|
| 67 |
+
print(f"Saved {len(df)} records from CSV to the database.")
|
| 68 |
|
| 69 |
# Format for display
|
| 70 |
df_display = df.copy()
|
| 71 |
df_display['confidence'] = [f"{score:.2%}" for score in pred_scores]
|
| 72 |
+
return df_display
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# --- History & Feedback Functions (Unchanged) ---
|
| 75 |
def show_history():
|
|
|
|
| 82 |
def save_feedback_and_refresh(record_id, corrected_label):
|
| 83 |
if not record_id or not corrected_label:
|
| 84 |
return show_history(), "Please enter a Record ID and a correct label first."
|
|
|
|
| 85 |
record_id_int = int(record_id)
|
| 86 |
database.update_feedback(record_id_int, corrected_label)
|
|
|
|
| 87 |
return show_history(), f"Success! Record {record_id_int} updated."
|
| 88 |
|
| 89 |
+
# --- MODIFIED: Build the Final Gradio Interface ---
|
| 90 |
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
| 91 |
gr.Markdown("## 📧 Smart Email Sorter (Zero-Shot Model)")
|
| 92 |
|
| 93 |
+
with gr.Tab("Single Prediction"):
|
| 94 |
+
gr.Markdown("### Classify a Single Email")
|
| 95 |
+
subject_input = gr.Textbox(label="Subject", lines=1)
|
| 96 |
+
body_input = gr.Textbox(label="Body", lines=5, placeholder="Paste the email body here...")
|
| 97 |
+
predict_btn = gr.Button("Classify Email", variant="primary")
|
| 98 |
+
output_label = gr.Label(label="Result")
|
| 99 |
+
predict_btn.click(fn=predict_single, inputs=[subject_input, body_input], outputs=output_label, show_progress="full")
|
| 100 |
+
|
| 101 |
+
with gr.Tab("Batch Prediction (CSV)"):
|
| 102 |
+
gr.Markdown("### Classify a Batch of Emails from a CSV File")
|
| 103 |
+
csv_input = gr.File(label="Upload a CSV file with 'subject' and 'body' columns")
|
| 104 |
+
csv_btn = gr.Button("Classify Batch", variant="primary")
|
| 105 |
+
batch_output = gr.Dataframe(label="Classification Results", wrap=True)
|
| 106 |
+
csv_btn.click(fn=predict_batch_csv, inputs=csv_input, outputs=batch_output, show_progress="full")
|
| 107 |
|
| 108 |
with gr.Tab("History & Feedback"):
|
| 109 |
gr.Markdown("### Prediction History & Feedback")
|
|
|
|
| 118 |
history_df = gr.Dataframe(label="All Saved Predictions", wrap=True)
|
| 119 |
app.load(fn=show_history, outputs=history_df)
|
| 120 |
refresh_history_btn.click(fn=show_history, outputs=history_df)
|
| 121 |
+
submit_feedback_btn.click(fn=save_feedback_and_refresh, inputs=[record_id_input, feedback_label], outputs=[history_df, feedback_status])
|
|
|
|
|
|
|
|
|
|
|
|