Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import tempfile | |
| import torch | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from huggingface_hub import login | |
| from threading import Thread | |
| # --- Page Config & Styling --- | |
| st.set_page_config( | |
| page_title="DocTalk - Chat With PDF", | |
| page_icon="ππ¬", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for polished UI and Footer | |
| st.markdown(""" | |
| <style> | |
| /* Chat styling */ | |
| .stChatInput { | |
| padding-bottom: 1rem; | |
| } | |
| /* Custom Footer */ | |
| .footer { | |
| position: fixed; | |
| left: 0; | |
| bottom: 0; | |
| width: 100%; | |
| background-color: white; | |
| color: #555; | |
| text-align: center; | |
| padding: 10px; | |
| font-size: 14px; | |
| border-top: 1px solid #eee; | |
| z-index: 100; | |
| } | |
| /* Hide Streamlit branding for cleaner look */ | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| /* Adjust sidebar padding for footer */ | |
| [data-testid="stSidebar"] { | |
| padding-bottom: 50px; | |
| } | |
| /* Responsive Design */ | |
| @media (max-width: 768px) { | |
| /* Make sidebar collapsible on mobile */ | |
| [data-testid="stSidebar"] { | |
| width: 100% !important; | |
| } | |
| /* Adjust chat input for mobile */ | |
| .stChatInput { | |
| font-size: 16px !important; | |
| } | |
| /* Better spacing on mobile */ | |
| .block-container { | |
| padding: 1rem !important; | |
| } | |
| /* Footer text smaller on mobile */ | |
| .footer { | |
| font-size: 12px; | |
| padding: 8px; | |
| } | |
| } | |
| @media (max-width: 480px) { | |
| /* Extra small devices */ | |
| h1 { | |
| font-size: 1.5rem !important; | |
| } | |
| .stButton button { | |
| font-size: 14px !important; | |
| } | |
| } | |
| /* Touch-friendly buttons */ | |
| .stButton button { | |
| min-height: 44px; | |
| padding: 0.5rem 1rem; | |
| } | |
| /* Better chat message display on mobile */ | |
| [data-testid="stChatMessage"] { | |
| max-width: 100%; | |
| padding: 0.5rem; | |
| } | |
| /* Animated typing indicator */ | |
| @keyframes blink { | |
| 0%, 49% { opacity: 1; } | |
| 50%, 100% { opacity: 0; } | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { transform: scale(1); opacity: 1; } | |
| 50% { transform: scale(1.2); opacity: 0.7; } | |
| } | |
| @keyframes shimmer { | |
| 0% { background-position: -100% 0; } | |
| 100% { background-position: 100% 0; } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Session State Management --- | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| if 'processing_done' not in st.session_state: | |
| st.session_state.processing_done = False | |
| if 'vector_store' not in st.session_state: | |
| st.session_state.vector_store = None | |
| if 'model' not in st.session_state: | |
| st.session_state.model = None | |
| if 'tokenizer' not in st.session_state: | |
| st.session_state.tokenizer = None | |
| # --- Authentication (Secrets Only) --- | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # --- Model Loading (Cached & Optimized) --- | |
| def load_embedding_model(): | |
| """Load the embedding model once to save time.""" | |
| try: | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| return embeddings | |
| except Exception as e: | |
| st.error(f"Error loading embedding model: {e}") | |
| return None | |
| def load_llm_model(token): | |
| """Load the Gemma LLM once - returns model and tokenizer for streaming.""" | |
| try: | |
| login(token=token) | |
| model_id = "google/gemma-2-2b-it" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
| # Load model to CPU with optimizations | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| token=token | |
| ) | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading LLM: {e}") | |
| return None, None | |
| # --- PDF Processing (Optimized for better accuracy) --- | |
| def process_document(uploaded_file, embedding_model): | |
| """Process PDF and create vector store.""" | |
| try: | |
| # Save temp file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp: | |
| tmp.write(uploaded_file.getvalue()) | |
| tmp_path = tmp.name | |
| # Load & Split with balanced parameters for accuracy | |
| loader = PyPDFLoader(tmp_path) | |
| docs = loader.load() | |
| # Balanced chunking for better accuracy | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=100, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| # Vector Store | |
| vector_store = FAISS.from_documents(chunks, embedding_model) | |
| # Clean up temp file | |
| os.unlink(tmp_path) | |
| return vector_store | |
| except Exception as e: | |
| st.error(f"Error processing PDF: {e}") | |
| return None | |
| def get_relevant_context(vector_store, question): | |
| """Retrieve relevant context from vector store.""" | |
| try: | |
| retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
| docs = retriever.invoke(question) | |
| context = "\n\n".join([doc.page_content for doc in docs]) | |
| return context, docs | |
| except Exception as e: | |
| st.error(f"Error retrieving context: {e}") | |
| return "", [] | |
| def stream_response(model, tokenizer, prompt): | |
| """Generate streaming response from the model.""" | |
| try: | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| # Create streamer | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Generation config optimized for Gemma | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| temperature=0.3, | |
| top_p=0.95, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Start generation in a separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield tokens as they're generated | |
| for text in streamer: | |
| yield text | |
| thread.join() | |
| except Exception as e: | |
| yield f"Error generating response: {e}" | |
| # --- Main Layout --- | |
| # 1. Sidebar Configuration | |
| with st.sidebar: | |
| st.title("Configuration") | |
| st.markdown("---") | |
| if not hf_token: | |
| st.error("π¨ **HF_TOKEN missing!**") | |
| st.info("Go to Space Settings β Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.") | |
| st.stop() | |
| else: | |
| st.success("β Hugging Face Connected") | |
| st.subheader("π Document Upload") | |
| uploaded_file = st.file_uploader("Upload your PDF", type="pdf", help="Upload a PDF document to chat with") | |
| if uploaded_file: | |
| process_btn = st.button("π Process Document", type="primary", use_container_width=True) | |
| if process_btn: | |
| with st.spinner("π§ Analyzing PDF document..."): | |
| # Load models (cached) | |
| model, tokenizer = load_llm_model(hf_token) | |
| embed_model = load_embedding_model() | |
| if model and tokenizer and embed_model: | |
| vector_store = process_document(uploaded_file, embed_model) | |
| if vector_store: | |
| st.session_state.vector_store = vector_store | |
| st.session_state.model = model | |
| st.session_state.tokenizer = tokenizer | |
| st.session_state.processing_done = True | |
| st.success("β Document processed! Start chatting below.") | |
| st.rerun() | |
| else: | |
| st.error("β Failed to process document. Please try again.") | |
| else: | |
| st.error("β Failed to load AI models. Check your token permissions.") | |
| if st.session_state.processing_done: | |
| st.markdown("---") | |
| st.success("β Start Chatting") | |
| st.info(f"π **{uploaded_file.name if uploaded_file else 'Document'}** loaded") | |
| if st.button("ποΈ Clear Chat History", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| if st.button("π Upload New Document", use_container_width=True): | |
| st.session_state.processing_done = False | |
| st.session_state.vector_store = None | |
| st.session_state.messages = [] | |
| st.rerun() | |
| # 2. Main Chat Area | |
| st.title("ππ¬ DocTalk - Chat With PDF") | |
| if st.session_state.processing_done: | |
| # Display Chat History | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # Chat Input | |
| if user_input := st.chat_input("Ask a question about your document..."): | |
| # Add user message | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| # Generate assistant response | |
| with st.chat_message("assistant"): | |
| try: | |
| # Get relevant context | |
| context, source_docs = get_relevant_context(st.session_state.vector_store, user_input) | |
| if not context: | |
| st.warning("β οΈ Could not find relevant information in the document.") | |
| else: | |
| # Build prompt for Gemma | |
| prompt = f"""<start_of_turn>user | |
| Answer the question based strictly on the context below. Be concise and accurate. | |
| Context: {context} | |
| Question: {user_input}<end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| # Stream the response | |
| response_placeholder = st.empty() | |
| full_response = "" | |
| for chunk in stream_response(st.session_state.model, st.session_state.tokenizer, prompt): | |
| full_response += chunk | |
| response_placeholder.markdown(full_response + " <span style='animation: blink 1s infinite; color: #00d4ff; font-weight: bold;'>β</span>", unsafe_allow_html=True) | |
| # Final update without cursor | |
| response_placeholder.markdown(full_response) | |
| # Save to history | |
| st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
| # Show sources | |
| if source_docs: | |
| with st.expander("π View Source Context"): | |
| for i, doc in enumerate(source_docs): | |
| st.markdown(f"**Source {i+1}** (Page {doc.metadata.get('page', 'Unknown')})") | |
| st.caption(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content) | |
| st.markdown("---") | |
| except Exception as e: | |
| st.error(f"β An error occurred: {e}") | |
| st.info("Please try asking your question again or upload a new document.") | |
| else: | |
| # Empty State | |
| st.info("π **Welcome to DocTalk!** Upload a PDF document in the sidebar to begin chatting.") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown("### π€ Upload") | |
| st.markdown("Upload your PDF document using the sidebar") | |
| with col2: | |
| st.markdown("### π Process") | |
| st.markdown("Click 'Process Document' to analyze it") | |
| with col3: | |
| st.markdown("### π¬ Chat") | |
| st.markdown("Ask questions and get instant answers") | |
| st.markdown("---") | |
| # --- Footer --- | |
| st.markdown(""" | |
| <div class="footer"> | |
| Made with β€οΈ using Streamlit and Gemma model, by Tannu Yadav | |
| </div> | |
| """, unsafe_allow_html=True) |