Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,19 +1,30 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from googleapiclient.discovery import build
|
| 3 |
from google_auth_oauthlib.flow import InstalledAppFlow
|
| 4 |
from google.auth.transport.requests import Request
|
| 5 |
from google.oauth2.credentials import Credentials
|
| 6 |
-
import os
|
| 7 |
-
import json
|
| 8 |
-
import pandas as pd
|
| 9 |
import base64
|
| 10 |
-
import
|
| 11 |
-
import
|
| 12 |
-
from fpdf import FPDF
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
#
|
|
|
|
|
|
|
|
|
|
| 17 |
if "authenticated" not in st.session_state:
|
| 18 |
st.session_state.authenticated = False
|
| 19 |
if "creds" not in st.session_state:
|
|
@@ -24,178 +35,394 @@ if "auth_code" not in st.session_state:
|
|
| 24 |
st.session_state.auth_code = ""
|
| 25 |
if "flow" not in st.session_state:
|
| 26 |
st.session_state.flow = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if os.path.exists('token.json'):
|
| 31 |
try:
|
| 32 |
creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
| 33 |
if creds and creds.valid:
|
| 34 |
st.session_state.creds = creds
|
| 35 |
st.session_state.authenticated = True
|
| 36 |
-
st.success("Authentication successful!")
|
| 37 |
return creds
|
| 38 |
except Exception as e:
|
| 39 |
-
st.error(f"Invalid token.json file: {e}")
|
| 40 |
os.remove('token.json')
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
st.session_state.creds
|
| 45 |
st.session_state.authenticated = True
|
| 46 |
-
st.success("Authentication successful!")
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
else:
|
| 49 |
if not st.session_state.flow:
|
| 50 |
st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES)
|
| 51 |
st.session_state.flow.redirect_uri = 'http://localhost'
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
st.
|
| 56 |
-
st.code(st.session_state.auth_url)
|
| 57 |
|
| 58 |
-
# Submit Authentication Code
|
| 59 |
def submit_auth_code():
|
| 60 |
try:
|
| 61 |
st.session_state.flow.fetch_token(code=st.session_state.auth_code)
|
| 62 |
st.session_state.creds = st.session_state.flow.credentials
|
| 63 |
st.session_state.authenticated = True
|
| 64 |
with open('token.json', 'w') as token_file:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
"refresh_token": st.session_state.creds.refresh_token,
|
| 68 |
-
"token_uri": st.session_state.creds.token_uri,
|
| 69 |
-
"client_id": st.session_state.creds.client_id,
|
| 70 |
-
"client_secret": st.session_state.creds.client_secret,
|
| 71 |
-
"scopes": st.session_state.creds.scopes
|
| 72 |
-
}, token_file)
|
| 73 |
-
st.success("Authentication successful!")
|
| 74 |
except Exception as e:
|
| 75 |
-
st.error(f"Error during authentication: {e}")
|
| 76 |
-
|
| 77 |
-
#
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import faiss
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 5 |
+
import requests
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import pickle
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
from googleapiclient.discovery import build
|
| 11 |
from google_auth_oauthlib.flow import InstalledAppFlow
|
| 12 |
from google.auth.transport.requests import Request
|
| 13 |
from google.oauth2.credentials import Credentials
|
|
|
|
|
|
|
|
|
|
| 14 |
import base64
|
| 15 |
+
import re
|
| 16 |
+
from pyngrok import ngrok
|
|
|
|
| 17 |
|
| 18 |
+
# ===============================
|
| 19 |
+
# 1. Streamlit App Configuration
|
| 20 |
+
# ===============================
|
| 21 |
+
st.set_page_config(page_title="π₯ Email Chat Application", layout="wide")
|
| 22 |
+
st.title("βοΈ Email Chat Application")
|
| 23 |
|
| 24 |
+
# ===============================
|
| 25 |
+
# 2. Gmail Authentication Configuration
|
| 26 |
+
# ===============================
|
| 27 |
+
SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']
|
| 28 |
if "authenticated" not in st.session_state:
|
| 29 |
st.session_state.authenticated = False
|
| 30 |
if "creds" not in st.session_state:
|
|
|
|
| 35 |
st.session_state.auth_code = ""
|
| 36 |
if "flow" not in st.session_state:
|
| 37 |
st.session_state.flow = None
|
| 38 |
+
if "data_chunks" not in st.session_state:
|
| 39 |
+
st.session_state.data_chunks = [] # List to store all email chunks
|
| 40 |
+
if "embeddings" not in st.session_state:
|
| 41 |
+
st.session_state.embeddings = None
|
| 42 |
+
if "vector_store" not in st.session_state:
|
| 43 |
+
st.session_state.vector_store = None
|
| 44 |
|
| 45 |
+
def count_tokens(text):
|
| 46 |
+
return len(text.split())
|
| 47 |
+
|
| 48 |
+
# ===============================
|
| 49 |
+
# 3. Gmail Authentication Functions
|
| 50 |
+
# ===============================
|
| 51 |
+
def reset_session_state():
|
| 52 |
+
st.session_state.authenticated = False
|
| 53 |
+
st.session_state.creds = None
|
| 54 |
+
st.session_state.auth_url = None
|
| 55 |
+
st.session_state.auth_code = ""
|
| 56 |
+
st.session_state.flow = None
|
| 57 |
+
st.session_state.data_chunks = []
|
| 58 |
+
st.session_state.embeddings = None
|
| 59 |
+
st.session_state.vector_store = None
|
| 60 |
+
for filename in ["token.json", "data_chunks.pkl", "embeddings.pkl", "vector_store.index"]:
|
| 61 |
+
if os.path.exists(filename):
|
| 62 |
+
os.remove(filename)
|
| 63 |
+
|
| 64 |
+
def authenticate_gmail(credentials_file):
|
| 65 |
+
creds = None
|
| 66 |
if os.path.exists('token.json'):
|
| 67 |
try:
|
| 68 |
creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
| 69 |
if creds and creds.valid:
|
| 70 |
st.session_state.creds = creds
|
| 71 |
st.session_state.authenticated = True
|
| 72 |
+
st.success("β
Authentication successful!")
|
| 73 |
return creds
|
| 74 |
except Exception as e:
|
| 75 |
+
st.error(f"β Invalid token.json file: {e}")
|
| 76 |
os.remove('token.json')
|
| 77 |
+
if not creds or not creds.valid:
|
| 78 |
+
if creds and creds.expired and creds.refresh_token:
|
| 79 |
+
creds.refresh(Request())
|
| 80 |
+
st.session_state.creds = creds
|
| 81 |
st.session_state.authenticated = True
|
| 82 |
+
st.success("β
Authentication successful!")
|
| 83 |
+
with open('token.json', 'w') as token_file:
|
| 84 |
+
token_file.write(creds.to_json())
|
| 85 |
+
return creds
|
| 86 |
else:
|
| 87 |
if not st.session_state.flow:
|
| 88 |
st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES)
|
| 89 |
st.session_state.flow.redirect_uri = 'http://localhost'
|
| 90 |
+
auth_url, _ = st.session_state.flow.authorization_url(prompt='consent')
|
| 91 |
+
st.session_state.auth_url = auth_url
|
| 92 |
+
st.info("π **Authorize the application by visiting the URL below:**")
|
| 93 |
+
st.markdown(f"[Authorize]({st.session_state.auth_url})")
|
|
|
|
| 94 |
|
|
|
|
| 95 |
def submit_auth_code():
|
| 96 |
try:
|
| 97 |
st.session_state.flow.fetch_token(code=st.session_state.auth_code)
|
| 98 |
st.session_state.creds = st.session_state.flow.credentials
|
| 99 |
st.session_state.authenticated = True
|
| 100 |
with open('token.json', 'w') as token_file:
|
| 101 |
+
token_file.write(st.session_state.creds.to_json())
|
| 102 |
+
st.success("β
Authentication successful!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
except Exception as e:
|
| 104 |
+
st.error(f"β Error during authentication: {e}")
|
| 105 |
+
|
| 106 |
+
# ===============================
|
| 107 |
+
# 4. Email Data Extraction, Embedding and Vector Store Functions
|
| 108 |
+
# ===============================
|
| 109 |
+
def extract_email_body(payload):
|
| 110 |
+
if 'body' in payload and 'data' in payload['body'] and payload['body']['data']:
|
| 111 |
+
try:
|
| 112 |
+
return base64.urlsafe_b64decode(payload['body']['data'].encode('UTF-8')).decode('UTF-8')
|
| 113 |
+
except Exception as e:
|
| 114 |
+
st.error(f"Error decoding email body: {e}")
|
| 115 |
+
return ""
|
| 116 |
+
if 'parts' in payload:
|
| 117 |
+
for part in payload['parts']:
|
| 118 |
+
if part.get('mimeType') == 'text/plain' and 'data' in part.get('body', {}):
|
| 119 |
+
try:
|
| 120 |
+
return base64.urlsafe_b64decode(part['body']['data'].encode('UTF-8')).decode('UTF-8')
|
| 121 |
+
except Exception as e:
|
| 122 |
+
st.error(f"Error decoding email part: {e}")
|
| 123 |
+
continue
|
| 124 |
+
if payload['parts']:
|
| 125 |
+
first_part = payload['parts'][0]
|
| 126 |
+
if 'data' in first_part.get('body', {}):
|
| 127 |
+
try:
|
| 128 |
+
return base64.urlsafe_b64decode(first_part['body']['data'].encode('UTF-8')).decode('UTF-8')
|
| 129 |
+
except Exception as e:
|
| 130 |
+
st.error(f"Error decoding fallback email part: {e}")
|
| 131 |
+
return ""
|
| 132 |
+
return ""
|
| 133 |
+
|
| 134 |
+
def combine_email_text(email):
|
| 135 |
+
parts = []
|
| 136 |
+
if email.get("sender"):
|
| 137 |
+
parts.append(f"Sender: {email['sender']}")
|
| 138 |
+
if email.get("to"):
|
| 139 |
+
parts.append(f"To: {email['to']}")
|
| 140 |
+
if email.get("date"):
|
| 141 |
+
parts.append(f"Date: {email['date']}")
|
| 142 |
+
if email.get("subject"):
|
| 143 |
+
parts.append(f"Subject: {email['subject']}")
|
| 144 |
+
if email.get("body"):
|
| 145 |
+
parts.append(f"Body: {email['body']}")
|
| 146 |
+
return "\n".join(parts)
|
| 147 |
+
|
| 148 |
+
def create_chunks_from_gmail(service, label):
|
| 149 |
+
try:
|
| 150 |
+
messages = []
|
| 151 |
+
result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500).execute()
|
| 152 |
+
messages.extend(result.get('messages', []))
|
| 153 |
+
while 'nextPageToken' in result:
|
| 154 |
+
token = result["nextPageToken"]
|
| 155 |
+
result = service.users().messages().list(userId='me', labelIds=[label],
|
| 156 |
+
maxResults=500, pageToken=token).execute()
|
| 157 |
+
messages.extend(result.get('messages', []))
|
| 158 |
+
|
| 159 |
+
data_chunks = []
|
| 160 |
+
progress_bar = st.progress(0)
|
| 161 |
+
total = len(messages)
|
| 162 |
+
for idx, msg in enumerate(messages):
|
| 163 |
+
msg_data = service.users().messages().get(userId='me', id=msg['id'], format='full').execute()
|
| 164 |
+
headers = msg_data.get('payload', {}).get('headers', [])
|
| 165 |
+
email_dict = {"id": msg['id']}
|
| 166 |
+
for header in headers:
|
| 167 |
+
name = header.get('name', '').lower()
|
| 168 |
+
if name == 'from':
|
| 169 |
+
email_dict['sender'] = header.get('value', '')
|
| 170 |
+
elif name == 'subject':
|
| 171 |
+
email_dict['subject'] = header.get('value', '')
|
| 172 |
+
elif name == 'to':
|
| 173 |
+
email_dict['to'] = header.get('value', '')
|
| 174 |
+
elif name == 'date':
|
| 175 |
+
email_dict['date'] = header.get('value', '')
|
| 176 |
+
email_dict['body'] = extract_email_body(msg_data.get('payload', {}))
|
| 177 |
+
data_chunks.append(email_dict)
|
| 178 |
+
progress_bar.progress((idx + 1) / total)
|
| 179 |
+
st.session_state.data_chunks = data_chunks
|
| 180 |
+
st.success(f"β
Data chunks created successfully from Gmail! Total emails processed: {len(data_chunks)}")
|
| 181 |
+
# Save chunks locally for future use.
|
| 182 |
+
with open("data_chunks.pkl", "wb") as f:
|
| 183 |
+
pickle.dump(data_chunks, f)
|
| 184 |
+
except Exception as e:
|
| 185 |
+
st.error(f"β Error creating chunks from Gmail: {e}")
|
| 186 |
+
|
| 187 |
+
def embed_emails(email_chunks):
|
| 188 |
+
st.header("π Embedding Data and Creating Vector Store")
|
| 189 |
+
with st.spinner('π Embedding data...'):
|
| 190 |
+
try:
|
| 191 |
+
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 192 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 193 |
+
embed_model.to(device)
|
| 194 |
+
combined_texts = [combine_email_text(email) for email in email_chunks]
|
| 195 |
+
batch_size = 64
|
| 196 |
+
embeddings = []
|
| 197 |
+
for i in range(0, len(combined_texts), batch_size):
|
| 198 |
+
batch = combined_texts[i:i+batch_size]
|
| 199 |
+
batch_embeddings = embed_model.encode(
|
| 200 |
+
batch,
|
| 201 |
+
convert_to_numpy=True,
|
| 202 |
+
show_progress_bar=False,
|
| 203 |
+
device=device
|
| 204 |
+
)
|
| 205 |
+
embeddings.append(batch_embeddings)
|
| 206 |
+
embeddings = np.vstack(embeddings)
|
| 207 |
+
faiss.normalize_L2(embeddings)
|
| 208 |
+
st.session_state.embeddings = embeddings
|
| 209 |
+
dimension = embeddings.shape[1]
|
| 210 |
+
index = faiss.IndexFlatIP(dimension)
|
| 211 |
+
index.add(embeddings)
|
| 212 |
+
st.session_state.vector_store = index
|
| 213 |
+
st.success("β
Data embedding and vector store created successfully!")
|
| 214 |
+
# Save embeddings and index to disk.
|
| 215 |
+
with open('embeddings.pkl', 'wb') as f:
|
| 216 |
+
pickle.dump(embeddings, f)
|
| 217 |
+
faiss.write_index(index, 'vector_store.index')
|
| 218 |
+
except Exception as e:
|
| 219 |
+
st.error(f"β Error during embedding: {e}")
|
| 220 |
+
|
| 221 |
+
def save_embeddings_and_index():
|
| 222 |
+
try:
|
| 223 |
+
with open('embeddings.pkl', 'wb') as f:
|
| 224 |
+
pickle.dump(st.session_state.embeddings, f)
|
| 225 |
+
faiss.write_index(st.session_state.vector_store, 'vector_store.index')
|
| 226 |
+
st.success("πΎ Embeddings and vector store saved successfully!")
|
| 227 |
+
except Exception as e:
|
| 228 |
+
st.error(f"β Error saving embeddings and vector store: {e}")
|
| 229 |
+
|
| 230 |
+
def load_embeddings_and_index():
|
| 231 |
+
try:
|
| 232 |
+
with open('embeddings.pkl', 'rb') as f:
|
| 233 |
+
st.session_state.embeddings = pickle.load(f)
|
| 234 |
+
st.session_state.vector_store = faiss.read_index('vector_store.index')
|
| 235 |
+
st.success("π Embeddings and vector store loaded successfully!")
|
| 236 |
+
except Exception as e:
|
| 237 |
+
st.error(f"β Error loading embeddings and vector store: {e}")
|
| 238 |
+
|
| 239 |
+
def load_chunks():
|
| 240 |
+
try:
|
| 241 |
+
with open("data_chunks.pkl", "rb") as f:
|
| 242 |
+
st.session_state.data_chunks = pickle.load(f)
|
| 243 |
+
st.success("π Email chunks loaded successfully!")
|
| 244 |
+
except Exception as e:
|
| 245 |
+
st.error(f"β Error loading email chunks: {e}")
|
| 246 |
+
|
| 247 |
+
# ===============================
|
| 248 |
+
# 5. Handling User Queries
|
| 249 |
+
# ===============================
|
| 250 |
+
def preprocess_query(query):
|
| 251 |
+
return query.lower().strip()
|
| 252 |
+
|
| 253 |
+
def handle_user_query():
|
| 254 |
+
st.header("π¬ Let's chat with your Email")
|
| 255 |
+
user_query = st.text_input("Enter your query:")
|
| 256 |
+
TOP_K = 10
|
| 257 |
+
SIMILARITY_THRESHOLD = 0.4
|
| 258 |
+
|
| 259 |
+
if st.button("π Get Response"):
|
| 260 |
+
if (st.session_state.vector_store is None or
|
| 261 |
+
st.session_state.embeddings is None or
|
| 262 |
+
st.session_state.data_chunks is None):
|
| 263 |
+
st.error("β Please process your email data or load saved chunks/embeddings first.")
|
| 264 |
+
return
|
| 265 |
+
if not user_query.strip():
|
| 266 |
+
st.error("β Please enter a valid query.")
|
| 267 |
+
return
|
| 268 |
+
with st.spinner('π Processing your query...'):
|
| 269 |
+
try:
|
| 270 |
+
# Retrieve candidates using the bi-encoder.
|
| 271 |
+
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 272 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 273 |
+
embed_model.to(device)
|
| 274 |
+
processed_query = preprocess_query(user_query)
|
| 275 |
+
query_embedding = embed_model.encode(
|
| 276 |
+
[processed_query],
|
| 277 |
+
convert_to_numpy=True,
|
| 278 |
+
show_progress_bar=False,
|
| 279 |
+
device=device
|
| 280 |
+
)
|
| 281 |
+
faiss.normalize_L2(query_embedding)
|
| 282 |
+
distances, indices = st.session_state.vector_store.search(query_embedding, TOP_K)
|
| 283 |
+
candidates = []
|
| 284 |
+
for idx, score in zip(indices[0], distances[0]):
|
| 285 |
+
candidates.append((st.session_state.data_chunks[idx], score))
|
| 286 |
+
|
| 287 |
+
# Boost candidates if sender or "to" field contains query tokens (e.g., email addresses).
|
| 288 |
+
query_tokens = re.findall(r'\S+@\S+', user_query)
|
| 289 |
+
if query_tokens:
|
| 290 |
+
for i in range(len(candidates)):
|
| 291 |
+
candidate_email_str = (
|
| 292 |
+
(candidates[i][0].get("sender", "") + " " + candidates[i][0].get("to", "")).lower()
|
| 293 |
+
)
|
| 294 |
+
for token in query_tokens:
|
| 295 |
+
if token.lower() in candidate_email_str:
|
| 296 |
+
candidates[i] = (candidates[i][0], max(candidates[i][1], 1.0))
|
| 297 |
+
filtered_candidates = []
|
| 298 |
+
for candidate, score in candidates:
|
| 299 |
+
candidate_text = combine_email_text(candidate).lower()
|
| 300 |
+
if any(token.lower() in candidate_text for token in query_tokens):
|
| 301 |
+
filtered_candidates.append((candidate, score))
|
| 302 |
+
if filtered_candidates:
|
| 303 |
+
candidates = filtered_candidates
|
| 304 |
+
else:
|
| 305 |
+
st.info("No candidate emails contain the query token(s) exactly. Proceeding with all candidates.")
|
| 306 |
+
|
| 307 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 308 |
+
if not candidates:
|
| 309 |
+
st.subheader("π AI Response:")
|
| 310 |
+
st.write("β οΈ No documents found.")
|
| 311 |
+
return
|
| 312 |
+
if candidates[0][1] < SIMILARITY_THRESHOLD:
|
| 313 |
+
st.subheader("π AI Response:")
|
| 314 |
+
st.write("β οΈ No document strongly matches your query. Try refining your query.")
|
| 315 |
+
return
|
| 316 |
+
|
| 317 |
+
# Re-rank candidates using the cross-encoder.
|
| 318 |
+
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
| 319 |
+
candidate_pairs = [(user_query, combine_email_text(candidate[0])) for candidate in candidates]
|
| 320 |
+
rerank_scores = cross_encoder.predict(candidate_pairs)
|
| 321 |
+
reranked_candidates = [(candidates[i][0], rerank_scores[i]) for i in range(len(candidates))]
|
| 322 |
+
reranked_candidates.sort(key=lambda x: x[1], reverse=True)
|
| 323 |
+
retrieved_emails = [email for email, score in reranked_candidates]
|
| 324 |
+
retrieved_scores = [score for email, score in reranked_candidates]
|
| 325 |
+
average_similarity = np.mean(retrieved_scores)
|
| 326 |
+
|
| 327 |
+
# Build the final context string.
|
| 328 |
+
context_str = "\n\n".join([combine_email_text(email) for email in retrieved_emails])
|
| 329 |
+
MAX_CONTEXT_TOKENS = 500
|
| 330 |
+
context_tokens = context_str.split()
|
| 331 |
+
if len(context_tokens) > MAX_CONTEXT_TOKENS:
|
| 332 |
+
context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS])
|
| 333 |
+
|
| 334 |
+
payload = {
|
| 335 |
+
"model": "llama3-8b-8192", # Adjust as needed.
|
| 336 |
+
"messages": [
|
| 337 |
+
{"role": "system", "content": f"Use the following context:\n{context_str}"},
|
| 338 |
+
{"role": "user", "content": user_query}
|
| 339 |
+
]
|
| 340 |
+
}
|
| 341 |
+
api_key = "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" # Replace with your API key.
|
| 342 |
+
url = "https://api.groq.com/openai/v1/chat/completions"
|
| 343 |
+
headers = {
|
| 344 |
+
"Authorization": f"Bearer {api_key}",
|
| 345 |
+
"Content-Type": "application/json"
|
| 346 |
+
}
|
| 347 |
+
response = requests.post(url, headers=headers, json=payload)
|
| 348 |
+
if response.status_code == 200:
|
| 349 |
+
response_json = response.json()
|
| 350 |
+
generated_text = response_json["choices"][0]["message"]["content"]
|
| 351 |
+
st.subheader("π AI Response:")
|
| 352 |
+
st.write(generated_text)
|
| 353 |
+
st.write(f"Average Re-Ranked Score: {average_similarity:.4f}")
|
| 354 |
+
else:
|
| 355 |
+
st.error(f"β Error from LLM API: {response.status_code} - {response.text}")
|
| 356 |
+
except Exception as e:
|
| 357 |
+
st.error(f"β An error occurred during processing: {e}")
|
| 358 |
+
|
| 359 |
+
# ===============================
|
| 360 |
+
# 6. Main Application Logic
|
| 361 |
+
# ===============================
|
| 362 |
+
def main():
|
| 363 |
+
st.sidebar.header("π Gmail Authentication")
|
| 364 |
+
credentials_file = st.sidebar.file_uploader("π Upload `credentials.json`", type=["json"])
|
| 365 |
+
if credentials_file and st.sidebar.button("π Authenticate"):
|
| 366 |
+
reset_session_state()
|
| 367 |
+
with open("credentials.json", "wb") as f:
|
| 368 |
+
f.write(credentials_file.getbuffer())
|
| 369 |
+
authenticate_gmail("credentials.json")
|
| 370 |
+
|
| 371 |
+
# Option to load previously saved email chunks.
|
| 372 |
+
chunks_file = st.sidebar.file_uploader("π Upload saved email chunks (data_chunks.pkl)", type=["pkl"])
|
| 373 |
+
if chunks_file:
|
| 374 |
+
try:
|
| 375 |
+
st.session_state.data_chunks = pickle.load(chunks_file)
|
| 376 |
+
st.success("π Email chunks loaded successfully from upload!")
|
| 377 |
+
except Exception as e:
|
| 378 |
+
st.error(f"β Error loading uploaded email chunks: {e}")
|
| 379 |
+
|
| 380 |
+
# Option to load previously saved embeddings and vector store.
|
| 381 |
+
embeddings_file = st.sidebar.file_uploader("π Upload saved embeddings (embeddings.pkl)", type=["pkl"])
|
| 382 |
+
vector_file = st.sidebar.file_uploader("π Upload saved vector store (vector_store.index)", type=["index", "idx"])
|
| 383 |
+
if embeddings_file and vector_file:
|
| 384 |
+
try:
|
| 385 |
+
st.session_state.embeddings = pickle.load(embeddings_file)
|
| 386 |
+
st.session_state.vector_store = faiss.read_index(vector_file.name)
|
| 387 |
+
st.success("π Embeddings and vector store loaded successfully from upload!")
|
| 388 |
+
except Exception as e:
|
| 389 |
+
st.error(f"β Error loading uploaded embeddings/vector store: {e}")
|
| 390 |
+
|
| 391 |
+
if st.session_state.auth_url:
|
| 392 |
+
st.sidebar.markdown("### π **Authorization URL:**")
|
| 393 |
+
st.sidebar.markdown(f"[Authorize]({st.session_state.auth_url})")
|
| 394 |
+
st.sidebar.text_input("π Enter the authorization code:", key="auth_code")
|
| 395 |
+
if st.sidebar.button("β
Submit Authentication Code"):
|
| 396 |
+
submit_auth_code()
|
| 397 |
+
|
| 398 |
+
if st.session_state.authenticated:
|
| 399 |
+
st.sidebar.success("β
You are authenticated!")
|
| 400 |
+
st.sidebar.header("π Data Management")
|
| 401 |
+
label = st.sidebar.selectbox("π¬ Select Label to Process Emails From:",
|
| 402 |
+
["INBOX", "SENT", "DRAFTS", "TRASH", "SPAM"],
|
| 403 |
+
key="label_selector")
|
| 404 |
+
if st.sidebar.button("π₯ Create Chunks and Embed Data"):
|
| 405 |
+
service = build('gmail', 'v1', credentials=st.session_state.creds)
|
| 406 |
+
create_chunks_from_gmail(service, label)
|
| 407 |
+
if st.session_state.data_chunks:
|
| 408 |
+
embed_emails(st.session_state.data_chunks)
|
| 409 |
+
if (st.session_state.embeddings is not None and st.session_state.vector_store is not None):
|
| 410 |
+
with st.expander("πΎ Save Data"):
|
| 411 |
+
if st.button("πΎ Save Email Chunks"):
|
| 412 |
+
try:
|
| 413 |
+
with open("data_chunks.pkl", "wb") as f:
|
| 414 |
+
pickle.dump(st.session_state.data_chunks, f)
|
| 415 |
+
st.success("πΎ Email chunks saved to disk!")
|
| 416 |
+
except Exception as e:
|
| 417 |
+
st.error(f"β Error saving email chunks: {e}")
|
| 418 |
+
if st.button("πΎ Save Embeddings & Vector Store"):
|
| 419 |
+
save_embeddings_and_index()
|
| 420 |
+
if (st.session_state.vector_store is not None and
|
| 421 |
+
st.session_state.embeddings is not None and
|
| 422 |
+
st.session_state.data_chunks is not None):
|
| 423 |
+
handle_user_query()
|
| 424 |
+
else:
|
| 425 |
+
st.warning("β οΈ You are not authenticated yet. Please authenticate to access your Gmail data.")
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
main()
|