Spaces:
Sleeping
Sleeping
| # backend/app.py (FINAL VERSION FOR FREE TIER) | |
| from transformers import pipeline | |
| import gradio as gr | |
| import pandas as pd | |
| # We no longer need gmail_fetcher, so we remove the import | |
| import database | |
| import sys | |
| # --- Diagnostic and Model Loading (Unchanged) --- | |
| print("--- Starting Application ---") | |
| try: | |
| print(f"Gradio Version: {gr.__version__}") | |
| except Exception as e: | |
| print(f"Could not get Gradio version: {e}") | |
| print("--------------------------") | |
| print("Loading zero-shot classification model...") | |
| classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-3") | |
| print("Model loaded successfully.") | |
| database.init_db() | |
| POSSIBLE_LABELS = ["Work", "Promotions", "Personal", "Spam", "Important"] | |
| # --- NEW/RESTORED: Functions for Single and Batch Prediction --- | |
| def predict_single(subject, body): | |
| """Classifies a single email from subject and body text.""" | |
| if not subject.strip() and not body.strip(): | |
| return "Please provide a subject or body." | |
| sequence = subject + " " + body | |
| prediction = classifier(sequence, candidate_labels=POSSIBLE_LABELS) | |
| top_label = prediction['labels'][0] | |
| top_score = prediction['scores'][0] | |
| return f"Predicted Folder: {top_label} (Confidence: {top_score:.2%})" | |
| def predict_batch_csv(file): | |
| """Classifies a batch of emails from an uploaded CSV file.""" | |
| try: | |
| df = pd.read_csv(file.name) | |
| except Exception as e: | |
| return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]}) | |
| if 'subject' not in df.columns or 'body' not in df.columns: | |
| return pd.DataFrame({"Error": ["CSV must have 'subject' and 'body' columns."]}) | |
| df['Subject'] = df['subject'].fillna('') | |
| df['Body'] = df['body'].fillna('') | |
| sequences = (df['Subject'] + " " + df['Body']).tolist() | |
| print(f"Classifying {len(sequences)} emails from CSV...") | |
| predictions = classifier(sequences, candidate_labels=POSSIBLE_LABELS) | |
| print("Classification complete.") | |
| pred_labels = [p['labels'][0] for p in predictions] | |
| pred_scores = [p['scores'][0] for p in predictions] | |
| df['predicted_folder'] = pred_labels | |
| df['confidence'] = pred_scores | |
| # Save results to the database | |
| database.add_email_records(df) | |
| print(f"Saved {len(df)} records from CSV to the database.") | |
| # Format for display | |
| df_display = df.copy() | |
| df_display['confidence'] = [f"{score:.2%}" for score in pred_scores] | |
| return df_display | |
| # --- History & Feedback Functions (Unchanged) --- | |
| def show_history(): | |
| records, columns = database.get_all_emails() | |
| if not records: | |
| return pd.DataFrame({"Message": ["No history found."]}) | |
| history_df = pd.DataFrame(records, columns=columns) | |
| return history_df | |
| def save_feedback_and_refresh(record_id, corrected_label): | |
| if not record_id or not corrected_label: | |
| return show_history(), "Please enter a Record ID and a correct label first." | |
| record_id_int = int(record_id) | |
| database.update_feedback(record_id_int, corrected_label) | |
| return show_history(), f"Success! Record {record_id_int} updated." | |
| # --- MODIFIED: Build the Final Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("## 📧 Smart Email Sorter (Zero-Shot Model)") | |
| with gr.Tab("Single Prediction"): | |
| gr.Markdown("### Classify a Single Email") | |
| subject_input = gr.Textbox(label="Subject", lines=1) | |
| body_input = gr.Textbox(label="Body", lines=5, placeholder="Paste the email body here...") | |
| predict_btn = gr.Button("Classify Email", variant="primary") | |
| output_label = gr.Label(label="Result") | |
| predict_btn.click(fn=predict_single, inputs=[subject_input, body_input], outputs=output_label, show_progress="full") | |
| with gr.Tab("Batch Prediction (CSV)"): | |
| gr.Markdown("### Classify a Batch of Emails from a CSV File") | |
| csv_input = gr.File(label="Upload a CSV file with 'subject' and 'body' columns") | |
| csv_btn = gr.Button("Classify Batch", variant="primary") | |
| batch_output = gr.Dataframe(label="Classification Results", wrap=True) | |
| csv_btn.click(fn=predict_batch_csv, inputs=csv_input, outputs=batch_output, show_progress="full") | |
| with gr.Tab("History & Feedback"): | |
| gr.Markdown("### Prediction History & Feedback") | |
| gr.Markdown("To correct a prediction, find its `id` in the table, then enter it into the textbox below.") | |
| with gr.Row(): | |
| record_id_input = gr.Number(label="Enter Record ID to Correct") | |
| feedback_label = gr.Dropdown(label="Correct Folder", choices=POSSIBLE_LABELS) | |
| with gr.Row(): | |
| submit_feedback_btn = gr.Button("Submit Correction", variant="primary") | |
| refresh_history_btn = gr.Button("Refresh History") | |
| feedback_status = gr.Markdown("") | |
| history_df = gr.Dataframe(label="All Saved Predictions", wrap=True) | |
| app.load(fn=show_history, outputs=history_df) | |
| refresh_history_btn.click(fn=show_history, outputs=history_df) | |
| submit_feedback_btn.click(fn=save_feedback_and_refresh, inputs=[record_id_input, feedback_label], outputs=[history_df, feedback_status]) | |
| app.launch() | |