Revai / app.py
muhammad121's picture
Update app.py
9459b34 verified
"""
RevAI Chatbot Application
This Streamlit application provides a chat interface for users to interact with their documents
using AI-powered semantic search and natural language processing. The application uses
OpenAI's embeddings and chat models to provide relevant responses based on document content.
Features:
- Document-based Q&A using semantic search
- User-specific document repositories
- Interactive chat interface
"""
# Standard library imports
import json
import logging
import os
import re
import time
import uuid
from datetime import datetime
# Third-party imports
import streamlit as st
from dotenv import load_dotenv
from openai import OpenAI
from supabase import create_client
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Initialize environment variables
load_dotenv()
# Initialize OpenAI client
client = None
# Initialize Supabase client
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
if not SUPABASE_URL or not SUPABASE_KEY:
raise ValueError("SUPABASE_URL and SUPABASE_KEY environment variables must be set")
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
def init_openai_client():
"""
Initialize the OpenAI client with API key from environment variables.
Returns:
OpenAI: Initialized OpenAI client object
"""
global client
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
client = OpenAI(api_key=api_key)
else:
st.error("OpenAI API key not found in environment variables")
return client
# Initialize session state variables
if 'messages' not in st.session_state:
st.session_state.messages = []
def sanitize_customer_id(customer_id):
"""
Sanitize customer ID to prevent SQL injection and other security issues.
Args:
customer_id (str): The raw customer ID input
Returns:
str or None: Sanitized customer ID if valid, None otherwise
"""
if re.match(r"^[a-zA-Z0-9_-]+$", customer_id):
return customer_id
logger.warning(f"Invalid customer ID: {customer_id}")
return None
def get_customer_table(customer_id):
"""
Get the database table name associated with a customer ID.
Args:
customer_id (str): The customer ID to look up
Returns:
str or None: The table name if found, None otherwise
"""
sanitized_customer_id = sanitize_customer_id(customer_id)
if not sanitized_customer_id:
logger.error("Customer ID is invalid or potentially malicious")
return None
try:
logger.info(f"Fetching table for customer ID: {sanitized_customer_id}")
response = (
supabase.table("customer_mappings")
.select("table_name")
.eq("customer_id", sanitized_customer_id)
.execute()
)
if response.data and len(response.data) > 0:
table_name = response.data[0]["table_name"]
logger.info(f"Found table '{table_name}' for customer ID: {customer_id}")
return table_name
logger.warning(f"No table found for customer ID: {customer_id}")
return None
except Exception as e:
logger.error(f"Error finding customer table: {str(e)}")
return None
def get_embeddings(text):
"""
Generate embeddings for the provided text using OpenAI's embedding model.
Args:
text (str): The text to generate embeddings for
Returns:
list: The embedding vector if successful, empty list otherwise
"""
try:
logger.info("Generating embeddings for text")
if not text or not text.strip():
logger.warning("Empty or invalid text provided for embeddings")
return []
response = client.embeddings.create(
model="text-embedding-3-large", input=text
)
embedding = response.data[0].embedding
logger.info("Embeddings generated successfully")
return embedding
except Exception as e:
logger.error(f"Error generating embeddings: {str(e)}")
return []
def perform_similarity_search(customer_id, query_embedding):
"""
Perform similarity search using the query embedding against customer's documents.
Args:
customer_id (str): The customer ID to search documents for
query_embedding (list): The embedding vector for the query
Returns:
dict or None: Search results containing matches and count, None if error occurs
"""
try:
logger.info(f"Performing similarity search for customer ID: {customer_id}")
table_name = get_customer_table(customer_id)
if not table_name:
logger.warning(f"No table found for customer ID: {customer_id}")
return None
result = supabase.rpc(
"match_documents",
{
"query_embedding": query_embedding,
"match_threshold": 0.3,
"match_count": 3,
"table_name": table_name,
},
).execute()
if result.data and len(result.data) > 0:
logger.info(f"Similarity search returned {len(result.data)} results")
# Serialize the results to JSON for debugging
output_file = "search_results.json"
with open(output_file, "w") as f:
json.dump(result.data, f, indent=2)
logger.info(f"Search results saved to {output_file}")
return {"matches": result.data, "count": len(result.data)}
logger.info("Similarity search returned no results")
return {"matches": [], "count": 0}
except Exception as e:
logger.error(f"Error in similarity search: {str(e)}")
return None
def call_inference_api(prompt, user_id):
"""
Process user query by generating embeddings and performing similarity search.
Args:
prompt (str): The user's query
user_id (str): The user's ID for document access
Returns:
dict: Results or error information
"""
try:
query_embedding = get_embeddings(prompt)
if not query_embedding:
return {"error": "Failed to generate embeddings for query"}
search_results = perform_similarity_search(user_id, query_embedding)
if search_results is None:
return {"error": f"No data found for customer ID: {user_id}"}
results = []
for match in search_results.get("matches", []):
results.append({
"chunk_content": match.get("content", ""),
"extra_key_data": match.get("metadata", {}).get("extra_key_data", {}),
"document_name": match.get("metadata", {}).get("document_name", "Unknown document"),
})
return {"results": results}
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
return {"error": f"Error during inference: {str(e)}"}
def format_chunks(chunks):
"""
Format the retrieved document chunks into a readable text format.
Args:
chunks (list): List of document chunks from similarity search
Returns:
str: Formatted text containing content from retrieved chunks
"""
chunks_text = ""
for i, chunk in enumerate(chunks):
content = chunk.get("chunk_content", "No content")
extra_data = chunk.get("extra_key_data", "")
doc_name = chunk.get("document_name", "Unknown document")
chunks_text += f"Source {i+1} ({doc_name}):\nContent: {content}\n"
if extra_data:
chunks_text += f"Summary: {extra_data}\n"
chunks_text += "\n"
return chunks_text
def prepare_messages(chunks_text):
"""
Prepare the messages for the OpenAI API with chat history and context.
Args:
chunks_text (str): Formatted document chunks to provide as context
Returns:
list: Messages formatted for OpenAI chat completion API
"""
messages = [{"role": "system", "content": f"You are a helpful assistant. Use the following information to answer the user's question:\n\n{chunks_text} keep your answer clear and complete and fast "}]
for msg in st.session_state.messages[-10:]:
messages.append({"role": msg["role"], "content": msg["content"]})
return messages
def generate_ai_response(messages):
"""
Generate a response using the OpenAI API.
Args:
messages (list): Formatted messages for the OpenAI chat completion API
Returns:
str: AI-generated response or error message
"""
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
temperature=0.7,
max_tokens=400,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error generating AI response: {str(e)}")
return f"Error generating response: {str(e)}"
def get_response(prompt, user_id):
"""
Main function to handle the user query and generate a response.
Args:
prompt (str): The user's query
user_id (str): The user's ID for document access
Returns:
tuple: (AI-generated response or error message, turnaround time)
"""
start_time = time.time()
# Display status: Calling inference API
with st.spinner("Gathering insights..."):
data = call_inference_api(prompt, user_id)
if "error" in data:
logger.error(f"Error in inference API: {data['error']}")
response = data["error"]
elif "results" in data:
# Display status: Formatting document chunks
with st.spinner("Formatting document chunks..."):
chunks_text = format_chunks(data["results"])
# Display status: Preparing messages for OpenAI
with st.spinner("Preparing messages for OpenAI..."):
messages = prepare_messages(chunks_text)
# Display status: Generating AI response
with st.spinner("Generating AI response..."):
response = generate_ai_response(messages)
else:
response = data.get("response", "No response received from server")
turnaround_time = time.time() - start_time
return response, turnaround_time
def chat_page():
"""
Render the chat interface page in the Streamlit app.
Uses a hardcoded user_id for document access.
"""
st.title("StartSmart AI")
# Hardcoded user_id
user_id = "670e5a5c-149f-4090-bae6-4c69e4cda540"
logger.info(f"Using hardcoded User ID: {user_id}")
# Initialize a container for the chat messages
chat_container = st.container()
# Chat input that processes on Enter key
user_query = st.chat_input("Type your message here...")
if user_query:
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_query})
# Get assistant response
response, turnaround_time = get_response(user_query, user_id)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response, "turnaround_time": turnaround_time})
# Force a rerun to update the UI immediately
st.rerun()
# Display all chat messages in the container
with chat_container:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
if message["role"] == "assistant" and "turnaround_time" in message:
st.caption(f"Response time: {message['turnaround_time']:.2f} seconds")
def main():
"""
Main function to run the Streamlit application.
"""
init_openai_client()
chat_page()
if __name__ == "__main__":
main()