import streamlit as st import logging import os from io import BytesIO import pdfplumber from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from sentence_transformers import SentenceTransformer from transformers import pipeline import re # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # ----------- Load Models ----------- @st.cache_resource(ttl=1800) def load_embeddings_model(): try: return SentenceTransformer("all-MiniLM-L12-v2") except Exception as e: st.error(f"Embedding model error: {str(e)}") return None @st.cache_resource(ttl=1800) def load_qa_pipeline(): try: return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300) except Exception as e: st.error(f"QA model error: {str(e)}") return None @st.cache_resource(ttl=1800) def load_summary_pipeline(): try: return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", max_length=150) except Exception as e: st.error(f"Summary model error: {str(e)}") return None # ----------- PDF Processing ----------- def process_pdf(uploaded_file): text = "" code_blocks = [] try: with pdfplumber.open(BytesIO(uploaded_file.read())) as pdf: for page in pdf.pages[:20]: extracted = page.extract_text(layout=False) if extracted: text += extracted + "\n" for char in page.chars: if 'fontname' in char and 'mono' in char['fontname'].lower(): code_blocks.append(char['text']) code_text_page = page.extract_text() or "" code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text_page, re.MULTILINE) for match in code_matches: code_blocks.append(match.group().strip()) tables = page.extract_tables() if tables: for table in tables: text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n" code_text = "\n".join(code_blocks).strip() text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", ".", " "] ) text_chunks = text_splitter.split_text(text)[:50] code_chunks = text_splitter.split_text(code_text)[:25] if code_text else [] embeddings_model = load_embeddings_model() if not embeddings_model: return None, None, text, code_text text_vectors = [embeddings_model.encode(chunk) for chunk in text_chunks] code_vectors = [embeddings_model.encode(chunk) for chunk in code_chunks] text_vector_store = FAISS.from_embeddings(zip(text_chunks, text_vectors), embeddings_model.encode) if text_chunks else None code_vector_store = FAISS.from_embeddings(zip(code_chunks, code_vectors), embeddings_model.encode) if code_chunks else None return text_vector_store, code_vector_store, text, code_text except Exception as e: st.error(f"PDF error: {str(e)}") return None, None, "", "" # ----------- Preload Dataset ----------- def preload_dataset(): dataset_path = "data" combined_text = "" combined_code = "" text_vector_store = None code_vector_store = None if not os.path.exists(dataset_path): return text_vector_store, code_vector_store, combined_text, combined_code embeddings_model = load_embeddings_model() if not embeddings_model: return text_vector_store, code_vector_store, combined_text, combined_code all_text_chunks = [] all_text_vectors = [] all_code_chunks = [] all_code_vectors = [] for file_name in os.listdir(dataset_path): file_path = os.path.join(dataset_path, file_name) if file_name.lower().endswith(".pdf"): with open(file_path, "rb") as f: t_store, c_store, t_text, c_text = process_pdf(f) combined_text += t_text + "\n" combined_code += c_text + "\n" if t_store: for chunk in t_store.index_to_docstore().values(): all_text_chunks.append(chunk) all_text_vectors.append(embeddings_model.encode(chunk)) if c_store: for chunk in c_store.index_to_docstore().values(): all_code_chunks.append(chunk) all_code_vectors.append(embeddings_model.encode(chunk)) elif file_name.lower().endswith(".txt"): with open(file_path, "r", encoding="utf-8") as f: text_content = f.read() combined_text += text_content + "\n" chunks = text_content.split("\n\n") for chunk in chunks: all_text_chunks.append(chunk) all_text_vectors.append(embeddings_model.encode(chunk)) if all_text_chunks: text_vector_store = FAISS.from_embeddings(zip(all_text_chunks, all_text_vectors), embeddings_model.encode) if all_code_chunks: code_vector_store = FAISS.from_embeddings(zip(all_code_chunks, all_code_vectors), embeddings_model.encode) return text_vector_store, code_vector_store, combined_text, combined_code # ----------- Streamlit UI ----------- st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide") # Fixed CSS for chat colors st.markdown(""" """, unsafe_allow_html=True) st.markdown('

Smart PDF Q&A

', unsafe_allow_html=True) st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'.") # Session state if "messages" not in st.session_state: st.session_state.messages = [] if "text_vector_store" not in st.session_state: st.session_state.text_vector_store = None if "code_vector_store" not in st.session_state: st.session_state.code_vector_store = None if "pdf_text" not in st.session_state: st.session_state.pdf_text = "" if "code_text" not in st.session_state: st.session_state.code_text = "" # Preload dataset at start if st.session_state.text_vector_store is None and st.session_state.code_vector_store is None: st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = preload_dataset() if st.session_state.text_vector_store or st.session_state.code_vector_store: st.info("Preloaded sample dataset loaded for better QA and code retrieval.") # PDF upload & buttons uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"]) col1, col2 = st.columns([1,1]) with col1: if st.button("Process PDF") and uploaded_file: with st.spinner("Processing PDF..."): st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file) if st.session_state.text_vector_store or st.session_state.code_vector_store: st.success("PDF processed! Ask away or summarize.") st.session_state.messages = [] else: st.error("Failed to process PDF.") with col2: if st.button("Summarize PDF") and st.session_state.pdf_text: with st.spinner("Summarizing..."): summary_pipeline = load_summary_pipeline() text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", ".", " "]) chunks = text_splitter.split_text(st.session_state.pdf_text)[:2] summaries = [] for chunk in chunks: summary = summary_pipeline(chunk[:500], max_length=100, min_length=30, do_sample=False)[0]['summary_text'] summaries.append(summary.strip()) combined_summary = " ".join(summaries) st.session_state.messages.append({"role":"assistant","content":combined_summary}) st.markdown(combined_summary) # Chat interface st.markdown('
', unsafe_allow_html=True) prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):") if prompt: st.session_state.messages.append({"role":"user","content":prompt}) with st.chat_message("user"): st.markdown(f"
{prompt}
", unsafe_allow_html=True) with st.chat_message("assistant"): qa_pipeline = load_qa_pipeline() is_code_query = any(k in prompt.lower() for k in ["code","script","function","programming","give me code","show code"]) if is_code_query and st.session_state.code_vector_store: answer = f"Here's the code from the PDF:\n```python\n{st.session_state.code_text}\n```" elif st.session_state.text_vector_store: docs = st.session_state.text_vector_store.similarity_search(prompt, k=5) context = "\n".join(doc.page_content for doc in docs) answer = qa_pipeline(f"Context: {context}\nQuestion: {prompt}\nProvide a detailed answer.")[0]['generated_text'] else: answer = "Please upload a PDF first!" st.markdown(f"
{answer}
", unsafe_allow_html=True) st.session_state.messages.append({"role":"assistant","content":answer}) # Display chat history for msg in st.session_state.messages: cls = "user" if msg["role"]=="user" else "assistant" st.markdown(f"
{msg['content']}
", unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) # Download chat if st.session_state.messages: chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages) st.download_button("Download Chat History", chat_text, "chat_history.txt")