Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import requests | |
| import os | |
| import torch | |
| import pickle | |
| import base64 | |
| from googleapiclient.discovery import build | |
| from google_auth_oauthlib.flow import InstalledAppFlow | |
| from google.auth.transport.requests import Request | |
| # =============================== | |
| # 1. Streamlit App Configuration | |
| # =============================== | |
| st.set_page_config(page_title="π₯ Email Chat Application", layout="wide") | |
| st.title("π¬ Turn Emails into ConversationsβEffortless Chat with Your Inbox! π©") | |
| # =============================== | |
| # 2. Initialize Session State Variables | |
| # =============================== | |
| if "authenticated" not in st.session_state: | |
| st.session_state.authenticated = False | |
| if "creds" not in st.session_state: | |
| st.session_state.creds = None | |
| if "auth_url" not in st.session_state: | |
| st.session_state.auth_url = None | |
| if "auth_code" not in st.session_state: | |
| st.session_state.auth_code = "" | |
| if "flow" not in st.session_state: | |
| st.session_state.flow = None | |
| if "data_chunks" not in st.session_state: | |
| st.session_state.data_chunks = [] # List to store all email chunks | |
| if "embeddings" not in st.session_state: | |
| st.session_state.embeddings = None | |
| if "vector_store" not in st.session_state: | |
| st.session_state.vector_store = None | |
| # For storing candidate context details. | |
| if "candidate_context" not in st.session_state: | |
| st.session_state.candidate_context = None | |
| if "raw_candidates" not in st.session_state: | |
| st.session_state.raw_candidates = None | |
| # Initialize chat messages | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Flags to ensure success messages are shown only once | |
| if "candidates_message_shown" not in st.session_state: | |
| st.session_state.candidates_message_shown = False | |
| if "vector_db_message_shown" not in st.session_state: | |
| st.session_state.vector_db_message_shown = False | |
| def count_tokens(text): | |
| return len(text.split()) | |
| # =============================== | |
| # 3. Gmail Authentication Functions (Updated) | |
| # =============================== | |
| def reset_session_state(): | |
| st.session_state.authenticated = False | |
| st.session_state.creds = None | |
| st.session_state.auth_url = None | |
| st.session_state.auth_code = "" | |
| st.session_state.flow = None | |
| st.session_state.data_chunks = [] | |
| st.session_state.embeddings = None | |
| st.session_state.vector_store = None | |
| st.session_state.candidate_context = None | |
| st.session_state.raw_candidates = None | |
| st.session_state.messages = [] | |
| st.session_state.candidates_message_shown = False | |
| st.session_state.vector_db_message_shown = False | |
| for filename in ["token.json", "data_chunks.pkl", "embeddings.pkl", "vector_store.index", "vector_database.pkl"]: | |
| if os.path.exists(filename): | |
| os.remove(filename) | |
| def authenticate_gmail(credentials_file): | |
| SCOPES = ['https://www.googleapis.com/auth/gmail.readonly'] | |
| creds = None | |
| if os.path.exists('token.json'): | |
| try: | |
| from google.oauth2.credentials import Credentials | |
| creds = Credentials.from_authorized_user_file('token.json', SCOPES) | |
| if creds and creds.valid: | |
| st.session_state.creds = creds | |
| st.session_state.authenticated = True | |
| if not st.session_state.candidates_message_shown: | |
| st.success("β Authentication successful!") | |
| st.session_state.candidates_message_shown = True | |
| return creds | |
| except Exception as e: | |
| st.error(f"β Invalid token.json file: {e}") | |
| os.remove('token.json') | |
| if not creds or not creds.valid: | |
| if creds and creds.expired and creds.refresh_token: | |
| creds.refresh(Request()) | |
| st.session_state.creds = creds | |
| st.session_state.authenticated = True | |
| if not st.session_state.candidates_message_shown: | |
| st.success("β Authentication successful!") | |
| st.session_state.candidates_message_shown = True | |
| with open('token.json', 'w') as token_file: | |
| token_file.write(creds.to_json()) | |
| return creds | |
| else: | |
| if not st.session_state.flow: | |
| st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES) | |
| st.session_state.flow.redirect_uri = 'http://localhost' | |
| auth_url, _ = st.session_state.flow.authorization_url(prompt='consent') | |
| st.session_state.auth_url = auth_url | |
| st.info("π **Authorize the application by visiting the URL below:**") | |
| st.markdown(f"[Authorize]({st.session_state.auth_url})") | |
| def submit_auth_code(): | |
| try: | |
| # Attempt to fetch the token using the provided authorization code | |
| st.session_state.flow.fetch_token(code=st.session_state.auth_code) | |
| st.session_state.creds = st.session_state.flow.credentials | |
| # Attempt to write the credentials to token.json | |
| with open('token.json', 'w') as token_file: | |
| token_json = st.session_state.creds.to_json() | |
| token_file.write(token_json) | |
| # If writing is successful, update the session state | |
| st.session_state.authenticated = True | |
| st.success("β Authentication successful!") | |
| except Exception as e: | |
| # If any error occurs, ensure the authenticated flag is not set | |
| st.session_state.authenticated = False | |
| st.error(f"β Error during authentication: {e}") | |
| # =============================== | |
| # 4. Email Data Extraction, Embedding and Vector Store Functions | |
| # =============================== | |
| def extract_email_body(payload): | |
| if 'body' in payload and 'data' in payload['body'] and payload['body']['data']: | |
| try: | |
| return base64.urlsafe_b64decode(payload['body']['data'].encode('UTF-8')).decode('UTF-8') | |
| except Exception as e: | |
| st.error(f"Error decoding email body: {e}") | |
| return "" | |
| if 'parts' in payload: | |
| for part in payload['parts']: | |
| if part.get('mimeType') == 'text/plain' and 'data' in part.get('body', {}): | |
| try: | |
| return base64.urlsafe_b64decode(part['body']['data'].encode('UTF-8')).decode('UTF-8') | |
| except Exception as e: | |
| st.error(f"Error decoding email part: {e}") | |
| continue | |
| if payload['parts']: | |
| first_part = payload['parts'][0] | |
| if 'data' in first_part.get('body', {}): | |
| try: | |
| return base64.urlsafe_b64decode(first_part['body']['data'].encode('UTF-8')).decode('UTF-8') | |
| except Exception as e: | |
| st.error(f"Error decoding fallback email part: {e}") | |
| return "" | |
| return "" | |
| def combine_email_text(email): | |
| # Build the complete email text by joining parts with HTML line breaks. | |
| parts = [] | |
| if email.get("sender"): | |
| parts.append("From: " + email["sender"]) | |
| if email.get("to"): | |
| parts.append("To: " + email["to"]) | |
| if email.get("date"): | |
| parts.append("Date: " + email["date"]) | |
| if email.get("subject"): | |
| parts.append("Subject: " + email["subject"]) | |
| if email.get("body"): | |
| parts.append("Body: " + email["body"]) | |
| return "<br>".join(parts) | |
| def create_chunks_from_gmail(service, label): | |
| try: | |
| messages = [] | |
| result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500).execute() | |
| messages.extend(result.get('messages', [])) | |
| while 'nextPageToken' in result: | |
| token = result["nextPageToken"] | |
| result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500, pageToken=token).execute() | |
| messages.extend(result.get('messages', [])) | |
| data_chunks = [] | |
| progress_bar = st.progress(0) | |
| total = len(messages) | |
| for idx, msg in enumerate(messages): | |
| msg_data = service.users().messages().get(userId='me', id=msg['id'], format='full').execute() | |
| headers = msg_data.get('payload', {}).get('headers', []) | |
| email_dict = {"id": msg['id']} | |
| for header in headers: | |
| name = header.get('name', '').lower() | |
| if name == 'from': | |
| email_dict['sender'] = header.get('value', '') | |
| elif name == 'subject': | |
| email_dict['subject'] = header.get('value', '') | |
| elif name == 'to': | |
| email_dict['to'] = header.get('value', '') | |
| elif name == 'date': | |
| email_dict['date'] = header.get('value', '') | |
| email_dict['body'] = extract_email_body(msg_data.get('payload', {})) | |
| data_chunks.append(email_dict) | |
| progress_bar.progress(min((idx + 1) / total, 1.0)) | |
| st.session_state.data_chunks.extend(data_chunks) | |
| if not st.session_state.vector_db_message_shown: | |
| st.success(f"π Vector database loaded successfully from upload! Total emails processed for label '{label}': {len(data_chunks)}") | |
| st.session_state.vector_db_message_shown = True | |
| except Exception as e: | |
| st.error(f"β Error creating chunks from Gmail for label '{label}': {e}") | |
| # ------------------------------- | |
| # Cached model loaders for efficiency | |
| # ------------------------------- | |
| def get_embed_model(): | |
| model = SentenceTransformer("all-MiniLM-L6-v2") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| return model, device | |
| def embed_emails(email_chunks): | |
| st.header("π Embedding Data and Creating Vector Store") | |
| progress_bar = st.progress(0) | |
| with st.spinner('π Embedding data...'): | |
| try: | |
| embed_model, device = get_embed_model() | |
| combined_texts = [combine_email_text(email) for email in email_chunks] | |
| batch_size = 64 | |
| embeddings = [] | |
| for i in range(0, len(combined_texts), batch_size): | |
| batch = combined_texts[i:i+batch_size] | |
| batch_embeddings = embed_model.encode( | |
| batch, | |
| convert_to_numpy=True, | |
| show_progress_bar=False, | |
| device=device | |
| ) | |
| embeddings.append(batch_embeddings) | |
| progress_value = min((i + batch_size) / len(combined_texts), 1.0) | |
| progress_bar.progress(progress_value) | |
| embeddings = np.vstack(embeddings) | |
| faiss.normalize_L2(embeddings) | |
| st.session_state.embeddings = embeddings | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dimension) | |
| index.add(embeddings) | |
| st.session_state.vector_store = index | |
| if not st.session_state.candidates_message_shown: | |
| st.success("β Data embedding and vector store created successfully!") | |
| st.session_state.candidates_message_shown = True | |
| except Exception as e: | |
| st.error(f"β Error during embedding: {e}") | |
| # New function to save the entire vector database as a single pickle file. | |
| def save_vector_database(): | |
| try: | |
| vector_db = { | |
| "vector_store": st.session_state.vector_store, | |
| "embeddings": st.session_state.embeddings, | |
| "data_chunks": st.session_state.data_chunks | |
| } | |
| db_data = pickle.dumps(vector_db) | |
| st.download_button( | |
| label="πΎ Download Vector Database", | |
| data=db_data, | |
| file_name="vector_database.pkl", | |
| mime="application/octet-stream" | |
| ) | |
| except Exception as e: | |
| st.error(f"β Error saving vector database: {e}") | |
| # =============================== | |
| # 5. Handling User Queries (User-Controlled Threshold) | |
| # =============================== | |
| def preprocess_query(query): | |
| return query.lower().strip() | |
| def process_candidate_emails(query, similarity_threshold): | |
| """ | |
| Process the query by computing its embedding, searching the vector store, | |
| filtering candidates based on a similarity threshold, and building a context string. | |
| """ | |
| TOP_K = 20 # Increased to allow for threshold filtering | |
| # Reset candidate context for each query | |
| st.session_state.candidate_context = None | |
| st.session_state.raw_candidates = None | |
| if st.session_state.vector_store is None: | |
| st.error("β Please process your email data or load a saved vector database first.") | |
| return | |
| try: | |
| embed_model, device = get_embed_model() | |
| processed_query = preprocess_query(query) | |
| query_embedding = embed_model.encode( | |
| [processed_query], | |
| convert_to_numpy=True, | |
| show_progress_bar=False, | |
| device=device | |
| ) | |
| faiss.normalize_L2(query_embedding) | |
| # Perform search | |
| distances, indices = st.session_state.vector_store.search(query_embedding, TOP_K) | |
| candidates = [] | |
| for idx, sim in zip(indices[0], distances[0]): | |
| # Include candidate only if similarity meets the threshold | |
| if sim >= similarity_threshold: | |
| candidates.append((st.session_state.data_chunks[idx], sim)) | |
| if not candidates: | |
| # Append warning message as assistant message | |
| st.session_state.messages.append({"role": "assistant", "content": "β οΈ No matching embeddings found for your query with the selected threshold."}) | |
| return | |
| # Build the context string by concatenating all matching email texts using HTML breaks. | |
| context_str = "" | |
| for candidate, sim in candidates: | |
| context_str += combine_email_text(candidate) + "<br><br>" | |
| # Optionally limit context size. | |
| MAX_CONTEXT_TOKENS = 500 | |
| context_tokens = context_str.split() | |
| if len(context_tokens) > MAX_CONTEXT_TOKENS: | |
| context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS]) | |
| st.session_state.candidate_context = context_str | |
| st.session_state.raw_candidates = candidates | |
| except Exception as e: | |
| st.error(f"β An error occurred during processing: {e}") | |
| def call_llm_api(query): | |
| """ | |
| Send the user's query along with the concatenated matching email texts (context) | |
| to the LLM API and display the AI response. | |
| """ | |
| if not st.session_state.candidate_context: | |
| st.error("β No candidate context available. Please try again.") | |
| return | |
| # Retrieve the API key from the environment variable 'GroqAPI' | |
| api_key = os.getenv("GroqAPI") | |
| if not api_key: | |
| st.error("β API key not found. Please ensure 'GroqAPI' is set in Hugging Face Secrets.") | |
| return | |
| payload = { | |
| "model": "llama-3.3-70b-versatile", # Adjust model as needed. | |
| "messages": [ | |
| {"role": "system", "content": f"Use the following context:\n{st.session_state.candidate_context}"}, | |
| {"role": "user", "content": query} | |
| ] | |
| } | |
| url = "https://api.groq.com/openai/v1/chat/completions" # Verify this endpoint | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| response = requests.post(url, headers=headers, json=payload) | |
| response.raise_for_status() # Raises stored HTTPError, if one occurred. | |
| response_json = response.json() | |
| generated_text = response_json["choices"][0]["message"]["content"] | |
| # Append AI response to chat messages | |
| st.session_state.messages.append({"role": "assistant", "content": generated_text}) | |
| except requests.exceptions.HTTPError as http_err: | |
| try: | |
| error_info = response.json().get("error", {}) | |
| error_message = error_info.get("message", "An unknown error occurred.") | |
| st.session_state.messages.append({"role": "assistant", "content": f"β HTTP error occurred: {error_message}"}) | |
| except ValueError: | |
| st.session_state.messages.append({"role": "assistant", "content": f"β HTTP error occurred: {response.status_code} - {response.text}"}) | |
| except Exception as err: | |
| st.session_state.messages.append({"role": "assistant", "content": f"β An unexpected error occurred: {err}"}) | |
| def handle_user_query(): | |
| st.header("π¬ Let's Chat with Your Emails") | |
| # Expander for threshold selection | |
| with st.expander("π§ Adjust Similarity Threshold", expanded=False): | |
| similarity_threshold = st.slider( | |
| "Select Similarity Threshold", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.3, | |
| step=0.05, | |
| help="Adjust the similarity threshold to control the relevance of retrieved emails. Higher values yield more relevant results.", | |
| key='similarity_threshold' | |
| ) | |
| # Chat input for user queries | |
| user_input = st.chat_input("Enter your query:") | |
| if user_input: | |
| # Append user message to chat | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # Process the query | |
| process_candidate_emails(user_input, similarity_threshold) | |
| if st.session_state.candidate_context: | |
| # Send the query to the LLM API | |
| call_llm_api(user_input) | |
| # Display chat messages | |
| for msg in st.session_state.messages: | |
| if msg["role"] == "user": | |
| with st.chat_message("user"): | |
| st.markdown(msg["content"]) | |
| elif msg["role"] == "assistant": | |
| with st.chat_message("assistant"): | |
| st.markdown(msg["content"]) | |
| # Display matching email chunks in an expander | |
| if st.session_state.raw_candidates: | |
| with st.expander("π Matching Email Chunks:", expanded=False): | |
| for candidate, sim in st.session_state.raw_candidates: | |
| # Get a snippet (first 150 characters) of the body instead of full body content. | |
| body = candidate.get('body', 'No Content') | |
| snippet = (body[:150] + "...") if len(body) > 150 else body | |
| st.markdown( | |
| f"**From:** {candidate.get('sender','Unknown')} <br>" | |
| f"**To:** {candidate.get('to','Unknown')} <br>" | |
| f"**Date:** {candidate.get('date','Unknown')} <br>" | |
| f"**Subject:** {candidate.get('subject','No Subject')} <br>" | |
| f"**Body Snippet:** {snippet} <br>" | |
| f"**Similarity:** {sim:.4f}", | |
| unsafe_allow_html=True | |
| ) | |
| # =============================== | |
| # 6. Main Application Logic | |
| # =============================== | |
| def main(): | |
| SCOPES = ['https://www.googleapis.com/auth/gmail.readonly'] | |
| st.sidebar.header("π Gmail Authentication") | |
| credentials_file = st.sidebar.file_uploader("π Upload credentials.json", type=["json"]) | |
| data_management_option = st.sidebar.selectbox( | |
| "Choose an option", | |
| ["Upload Pre-existing Data", "Authenticate and Create New Data"], | |
| index=1 # Default to "Authenticate and Create New Data" | |
| ) | |
| if data_management_option == "Upload Pre-existing Data": | |
| uploaded_db = st.sidebar.file_uploader("π Upload vector database (vector_database.pkl)", type=["pkl"]) | |
| if uploaded_db: | |
| # Check file size; if larger than 200MB, show a warning and then continue. | |
| file_size_mb = uploaded_db.size / (1024 * 1024) | |
| if file_size_mb > 200: | |
| st.warning("β οΈ The uploaded file is larger than 200MB. It may take longer to load, but processing will continue.") | |
| try: | |
| vector_db = pickle.load(uploaded_db) | |
| st.session_state.vector_store = vector_db.get("vector_store") | |
| st.session_state.embeddings = vector_db.get("embeddings") | |
| st.session_state.data_chunks = vector_db.get("data_chunks") | |
| if not st.session_state.vector_db_message_shown: | |
| st.success("π Vector database loaded successfully from upload!") | |
| st.session_state.vector_db_message_shown = True | |
| except Exception as e: | |
| st.error(f"β Error loading vector database: {e}") | |
| elif data_management_option == "Authenticate and Create New Data": | |
| if credentials_file and st.sidebar.button("π Authenticate"): | |
| reset_session_state() | |
| with open("credentials.json", "wb") as f: | |
| f.write(credentials_file.getbuffer()) | |
| authenticate_gmail("credentials.json") | |
| if st.session_state.auth_url: | |
| st.sidebar.markdown("### π **Authorization URL:**") | |
| st.sidebar.markdown(f"[Authorize]({st.session_state.auth_url})") | |
| st.sidebar.text_input("π Enter the authorization code:", key="auth_code") | |
| if st.sidebar.button("β Submit Authentication Code"): | |
| submit_auth_code() | |
| if data_management_option == "Authenticate and Create New Data" and st.session_state.authenticated: | |
| st.sidebar.success("β You are authenticated!") | |
| st.header("π Data Management") | |
| # Multi-select widget for folders (labels) | |
| folders = st.multiselect("Select Labels (Folders) to Process Emails From:", | |
| ["INBOX", "SENT", "DRAFTS", "TRASH", "SPAM"], default=["INBOX"]) | |
| if st.button("π₯ Create Chunks and Embed Data"): | |
| service = build('gmail', 'v1', credentials=st.session_state.creds) | |
| all_chunks = [] | |
| # Process each selected folder | |
| for folder in folders: | |
| # Clear temporary data_chunks so that each folder's data is separate | |
| st.session_state.data_chunks = [] | |
| create_chunks_from_gmail(service, folder) | |
| if st.session_state.data_chunks: | |
| all_chunks.extend(st.session_state.data_chunks) | |
| st.session_state.data_chunks = all_chunks | |
| if st.session_state.data_chunks: | |
| embed_emails(st.session_state.data_chunks) | |
| if st.session_state.vector_store is not None: | |
| with st.expander("πΎ Download Data", expanded=False): | |
| save_vector_database() | |
| if st.session_state.vector_store is not None: | |
| handle_user_query() | |
| if __name__ == "__main__": | |
| main() | |