Surya8663 commited on
Commit
4c21dbb
·
1 Parent(s): 5c04706

Final version for free tier (removes gmail fetch)

Browse files
Files changed (1) hide show
  1. backend/app.py +62 -49
backend/app.py CHANGED
@@ -1,66 +1,75 @@
1
- # backend/app.py (NEW VERSION USING A PRE-TRAINED MODEL)
2
 
3
  from transformers import pipeline
4
  import gradio as gr
5
  import pandas as pd
6
- # NEW, CORRECTED LINES
7
- from .gmail_fetcher import fetch_latest_emails
8
- from . import database
 
 
 
 
 
 
 
 
9
 
10
- # --- This is the core change ---
11
- # 1. We load a pre-trained "zero-shot-classification" pipeline from Hugging Face.
12
- # The first time this code runs, it will automatically download the model (approx. 1.6GB).
13
  print("Loading zero-shot classification model...")
14
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
15
  print("Model loaded successfully.")
16
- # -----------------------------
17
 
18
  database.init_db()
19
 
20
- # --- These are the labels we want to sort emails into ---
21
- # You can change or add to this list!
22
  POSSIBLE_LABELS = ["Work", "Promotions", "Personal", "Spam", "Important"]
23
- email_storage = {}
24
 
25
- def predict_and_save_emails(df):
26
- """Takes a DataFrame of emails, classifies them, and saves to DB."""
27
- if df.empty:
28
- return pd.DataFrame({"Message": ["No emails to process."]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Get the text content for classification
 
 
 
 
31
  sequences = (df['Subject'] + " " + df['Body']).tolist()
32
 
33
- print(f"Classifying {len(sequences)} emails...")
34
- # Use the zero-shot pipeline. It's powerful but can be slow on CPU for many emails.
35
  predictions = classifier(sequences, candidate_labels=POSSIBLE_LABELS)
36
  print("Classification complete.")
37
-
38
- # Extract the top label and score for each prediction
39
- pred_labels = [pred['labels'][0] for pred in predictions]
40
- pred_scores = [pred['scores'][0] for pred in predictions]
41
-
42
- # Add results to the DataFrame
43
  df['predicted_folder'] = pred_labels
44
  df['confidence'] = pred_scores
45
 
46
- # Save to the database
47
  database.add_email_records(df)
48
- print(f"Saved {len(df)} records to the database.")
49
 
50
  # Format for display
51
  df_display = df.copy()
52
  df_display['confidence'] = [f"{score:.2%}" for score in pred_scores]
53
- return df_display[["From", "Subject", "predicted_folder", "confidence"]]
54
-
55
- def fetch_and_classify_emails():
56
- """Fetches emails from Gmail and triggers the classification."""
57
- emails = fetch_latest_emails()
58
- if not emails:
59
- email_storage.clear()
60
- return pd.DataFrame({"From": ["No new emails found or an error occurred."], "Subject": [""]})
61
-
62
- df = pd.DataFrame(emails)
63
- return predict_and_save_emails(df)
64
 
65
  # --- History & Feedback Functions (Unchanged) ---
66
  def show_history():
@@ -73,20 +82,28 @@ def show_history():
73
  def save_feedback_and_refresh(record_id, corrected_label):
74
  if not record_id or not corrected_label:
75
  return show_history(), "Please enter a Record ID and a correct label first."
76
-
77
  record_id_int = int(record_id)
78
  database.update_feedback(record_id_int, corrected_label)
79
-
80
  return show_history(), f"Success! Record {record_id_int} updated."
81
 
82
- # --- Build the Gradio Interface (Unchanged) ---
83
  with gr.Blocks(theme=gr.themes.Soft()) as app:
84
  gr.Markdown("## 📧 Smart Email Sorter (Zero-Shot Model)")
85
 
86
- with gr.Tab("Fetch & Classify Gmail"):
87
- predict_gmail_btn = gr.Button("Fetch and Classify Latest Emails", variant="primary")
88
- email_display = gr.Dataframe(headers=["From", "Subject", "predicted_folder", "confidence"], label="Fetched & Classified Emails", wrap=True)
89
- predict_gmail_btn.click(fn=fetch_and_classify_emails, outputs=email_display, show_progress="full")
 
 
 
 
 
 
 
 
 
 
90
 
91
  with gr.Tab("History & Feedback"):
92
  gr.Markdown("### Prediction History & Feedback")
@@ -101,8 +118,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
101
  history_df = gr.Dataframe(label="All Saved Predictions", wrap=True)
102
  app.load(fn=show_history, outputs=history_df)
103
  refresh_history_btn.click(fn=show_history, outputs=history_df)
104
- submit_feedback_btn.click(
105
- fn=save_feedback_and_refresh,
106
- inputs=[record_id_input, feedback_label],
107
- outputs=[history_df, feedback_status]
108
- )
 
1
+ # backend/app.py (FINAL VERSION FOR FREE TIER)
2
 
3
  from transformers import pipeline
4
  import gradio as gr
5
  import pandas as pd
6
+ # We no longer need gmail_fetcher, so we remove the import
7
+ import database
8
+ import sys
9
+
10
+ # --- Diagnostic and Model Loading (Unchanged) ---
11
+ print("--- Starting Application ---")
12
+ try:
13
+ print(f"Gradio Version: {gr.__version__}")
14
+ except Exception as e:
15
+ print(f"Could not get Gradio version: {e}")
16
+ print("--------------------------")
17
 
 
 
 
18
  print("Loading zero-shot classification model...")
19
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
20
  print("Model loaded successfully.")
 
21
 
22
  database.init_db()
23
 
 
 
24
  POSSIBLE_LABELS = ["Work", "Promotions", "Personal", "Spam", "Important"]
 
25
 
26
+ # --- NEW/RESTORED: Functions for Single and Batch Prediction ---
27
+
28
+ def predict_single(subject, body):
29
+ """Classifies a single email from subject and body text."""
30
+ if not subject.strip() and not body.strip():
31
+ return "Please provide a subject or body."
32
+
33
+ sequence = subject + " " + body
34
+ prediction = classifier(sequence, candidate_labels=POSSIBLE_LABELS)
35
+
36
+ top_label = prediction['labels'][0]
37
+ top_score = prediction['scores'][0]
38
+
39
+ return f"Predicted Folder: {top_label} (Confidence: {top_score:.2%})"
40
+
41
+ def predict_batch_csv(file):
42
+ """Classifies a batch of emails from an uploaded CSV file."""
43
+ try:
44
+ df = pd.read_csv(file.name)
45
+ except Exception as e:
46
+ return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]})
47
 
48
+ if 'subject' not in df.columns or 'body' not in df.columns:
49
+ return pd.DataFrame({"Error": ["CSV must have 'subject' and 'body' columns."]})
50
+
51
+ df['Subject'] = df['subject'].fillna('')
52
+ df['Body'] = df['body'].fillna('')
53
  sequences = (df['Subject'] + " " + df['Body']).tolist()
54
 
55
+ print(f"Classifying {len(sequences)} emails from CSV...")
 
56
  predictions = classifier(sequences, candidate_labels=POSSIBLE_LABELS)
57
  print("Classification complete.")
58
+
59
+ pred_labels = [p['labels'][0] for p in predictions]
60
+ pred_scores = [p['scores'][0] for p in predictions]
61
+
 
 
62
  df['predicted_folder'] = pred_labels
63
  df['confidence'] = pred_scores
64
 
65
+ # Save results to the database
66
  database.add_email_records(df)
67
+ print(f"Saved {len(df)} records from CSV to the database.")
68
 
69
  # Format for display
70
  df_display = df.copy()
71
  df_display['confidence'] = [f"{score:.2%}" for score in pred_scores]
72
+ return df_display
 
 
 
 
 
 
 
 
 
 
73
 
74
  # --- History & Feedback Functions (Unchanged) ---
75
  def show_history():
 
82
  def save_feedback_and_refresh(record_id, corrected_label):
83
  if not record_id or not corrected_label:
84
  return show_history(), "Please enter a Record ID and a correct label first."
 
85
  record_id_int = int(record_id)
86
  database.update_feedback(record_id_int, corrected_label)
 
87
  return show_history(), f"Success! Record {record_id_int} updated."
88
 
89
+ # --- MODIFIED: Build the Final Gradio Interface ---
90
  with gr.Blocks(theme=gr.themes.Soft()) as app:
91
  gr.Markdown("## 📧 Smart Email Sorter (Zero-Shot Model)")
92
 
93
+ with gr.Tab("Single Prediction"):
94
+ gr.Markdown("### Classify a Single Email")
95
+ subject_input = gr.Textbox(label="Subject", lines=1)
96
+ body_input = gr.Textbox(label="Body", lines=5, placeholder="Paste the email body here...")
97
+ predict_btn = gr.Button("Classify Email", variant="primary")
98
+ output_label = gr.Label(label="Result")
99
+ predict_btn.click(fn=predict_single, inputs=[subject_input, body_input], outputs=output_label, show_progress="full")
100
+
101
+ with gr.Tab("Batch Prediction (CSV)"):
102
+ gr.Markdown("### Classify a Batch of Emails from a CSV File")
103
+ csv_input = gr.File(label="Upload a CSV file with 'subject' and 'body' columns")
104
+ csv_btn = gr.Button("Classify Batch", variant="primary")
105
+ batch_output = gr.Dataframe(label="Classification Results", wrap=True)
106
+ csv_btn.click(fn=predict_batch_csv, inputs=csv_input, outputs=batch_output, show_progress="full")
107
 
108
  with gr.Tab("History & Feedback"):
109
  gr.Markdown("### Prediction History & Feedback")
 
118
  history_df = gr.Dataframe(label="All Saved Predictions", wrap=True)
119
  app.load(fn=show_history, outputs=history_df)
120
  refresh_history_btn.click(fn=show_history, outputs=history_df)
121
+ submit_feedback_btn.click(fn=save_feedback_and_refresh, inputs=[record_id_input, feedback_label], outputs=[history_df, feedback_status])