Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import logging | |
| import os | |
| from io import BytesIO | |
| import pdfplumber | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| import re | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # ----------- Load Models ----------- | |
| def load_embeddings_model(): | |
| try: | |
| return SentenceTransformer("all-MiniLM-L12-v2") | |
| except Exception as e: | |
| st.error(f"Embedding model error: {str(e)}") | |
| return None | |
| def load_qa_pipeline(): | |
| try: | |
| return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300) | |
| except Exception as e: | |
| st.error(f"QA model error: {str(e)}") | |
| return None | |
| def load_summary_pipeline(): | |
| try: | |
| return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", max_length=150) | |
| except Exception as e: | |
| st.error(f"Summary model error: {str(e)}") | |
| return None | |
| # ----------- PDF Processing ----------- | |
| def process_pdf(uploaded_file): | |
| text = "" | |
| code_blocks = [] | |
| try: | |
| with pdfplumber.open(BytesIO(uploaded_file.read())) as pdf: | |
| for page in pdf.pages[:20]: | |
| extracted = page.extract_text(layout=False) | |
| if extracted: | |
| text += extracted + "\n" | |
| for char in page.chars: | |
| if 'fontname' in char and 'mono' in char['fontname'].lower(): | |
| code_blocks.append(char['text']) | |
| code_text_page = page.extract_text() or "" | |
| code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text_page, re.MULTILINE) | |
| for match in code_matches: | |
| code_blocks.append(match.group().strip()) | |
| tables = page.extract_tables() | |
| if tables: | |
| for table in tables: | |
| text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n" | |
| code_text = "\n".join(code_blocks).strip() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", ".", " "] | |
| ) | |
| text_chunks = text_splitter.split_text(text)[:50] | |
| code_chunks = text_splitter.split_text(code_text)[:25] if code_text else [] | |
| embeddings_model = load_embeddings_model() | |
| if not embeddings_model: | |
| return None, None, text, code_text | |
| text_vectors = [embeddings_model.encode(chunk) for chunk in text_chunks] | |
| code_vectors = [embeddings_model.encode(chunk) for chunk in code_chunks] | |
| text_vector_store = FAISS.from_embeddings(zip(text_chunks, text_vectors), embeddings_model.encode) if text_chunks else None | |
| code_vector_store = FAISS.from_embeddings(zip(code_chunks, code_vectors), embeddings_model.encode) if code_chunks else None | |
| return text_vector_store, code_vector_store, text, code_text | |
| except Exception as e: | |
| st.error(f"PDF error: {str(e)}") | |
| return None, None, "", "" | |
| # ----------- Preload Dataset ----------- | |
| def preload_dataset(): | |
| dataset_path = "data" | |
| combined_text = "" | |
| combined_code = "" | |
| text_vector_store = None | |
| code_vector_store = None | |
| if not os.path.exists(dataset_path): | |
| return text_vector_store, code_vector_store, combined_text, combined_code | |
| embeddings_model = load_embeddings_model() | |
| if not embeddings_model: | |
| return text_vector_store, code_vector_store, combined_text, combined_code | |
| all_text_chunks = [] | |
| all_text_vectors = [] | |
| all_code_chunks = [] | |
| all_code_vectors = [] | |
| for file_name in os.listdir(dataset_path): | |
| file_path = os.path.join(dataset_path, file_name) | |
| if file_name.lower().endswith(".pdf"): | |
| with open(file_path, "rb") as f: | |
| t_store, c_store, t_text, c_text = process_pdf(f) | |
| combined_text += t_text + "\n" | |
| combined_code += c_text + "\n" | |
| if t_store: | |
| for chunk in t_store.index_to_docstore().values(): | |
| all_text_chunks.append(chunk) | |
| all_text_vectors.append(embeddings_model.encode(chunk)) | |
| if c_store: | |
| for chunk in c_store.index_to_docstore().values(): | |
| all_code_chunks.append(chunk) | |
| all_code_vectors.append(embeddings_model.encode(chunk)) | |
| elif file_name.lower().endswith(".txt"): | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| text_content = f.read() | |
| combined_text += text_content + "\n" | |
| chunks = text_content.split("\n\n") | |
| for chunk in chunks: | |
| all_text_chunks.append(chunk) | |
| all_text_vectors.append(embeddings_model.encode(chunk)) | |
| if all_text_chunks: | |
| text_vector_store = FAISS.from_embeddings(zip(all_text_chunks, all_text_vectors), embeddings_model.encode) | |
| if all_code_chunks: | |
| code_vector_store = FAISS.from_embeddings(zip(all_code_chunks, all_code_vectors), embeddings_model.encode) | |
| return text_vector_store, code_vector_store, combined_text, combined_code | |
| # ----------- Streamlit UI ----------- | |
| st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide") | |
| # Fixed CSS for chat colors | |
| st.markdown(""" | |
| <style> | |
| /* Chat container */ | |
| .chat-container { | |
| border: 1px solid #ddd; | |
| border-radius: 10px; | |
| padding: 10px; | |
| height: 60vh; | |
| overflow-y: auto; | |
| margin-top: 20px; | |
| } | |
| /* Chat bubbles */ | |
| .stChatMessage { | |
| border-radius: 15px; | |
| padding: 10px; | |
| margin: 5px; | |
| max-width: 70%; | |
| word-wrap: break-word; | |
| } | |
| /* User message */ | |
| .user { | |
| background-color: #e6f3ff !important; | |
| color: #000 !important; | |
| align-self: flex-end; | |
| text-align: right; | |
| } | |
| /* Assistant message */ | |
| .assistant { | |
| background-color: #f0f0f0 !important; | |
| color: #000 !important; | |
| text-align: left; | |
| } | |
| /* Dark mode support */ | |
| body[data-theme="dark"] .user { | |
| background-color: #2a2a72 !important; | |
| color: #fff !important; | |
| } | |
| body[data-theme="dark"] .assistant { | |
| background-color: #2e2e2e !important; | |
| color: #fff !important; | |
| } | |
| /* Buttons */ | |
| .stButton>button { | |
| background-color: #4CAF50; | |
| color: white; | |
| border: none; | |
| padding: 8px 16px; | |
| border-radius: 5px; | |
| } | |
| .stButton>button:hover { | |
| background-color: #45a049; | |
| } | |
| /* Preformatted code */ | |
| pre { | |
| background-color: #f8f8f8; | |
| padding: 10px; | |
| border-radius: 5px; | |
| overflow-x: auto; | |
| } | |
| /* Header */ | |
| .header { | |
| background: linear-gradient(90deg, #4CAF50, #81C784); | |
| color: white; | |
| padding: 10px; | |
| border-radius: 5px; | |
| text-align: center; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True) | |
| st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'.") | |
| # Session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "text_vector_store" not in st.session_state: | |
| st.session_state.text_vector_store = None | |
| if "code_vector_store" not in st.session_state: | |
| st.session_state.code_vector_store = None | |
| if "pdf_text" not in st.session_state: | |
| st.session_state.pdf_text = "" | |
| if "code_text" not in st.session_state: | |
| st.session_state.code_text = "" | |
| # Preload dataset at start | |
| if st.session_state.text_vector_store is None and st.session_state.code_vector_store is None: | |
| st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = preload_dataset() | |
| if st.session_state.text_vector_store or st.session_state.code_vector_store: | |
| st.info("Preloaded sample dataset loaded for better QA and code retrieval.") | |
| # PDF upload & buttons | |
| uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"]) | |
| col1, col2 = st.columns([1,1]) | |
| with col1: | |
| if st.button("Process PDF") and uploaded_file: | |
| with st.spinner("Processing PDF..."): | |
| st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file) | |
| if st.session_state.text_vector_store or st.session_state.code_vector_store: | |
| st.success("PDF processed! Ask away or summarize.") | |
| st.session_state.messages = [] | |
| else: | |
| st.error("Failed to process PDF.") | |
| with col2: | |
| if st.button("Summarize PDF") and st.session_state.pdf_text: | |
| with st.spinner("Summarizing..."): | |
| summary_pipeline = load_summary_pipeline() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", ".", " "]) | |
| chunks = text_splitter.split_text(st.session_state.pdf_text)[:2] | |
| summaries = [] | |
| for chunk in chunks: | |
| summary = summary_pipeline(chunk[:500], max_length=100, min_length=30, do_sample=False)[0]['summary_text'] | |
| summaries.append(summary.strip()) | |
| combined_summary = " ".join(summaries) | |
| st.session_state.messages.append({"role":"assistant","content":combined_summary}) | |
| st.markdown(combined_summary) | |
| # Chat interface | |
| st.markdown('<div class="chat-container">', unsafe_allow_html=True) | |
| prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):") | |
| if prompt: | |
| st.session_state.messages.append({"role":"user","content":prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(f"<div class='user'>{prompt}</div>", unsafe_allow_html=True) | |
| with st.chat_message("assistant"): | |
| qa_pipeline = load_qa_pipeline() | |
| is_code_query = any(k in prompt.lower() for k in ["code","script","function","programming","give me code","show code"]) | |
| if is_code_query and st.session_state.code_vector_store: | |
| answer = f"Here's the code from the PDF:\n```python\n{st.session_state.code_text}\n```" | |
| elif st.session_state.text_vector_store: | |
| docs = st.session_state.text_vector_store.similarity_search(prompt, k=5) | |
| context = "\n".join(doc.page_content for doc in docs) | |
| answer = qa_pipeline(f"Context: {context}\nQuestion: {prompt}\nProvide a detailed answer.")[0]['generated_text'] | |
| else: | |
| answer = "Please upload a PDF first!" | |
| st.markdown(f"<div class='assistant'>{answer}</div>", unsafe_allow_html=True) | |
| st.session_state.messages.append({"role":"assistant","content":answer}) | |
| # Display chat history | |
| for msg in st.session_state.messages: | |
| cls = "user" if msg["role"]=="user" else "assistant" | |
| st.markdown(f"<div class='{cls}' style='margin:5px;padding:10px;border-radius:15px;'>{msg['content']}</div>", unsafe_allow_html=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Download chat | |
| if st.session_state.messages: | |
| chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages) | |
| st.download_button("Download Chat History", chat_text, "chat_history.txt") | |