import streamlit as st import os import tempfile import torch from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from huggingface_hub import login from threading import Thread # --- Page Config & Styling --- st.set_page_config( page_title="DocTalk - Chat With PDF", page_icon="📗💬", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for polished UI and Footer st.markdown(""" """, unsafe_allow_html=True) # --- Session State Management --- if 'messages' not in st.session_state: st.session_state.messages = [] if 'processing_done' not in st.session_state: st.session_state.processing_done = False if 'vector_store' not in st.session_state: st.session_state.vector_store = None if 'model' not in st.session_state: st.session_state.model = None if 'tokenizer' not in st.session_state: st.session_state.tokenizer = None # --- Authentication (Secrets Only) --- hf_token = os.environ.get("HF_TOKEN") # --- Model Loading (Cached & Optimized) --- @st.cache_resource def load_embedding_model(): """Load the embedding model once to save time.""" try: embeddings = HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True} ) return embeddings except Exception as e: st.error(f"Error loading embedding model: {e}") return None @st.cache_resource def load_llm_model(token): """Load the Gemma LLM once - returns model and tokenizer for streaming.""" try: login(token=token) model_id = "google/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) # Load model to CPU with optimizations model = AutoModelForCausalLM.from_pretrained( model_id, device_map="cpu", torch_dtype=torch.float32, low_cpu_mem_usage=True, token=token ) return model, tokenizer except Exception as e: st.error(f"Error loading LLM: {e}") return None, None # --- PDF Processing (Optimized for better accuracy) --- def process_document(uploaded_file, embedding_model): """Process PDF and create vector store.""" try: # Save temp file with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp: tmp.write(uploaded_file.getvalue()) tmp_path = tmp.name # Load & Split with balanced parameters for accuracy loader = PyPDFLoader(tmp_path) docs = loader.load() # Balanced chunking for better accuracy splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100, separators=["\n\n", "\n", " ", ""] ) chunks = splitter.split_documents(docs) # Vector Store vector_store = FAISS.from_documents(chunks, embedding_model) # Clean up temp file os.unlink(tmp_path) return vector_store except Exception as e: st.error(f"Error processing PDF: {e}") return None def get_relevant_context(vector_store, question): """Retrieve relevant context from vector store.""" try: retriever = vector_store.as_retriever(search_kwargs={"k": 3}) docs = retriever.invoke(question) context = "\n\n".join([doc.page_content for doc in docs]) return context, docs except Exception as e: st.error(f"Error retrieving context: {e}") return "", [] def stream_response(model, tokenizer, prompt): """Generate streaming response from the model.""" try: # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) # Create streamer streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Generation config optimized for Gemma generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=512, temperature=0.3, top_p=0.95, repetition_penalty=1.1, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield tokens as they're generated for text in streamer: yield text thread.join() except Exception as e: yield f"Error generating response: {e}" # --- Main Layout --- # 1. Sidebar Configuration with st.sidebar: st.title("Configuration") st.markdown("---") if not hf_token: st.error("🚨 **HF_TOKEN missing!**") st.info("Go to Space Settings → Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.") st.stop() else: st.success("✅ Hugging Face Connected") st.subheader("📄 Document Upload") uploaded_file = st.file_uploader("Upload your PDF", type="pdf", help="Upload a PDF document to chat with") if uploaded_file: process_btn = st.button("🚀 Process Document", type="primary", use_container_width=True) if process_btn: with st.spinner("🧠 Analyzing PDF document..."): # Load models (cached) model, tokenizer = load_llm_model(hf_token) embed_model = load_embedding_model() if model and tokenizer and embed_model: vector_store = process_document(uploaded_file, embed_model) if vector_store: st.session_state.vector_store = vector_store st.session_state.model = model st.session_state.tokenizer = tokenizer st.session_state.processing_done = True st.success("✅ Document processed! Start chatting below.") st.rerun() else: st.error("❌ Failed to process document. Please try again.") else: st.error("❌ Failed to load AI models. Check your token permissions.") if st.session_state.processing_done: st.markdown("---") st.success("✅ Start Chatting") st.info(f"📄 **{uploaded_file.name if uploaded_file else 'Document'}** loaded") if st.button("🗑️ Clear Chat History", use_container_width=True): st.session_state.messages = [] st.rerun() if st.button("🔄 Upload New Document", use_container_width=True): st.session_state.processing_done = False st.session_state.vector_store = None st.session_state.messages = [] st.rerun() # 2. Main Chat Area st.title("📗💬 DocTalk - Chat With PDF") if st.session_state.processing_done: # Display Chat History for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # Chat Input if user_input := st.chat_input("Ask a question about your document..."): # Add user message st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.markdown(user_input) # Generate assistant response with st.chat_message("assistant"): try: # Get relevant context context, source_docs = get_relevant_context(st.session_state.vector_store, user_input) if not context: st.warning("⚠️ Could not find relevant information in the document.") else: # Build prompt for Gemma prompt = f"""user Answer the question based strictly on the context below. Be concise and accurate. Context: {context} Question: {user_input} model """ # Stream the response response_placeholder = st.empty() full_response = "" for chunk in stream_response(st.session_state.model, st.session_state.tokenizer, prompt): full_response += chunk response_placeholder.markdown(full_response + " ", unsafe_allow_html=True) # Final update without cursor response_placeholder.markdown(full_response) # Save to history st.session_state.messages.append({"role": "assistant", "content": full_response}) # Show sources if source_docs: with st.expander("🔎 View Source Context"): for i, doc in enumerate(source_docs): st.markdown(f"**Source {i+1}** (Page {doc.metadata.get('page', 'Unknown')})") st.caption(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content) st.markdown("---") except Exception as e: st.error(f"❌ An error occurred: {e}") st.info("Please try asking your question again or upload a new document.") else: # Empty State st.info("👋 **Welcome to DocTalk!** Upload a PDF document in the sidebar to begin chatting.") col1, col2, col3 = st.columns(3) with col1: st.markdown("### 📤 Upload") st.markdown("Upload your PDF document using the sidebar") with col2: st.markdown("### 🔄 Process") st.markdown("Click 'Process Document' to analyze it") with col3: st.markdown("### 💬 Chat") st.markdown("Ask questions and get instant answers") st.markdown("---") # --- Footer --- st.markdown(""" """, unsafe_allow_html=True)