File size: 5,196 Bytes
4c21dbb
4ded330
 
 
 
4c21dbb
 
 
 
 
 
 
 
 
 
 
4ded330
 
d804a65
4ded330
 
 
 
 
4c21dbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ded330
4c21dbb
 
 
 
 
4ded330
 
4c21dbb
4ded330
 
4c21dbb
 
 
 
4ded330
 
 
4c21dbb
4ded330
4c21dbb
4ded330
 
 
 
4c21dbb
4ded330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c21dbb
4ded330
 
 
4c21dbb
 
 
 
 
 
 
 
 
 
 
 
 
 
4ded330
 
 
 
 
 
 
 
 
 
 
 
 
 
4c21dbb
6db3cd8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# backend/app.py (FINAL VERSION FOR FREE TIER)

from transformers import pipeline
import gradio as gr
import pandas as pd
# We no longer need gmail_fetcher, so we remove the import
import database
import sys

# --- Diagnostic and Model Loading (Unchanged) ---
print("--- Starting Application ---")
try:
    print(f"Gradio Version: {gr.__version__}")
except Exception as e:
    print(f"Could not get Gradio version: {e}")
print("--------------------------")

print("Loading zero-shot classification model...")
classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-3")
print("Model loaded successfully.")
database.init_db()

POSSIBLE_LABELS = ["Work", "Promotions", "Personal", "Spam", "Important"]

# --- NEW/RESTORED: Functions for Single and Batch Prediction ---

def predict_single(subject, body):
    """Classifies a single email from subject and body text."""
    if not subject.strip() and not body.strip():
        return "Please provide a subject or body."
    
    sequence = subject + " " + body
    prediction = classifier(sequence, candidate_labels=POSSIBLE_LABELS)
    
    top_label = prediction['labels'][0]
    top_score = prediction['scores'][0]
    
    return f"Predicted Folder: {top_label} (Confidence: {top_score:.2%})"

def predict_batch_csv(file):
    """Classifies a batch of emails from an uploaded CSV file."""
    try:
        df = pd.read_csv(file.name)
    except Exception as e:
        return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]})

    if 'subject' not in df.columns or 'body' not in df.columns:
        return pd.DataFrame({"Error": ["CSV must have 'subject' and 'body' columns."]})

    df['Subject'] = df['subject'].fillna('')
    df['Body'] = df['body'].fillna('')
    sequences = (df['Subject'] + " " + df['Body']).tolist()
    
    print(f"Classifying {len(sequences)} emails from CSV...")
    predictions = classifier(sequences, candidate_labels=POSSIBLE_LABELS)
    print("Classification complete.")
    
    pred_labels = [p['labels'][0] for p in predictions]
    pred_scores = [p['scores'][0] for p in predictions]
    
    df['predicted_folder'] = pred_labels
    df['confidence'] = pred_scores
    
    # Save results to the database
    database.add_email_records(df)
    print(f"Saved {len(df)} records from CSV to the database.")
    
    # Format for display
    df_display = df.copy()
    df_display['confidence'] = [f"{score:.2%}" for score in pred_scores]
    return df_display

# --- History & Feedback Functions (Unchanged) ---
def show_history():
    records, columns = database.get_all_emails()
    if not records:
        return pd.DataFrame({"Message": ["No history found."]})
    history_df = pd.DataFrame(records, columns=columns)
    return history_df

def save_feedback_and_refresh(record_id, corrected_label):
    if not record_id or not corrected_label:
        return show_history(), "Please enter a Record ID and a correct label first."
    record_id_int = int(record_id)
    database.update_feedback(record_id_int, corrected_label)
    return show_history(), f"Success! Record {record_id_int} updated."

# --- MODIFIED: Build the Final Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown("## 📧 Smart Email Sorter (Zero-Shot Model)")

    with gr.Tab("Single Prediction"):
        gr.Markdown("### Classify a Single Email")
        subject_input = gr.Textbox(label="Subject", lines=1)
        body_input = gr.Textbox(label="Body", lines=5, placeholder="Paste the email body here...")
        predict_btn = gr.Button("Classify Email", variant="primary")
        output_label = gr.Label(label="Result")
        predict_btn.click(fn=predict_single, inputs=[subject_input, body_input], outputs=output_label, show_progress="full")

    with gr.Tab("Batch Prediction (CSV)"):
        gr.Markdown("### Classify a Batch of Emails from a CSV File")
        csv_input = gr.File(label="Upload a CSV file with 'subject' and 'body' columns")
        csv_btn = gr.Button("Classify Batch", variant="primary")
        batch_output = gr.Dataframe(label="Classification Results", wrap=True)
        csv_btn.click(fn=predict_batch_csv, inputs=csv_input, outputs=batch_output, show_progress="full")

    with gr.Tab("History & Feedback"):
        gr.Markdown("### Prediction History & Feedback")
        gr.Markdown("To correct a prediction, find its `id` in the table, then enter it into the textbox below.")
        with gr.Row():
            record_id_input = gr.Number(label="Enter Record ID to Correct")
            feedback_label = gr.Dropdown(label="Correct Folder", choices=POSSIBLE_LABELS)
        with gr.Row():
            submit_feedback_btn = gr.Button("Submit Correction", variant="primary")
            refresh_history_btn = gr.Button("Refresh History")
        feedback_status = gr.Markdown("")
        history_df = gr.Dataframe(label="All Saved Predictions", wrap=True)
        app.load(fn=show_history, outputs=history_df)
        refresh_history_btn.click(fn=show_history, outputs=history_df)
        submit_feedback_btn.click(fn=save_feedback_and_refresh, inputs=[record_id_input, feedback_label], outputs=[history_df, feedback_status])
app.launch()