import streamlit as st
import os
import tempfile
import torch
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
from threading import Thread
# --- Page Config & Styling ---
st.set_page_config(
page_title="DocTalk - Chat With PDF",
page_icon="📗💬",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for polished UI and Footer
st.markdown("""
""", unsafe_allow_html=True)
# --- Session State Management ---
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'processing_done' not in st.session_state:
st.session_state.processing_done = False
if 'vector_store' not in st.session_state:
st.session_state.vector_store = None
if 'model' not in st.session_state:
st.session_state.model = None
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer = None
# --- Authentication (Secrets Only) ---
hf_token = os.environ.get("HF_TOKEN")
# --- Model Loading (Cached & Optimized) ---
@st.cache_resource
def load_embedding_model():
"""Load the embedding model once to save time."""
try:
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
return embeddings
except Exception as e:
st.error(f"Error loading embedding model: {e}")
return None
@st.cache_resource
def load_llm_model(token):
"""Load the Gemma LLM once - returns model and tokenizer for streaming."""
try:
login(token=token)
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
# Load model to CPU with optimizations
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
token=token
)
return model, tokenizer
except Exception as e:
st.error(f"Error loading LLM: {e}")
return None, None
# --- PDF Processing (Optimized for better accuracy) ---
def process_document(uploaded_file, embedding_model):
"""Process PDF and create vector store."""
try:
# Save temp file
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
# Load & Split with balanced parameters for accuracy
loader = PyPDFLoader(tmp_path)
docs = loader.load()
# Balanced chunking for better accuracy
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
separators=["\n\n", "\n", " ", ""]
)
chunks = splitter.split_documents(docs)
# Vector Store
vector_store = FAISS.from_documents(chunks, embedding_model)
# Clean up temp file
os.unlink(tmp_path)
return vector_store
except Exception as e:
st.error(f"Error processing PDF: {e}")
return None
def get_relevant_context(vector_store, question):
"""Retrieve relevant context from vector store."""
try:
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
docs = retriever.invoke(question)
context = "\n\n".join([doc.page_content for doc in docs])
return context, docs
except Exception as e:
st.error(f"Error retrieving context: {e}")
return "", []
def stream_response(model, tokenizer, prompt):
"""Generate streaming response from the model."""
try:
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
# Create streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generation config optimized for Gemma
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=512,
temperature=0.3,
top_p=0.95,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they're generated
for text in streamer:
yield text
thread.join()
except Exception as e:
yield f"Error generating response: {e}"
# --- Main Layout ---
# 1. Sidebar Configuration
with st.sidebar:
st.title("Configuration")
st.markdown("---")
if not hf_token:
st.error("🚨 **HF_TOKEN missing!**")
st.info("Go to Space Settings → Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.")
st.stop()
else:
st.success("✅ Hugging Face Connected")
st.subheader("📄 Document Upload")
uploaded_file = st.file_uploader("Upload your PDF", type="pdf", help="Upload a PDF document to chat with")
if uploaded_file:
process_btn = st.button("🚀 Process Document", type="primary", use_container_width=True)
if process_btn:
with st.spinner("🧠 Analyzing PDF document..."):
# Load models (cached)
model, tokenizer = load_llm_model(hf_token)
embed_model = load_embedding_model()
if model and tokenizer and embed_model:
vector_store = process_document(uploaded_file, embed_model)
if vector_store:
st.session_state.vector_store = vector_store
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.processing_done = True
st.success("✅ Document processed! Start chatting below.")
st.rerun()
else:
st.error("❌ Failed to process document. Please try again.")
else:
st.error("❌ Failed to load AI models. Check your token permissions.")
if st.session_state.processing_done:
st.markdown("---")
st.success("✅ Start Chatting")
st.info(f"📄 **{uploaded_file.name if uploaded_file else 'Document'}** loaded")
if st.button("🗑️ Clear Chat History", use_container_width=True):
st.session_state.messages = []
st.rerun()
if st.button("🔄 Upload New Document", use_container_width=True):
st.session_state.processing_done = False
st.session_state.vector_store = None
st.session_state.messages = []
st.rerun()
# 2. Main Chat Area
st.title("📗💬 DocTalk - Chat With PDF")
if st.session_state.processing_done:
# Display Chat History
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Chat Input
if user_input := st.chat_input("Ask a question about your document..."):
# Add user message
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
# Generate assistant response
with st.chat_message("assistant"):
try:
# Get relevant context
context, source_docs = get_relevant_context(st.session_state.vector_store, user_input)
if not context:
st.warning("⚠️ Could not find relevant information in the document.")
else:
# Build prompt for Gemma
prompt = f"""user
Answer the question based strictly on the context below. Be concise and accurate.
Context: {context}
Question: {user_input}
model
"""
# Stream the response
response_placeholder = st.empty()
full_response = ""
for chunk in stream_response(st.session_state.model, st.session_state.tokenizer, prompt):
full_response += chunk
response_placeholder.markdown(full_response + " ✍", unsafe_allow_html=True)
# Final update without cursor
response_placeholder.markdown(full_response)
# Save to history
st.session_state.messages.append({"role": "assistant", "content": full_response})
# Show sources
if source_docs:
with st.expander("🔎 View Source Context"):
for i, doc in enumerate(source_docs):
st.markdown(f"**Source {i+1}** (Page {doc.metadata.get('page', 'Unknown')})")
st.caption(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
st.markdown("---")
except Exception as e:
st.error(f"❌ An error occurred: {e}")
st.info("Please try asking your question again or upload a new document.")
else:
# Empty State
st.info("👋 **Welcome to DocTalk!** Upload a PDF document in the sidebar to begin chatting.")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("### 📤 Upload")
st.markdown("Upload your PDF document using the sidebar")
with col2:
st.markdown("### 🔄 Process")
st.markdown("Click 'Process Document' to analyze it")
with col3:
st.markdown("### 💬 Chat")
st.markdown("Ask questions and get instant answers")
st.markdown("---")
# --- Footer ---
st.markdown("""
""", unsafe_allow_html=True)