| | import streamlit as st |
| | import pdfplumber |
| | import docx |
| | import os |
| | import re |
| | import numpy as np |
| | import google.generativeai as palm |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | import logging |
| | import time |
| | import uuid |
| | import json |
| | import firebase_admin |
| | from firebase_admin import credentials, firestore |
| | from dotenv import load_dotenv |
| | import chromadb |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s', |
| | handlers=[logging.StreamHandler()] |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | class Config: |
| | CHUNK_WORDS = 300 |
| | EMBEDDING_MODEL = "models/text-embedding-004" |
| | TOP_N = 5 |
| | SYSTEM_PROMPT = ( |
| | "You are a helpful assistant. Answer the question using the provided context below. " |
| | "Answer based on your knowledge if the context given is not enough." |
| | ) |
| | GENERATION_MODEL = "models/gemini-1.5-flash" |
| |
|
| | |
| | def init_firebase(): |
| | """Initialize Firebase with proper credential handling""" |
| | if not firebase_admin._apps: |
| | try: |
| | firebase_cred = os.getenv("FIREBASE_CRED") |
| | if not firebase_cred: |
| | logger.error("Firebase credentials not found in environment variables") |
| | st.error("Firebase configuration is missing. Please check your .env file.") |
| | st.stop() |
| | |
| | cred_dict = json.loads(firebase_cred) |
| | cred = credentials.Certificate(cred_dict) |
| | firebase_admin.initialize_app(cred) |
| | logger.info("Firebase initialized successfully") |
| | |
| | except json.JSONDecodeError: |
| | logger.error("Invalid Firebase credentials format") |
| | st.error("Firebase credentials are invalid. Please check your .env file.") |
| | st.stop() |
| | except Exception as e: |
| | logger.error("Firebase initialization failed", exc_info=True) |
| | st.error("Failed to initialize Firebase. Please contact support.") |
| | st.stop() |
| |
|
| | |
| | def init_chroma(): |
| | """Initialize ChromaDB with proper persistence handling""" |
| | try: |
| | persist_directory = "chroma_db" |
| | os.makedirs(persist_directory, exist_ok=True) |
| | |
| | client = chromadb.PersistentClient(path=persist_directory) |
| | collection = client.get_or_create_collection( |
| | name="document_embeddings", |
| | metadata={"hnsw:space": "cosine"} |
| | ) |
| | logger.info("ChromaDB initialized successfully") |
| | return client, collection |
| | except Exception as e: |
| | logger.error("ChromaDB initialization failed", exc_info=True) |
| | st.error("Failed to initialize ChromaDB. Please check your configuration.") |
| | st.stop() |
| |
|
| | |
| | init_firebase() |
| | fs_client = firestore.client() |
| | chroma_client, embedding_collection = init_chroma() |
| |
|
| | |
| | API_KEY = os.getenv("GOOGLE_API_KEY") |
| | if not API_KEY: |
| | st.error("Google API key is not configured.") |
| | st.stop() |
| | palm.configure(api_key=API_KEY) |
| |
|
| | |
| | @st.cache_data(show_spinner=True) |
| | def generate_embedding_cached(text: str) -> list: |
| | """Generate embeddings with caching""" |
| | logger.info(f"Generating embedding for text: {text[:50]}...") |
| | try: |
| | response = palm.embed_content( |
| | model=Config.EMBEDDING_MODEL, |
| | content=text, |
| | task_type="retrieval_document" |
| | ) |
| | if "embedding" not in response or not response["embedding"]: |
| | logger.error("No embedding returned from API") |
| | return [0.0] * 768 |
| | |
| | embedding = np.array(response["embedding"]) |
| | if embedding.ndim == 2: |
| | embedding = embedding.flatten() |
| | return embedding.tolist() |
| | except Exception as e: |
| | logger.error(f"Embedding generation failed: {e}") |
| | return [0.0] * 768 |
| |
|
| | def extract_text_from_file(uploaded_file) -> str: |
| | """Extract text from various file formats""" |
| | file_name = uploaded_file.name.lower() |
| | |
| | if file_name.endswith(".txt"): |
| | return uploaded_file.read().decode("utf-8") |
| | elif file_name.endswith(".pdf"): |
| | with pdfplumber.open(uploaded_file) as pdf: |
| | return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()]) |
| | elif file_name.endswith(".docx"): |
| | doc = docx.Document(uploaded_file) |
| | return "\n".join([para.text for para in doc.paragraphs]) |
| | else: |
| | raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.") |
| |
|
| | def chunk_text(text: str) -> list[str]: |
| | """Split text into manageable chunks""" |
| | max_words = Config.CHUNK_WORDS |
| | paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] |
| | chunks = [] |
| | current_chunk = "" |
| | current_word_count = 0 |
| | |
| | for paragraph in paragraphs: |
| | para_word_count = len(paragraph.split()) |
| | |
| | if para_word_count > max_words: |
| | if current_chunk: |
| | chunks.append(current_chunk.strip()) |
| | current_chunk = "" |
| | current_word_count = 0 |
| | |
| | sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
| | temp_chunk = "" |
| | temp_word_count = 0 |
| | |
| | for sentence in sentences: |
| | sentence_word_count = len(sentence.split()) |
| | if temp_word_count + sentence_word_count > max_words: |
| | if temp_chunk: |
| | chunks.append(temp_chunk.strip()) |
| | temp_chunk = sentence + " " |
| | temp_word_count = sentence_word_count |
| | else: |
| | temp_chunk += sentence + " " |
| | temp_word_count += sentence_word_count |
| | |
| | if temp_chunk: |
| | chunks.append(temp_chunk.strip()) |
| | else: |
| | if current_word_count + para_word_count > max_words: |
| | if current_chunk: |
| | chunks.append(current_chunk.strip()) |
| | current_chunk = paragraph + "\n\n" |
| | current_word_count = para_word_count |
| | else: |
| | current_chunk += paragraph + "\n\n" |
| | current_word_count += para_word_count |
| | |
| | if current_chunk: |
| | chunks.append(current_chunk.strip()) |
| | return chunks |
| |
|
| | def process_document(uploaded_file) -> None: |
| | """Process document and store in ChromaDB""" |
| | try: |
| | |
| | keys_to_clear = ["document_text", "document_chunks", "document_embeddings"] |
| | for key in keys_to_clear: |
| | st.session_state.pop(key, None) |
| | |
| | |
| | file_text = extract_text_from_file(uploaded_file) |
| | if not file_text.strip(): |
| | st.error("The uploaded file contains no valid text.") |
| | return |
| | |
| | |
| | chunks = chunk_text(file_text) |
| | if not chunks: |
| | st.error("Failed to split text into chunks.") |
| | return |
| | |
| | |
| | embeddings = [] |
| | chunk_ids = [] |
| |
|
| | progress_bar = st.progress(0) |
| | |
| | for i, chunk in enumerate(chunks): |
| | chunk_id = str(uuid.uuid4()) |
| | embedding = generate_embedding_cached(chunk) |
| | |
| | if not any(embedding): |
| | continue |
| | |
| | embeddings.append(embedding) |
| | chunk_ids.append(chunk_id) |
| | progress_bar.progress((i + 1) / len(chunks)) |
| | |
| | if not embeddings: |
| | st.error("Failed to generate valid embeddings for the document.") |
| | return |
| |
|
| | |
| | if embedding_collection is None: |
| | st.error("ChromaDB collection is not initialized.") |
| | return |
| | |
| | |
| | embedding_collection.add( |
| | ids=chunk_ids, |
| | documents=chunks[:len(embeddings)], |
| | embeddings=embeddings, |
| | metadatas=[{"chunk_index": idx} for idx in range(len(embeddings))] |
| | ) |
| | |
| | |
| | st.session_state.update({ |
| | "document_text": file_text, |
| | "document_chunks": chunks[:len(embeddings)], |
| | "document_embeddings": embeddings, |
| | "chunk_ids": chunk_ids |
| | }) |
| | |
| | if not st.session_state.get("doc_processed", False): |
| | st.success("Document processing complete! You can now start chatting.") |
| | st.session_state.doc_processed = True |
| | |
| | except Exception as e: |
| | logger.error(f"Document processing failed: {e}") |
| | st.error(f"An error occurred while processing the document: {e}") |
| |
|
| | def search_query(query: str) -> list[tuple[str, float]]: |
| | """Search for relevant document chunks""" |
| | try: |
| | query_embedding = generate_embedding_cached(query) |
| | |
| | results = embedding_collection.query( |
| | query_embeddings=[query_embedding], |
| | n_results=Config.TOP_N |
| | ) |
| | |
| | results_data = [] |
| | for i, metadata in enumerate(results["metadatas"]): |
| | chunk_index = metadata["chunk_index"] |
| | similarity_score = results["distances"][i] |
| | results_data.append((st.session_state["document_chunks"][chunk_index], similarity_score)) |
| | |
| | return results_data |
| | except Exception as e: |
| | logger.error(f"Search query failed: {e}") |
| | return [] |
| |
|
| | def generate_answer(user_query: str, context: str) -> str: |
| | """Generate answer using Palm API""" |
| | prompt = ( |
| | f"System: {Config.SYSTEM_PROMPT}\n\n" |
| | f"Context:\n{context}\n\n" |
| | f"User: {user_query}\nAssistant:" |
| | ) |
| | try: |
| | model = palm.GenerativeModel(Config.GENERATION_MODEL) |
| | response = model.generate_content(prompt) |
| | return response.text if hasattr(response, "text") else response |
| | except Exception as e: |
| | logger.error(f"Answer generation failed: {e}") |
| | return "I'm sorry, I encountered an error generating a response." |
| |
|
| | |
| | def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None): |
| | """Save conversation to Firestore""" |
| | conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations") |
| | data = { |
| | "user_question": user_question, |
| | "assistant_answer": assistant_answer, |
| | "feedback": feedback, |
| | "timestamp": firestore.SERVER_TIMESTAMP |
| | } |
| | doc_ref = conv_ref.add(data) |
| | return doc_ref[1].id |
| |
|
| | def update_feedback_in_firestore(session_id, conversation_id, feedback): |
| | """Update feedback in Firestore""" |
| | conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id) |
| | conv_doc.update({"feedback": feedback}) |
| |
|
| | def handle_feedback(feedback_val): |
| | """Handle user feedback""" |
| | update_feedback_in_firestore( |
| | st.session_state.session_id, |
| | st.session_state.latest_conversation_id, |
| | feedback_val |
| | ) |
| | st.session_state.conversations[-1]["feedback"] = feedback_val |
| |
|
| | |
| | def chat_app(): |
| | """Main chat interface""" |
| | if "conversations" not in st.session_state: |
| | st.session_state.conversations = [] |
| | if "session_id" not in st.session_state: |
| | st.session_state.session_id = str(uuid.uuid4()) |
| |
|
| | |
| | for conv in st.session_state.conversations: |
| | with st.chat_message("user"): |
| | st.write(conv["user_question"]) |
| | with st.chat_message("assistant"): |
| | st.write(conv["assistant_answer"]) |
| | if conv.get("feedback"): |
| | st.markdown(f"**Feedback:** {conv['feedback']}") |
| |
|
| | |
| | user_input = st.chat_input("Type your message here") |
| | if user_input: |
| | with st.chat_message("user"): |
| | st.write(user_input) |
| | |
| | results = search_query(user_input) |
| | context = "\n\n".join([chunk for chunk, score in results]) if results else "" |
| | answer = generate_answer(user_input, context) |
| | |
| | with st.chat_message("assistant"): |
| | st.write(answer) |
| | |
| | conversation_id = save_conversation_to_firestore( |
| | st.session_state.session_id, |
| | user_question=user_input, |
| | assistant_answer=answer |
| | ) |
| | st.session_state.latest_conversation_id = conversation_id |
| | st.session_state.conversations.append({ |
| | "user_question": user_input, |
| | "assistant_answer": answer, |
| | }) |
| | |
| | |
| | if "feedback" not in st.session_state.conversations[-1]: |
| | col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10) |
| | col1.button("๐", key=f"feedback_like_{len(st.session_state.conversations)}", |
| | on_click=handle_feedback, args=("positive",)) |
| | col2.button("๐", key=f"feedback_dislike_{len(st.session_state.conversations)}", |
| | on_click=handle_feedback, args=("negative",)) |
| |
|
| | def main(): |
| | """Main application""" |
| | st.title("Chat with your files") |
| | |
| | |
| | st.sidebar.header("Upload Document") |
| | uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"]) |
| | |
| | if uploaded_file and not st.session_state.get("doc_processed", False): |
| | process_document(uploaded_file) |
| | |
| | if "document_text" in st.session_state: |
| | chat_app() |
| | else: |
| | st.info("Please upload and process a document from the sidebar to start chatting.") |
| | |
| | |
| | st.markdown( |
| | """ |
| | <div style="position: fixed; right: 10px; bottom: 10px; font-size: 12px; z-index: 9999; text-align: right;"> |
| | Made by Danny.<br> |
| | Your questions, our response as well as your feedback will be saved for evaluation purposes. |
| | </div> |
| | """, |
| | unsafe_allow_html=True |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | main() |