Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import numpy as np | |
| from pypdf import PdfReader | |
| from typing import List, Dict | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| # Try importing Groq client | |
| try: | |
| from groq import Groq | |
| except ImportError: | |
| Groq = None | |
| # ----------------------------- | |
| # Utility Functions | |
| # ----------------------------- | |
| def load_api_key() -> str: | |
| """Load the GROQ API key from Hugging Face secrets or env vars.""" | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| try: | |
| from huggingface_hub import HfFolder | |
| api_key = HfFolder.get_token() | |
| except Exception: | |
| pass | |
| return api_key | |
| def setup_groq() -> Groq: | |
| """Initialize Groq client with API key.""" | |
| api_key = load_api_key() | |
| if not api_key: | |
| st.error("β Missing GROQ_API_KEY in environment or Hugging Face secrets.") | |
| return None | |
| if Groq is None: | |
| st.error("β Groq library not installed. Please add `groq` to requirements.txt.") | |
| return None | |
| try: | |
| client = Groq(api_key=api_key) | |
| return client | |
| except Exception as e: | |
| st.error(f"Failed to initialize Groq client: {e}") | |
| return None | |
| def load_embedding_model(model_name: str = "all-MiniLM-L6-v2") -> SentenceTransformer: | |
| """Load and cache the embedding model.""" | |
| return SentenceTransformer(model_name) | |
| def pdf_to_chunks(uploaded_file, chunk_size: int = 500, overlap: int = 50) -> List[Dict]: | |
| """Convert PDF to overlapping text chunks.""" | |
| try: | |
| reader = PdfReader(uploaded_file) | |
| except Exception as e: | |
| st.error(f"Error reading PDF: {e}") | |
| return [] | |
| chunks = [] | |
| for page_num, page in enumerate(reader.pages, start=1): | |
| try: | |
| text = page.extract_text() or "" | |
| except Exception: | |
| text = "" | |
| if not text.strip(): | |
| continue | |
| words = text.split() | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk_text = " ".join(words[i:i + chunk_size]) | |
| if chunk_text.strip(): | |
| chunks.append({ | |
| "page_number": page_num, | |
| "text": chunk_text | |
| }) | |
| return chunks | |
| def create_vector_database(chunks: List[Dict], embedding_model: SentenceTransformer) -> str: | |
| """Create a new ChromaDB collection with embeddings and return its name.""" | |
| if not chunks: | |
| st.error("No text chunks extracted from PDF.") | |
| return None | |
| client = chromadb.Client() | |
| collection_name = f"pdf_chunks_{np.random.randint(10000)}" | |
| try: | |
| collection = client.create_collection(collection_name) | |
| except Exception as e: | |
| st.error(f"Error creating collection: {e}") | |
| return None | |
| texts = [c["text"] for c in chunks] | |
| ids = [str(i) for i in range(len(chunks))] | |
| # Encode in batches for safety | |
| embeddings = [] | |
| batch_size = 64 | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| emb = embedding_model.encode(batch) | |
| embeddings.extend(emb.tolist() if hasattr(emb, 'tolist') else list(map(list, emb))) | |
| try: | |
| collection.add( | |
| embeddings=embeddings, | |
| documents=texts, | |
| ids=ids, | |
| metadatas=chunks | |
| ) | |
| except Exception as e: | |
| st.error(f"Error adding embeddings: {e}") | |
| return None | |
| # Store only the collection name (not object) in session_state | |
| st.session_state.collection_name = collection_name | |
| return collection_name | |
| def query_vector_database(query: str, embedding_model: SentenceTransformer, | |
| top_k: int = 5) -> List[Dict]: | |
| """Query ChromaDB for relevant chunks.""" | |
| if "collection_name" not in st.session_state: | |
| st.error("No active collection found. Upload and process a PDF first.") | |
| return [] | |
| try: | |
| client = chromadb.Client() | |
| collection = client.get_collection(st.session_state.collection_name) | |
| except Exception as e: | |
| st.error(f"Error accessing collection: {e}") | |
| return [] | |
| try: | |
| query_embedding = embedding_model.encode([query]).tolist() | |
| except Exception as e: | |
| st.error(f"Error encoding query: {e}") | |
| return [] | |
| try: | |
| results = collection.query( | |
| query_embeddings=query_embedding, | |
| n_results=top_k | |
| ) | |
| except Exception as e: | |
| st.error(f"Error querying database: {e}") | |
| return [] | |
| documents = results.get("documents", [[]])[0] | |
| metadatas = results.get("metadatas", [[]])[0] | |
| dists = results.get("distances", [[]])[0] if "distances" in results else [] | |
| relevant_chunks = [] | |
| for i, doc in enumerate(documents): | |
| meta = metadatas[i] if i < len(metadatas) else {} | |
| distance = dists[i] if i < len(dists) else None | |
| if distance is None: | |
| similarity = 1.0 | |
| elif isinstance(distance, (int, float)) and distance <= 1: | |
| similarity = max(0, 1 - distance) | |
| else: | |
| similarity = float(distance) | |
| relevant_chunks.append({ | |
| "text": doc, | |
| "page_number": meta.get("page_number", "N/A"), | |
| "similarity": similarity | |
| }) | |
| return relevant_chunks | |
| def generate_answer_with_groq(client, query: str, relevant_chunks: List[Dict]) -> str: | |
| """Generate answer from Groq LLM using retrieved context.""" | |
| try: | |
| context_parts = [f"[Page {c['page_number']}]: {c['text']}" for c in relevant_chunks] | |
| context = "\n\n".join(context_parts) if context_parts else "" | |
| prompt = f"""Based ONLY on the following context from a PDF document, answer the user's question. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Instructions: | |
| - Answer using ONLY the information provided in the context above | |
| - If the context does not contain enough information to answer the question, reply exactly: β Insufficient evidence | |
| - Always include page citations in your answer using the format [Page X] | |
| - Be accurate and concise | |
| - Do not add information not present in the context | |
| Answer:""" | |
| if hasattr(client, "chat") and hasattr(client.chat, "completions"): | |
| chat_resp = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[ | |
| {"role": "system", "content": "You are a strict assistant that only uses provided context."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.1, | |
| max_tokens=500 | |
| ) | |
| else: | |
| chat_resp = client.create(prompt=prompt, max_tokens=500) | |
| if hasattr(chat_resp, "choices"): | |
| return chat_resp.choices[0].message.content | |
| elif isinstance(chat_resp, dict): | |
| choices = chat_resp.get("choices") or [] | |
| if choices: | |
| return choices[0].get("message", {}).get("content") \ | |
| or choices[0].get("text") \ | |
| or str(choices[0]) | |
| return str(chat_resp) | |
| except Exception as e: | |
| return f"Error generating answer: {e}" | |
| # ----------------------------- | |
| # Streamlit UI | |
| # ----------------------------- | |
| def main(): | |
| st.set_page_config(page_title="PDF Chatbot with Groq", layout="wide") | |
| st.title("π PDF Chatbot with Groq") | |
| st.sidebar.header("Upload PDF") | |
| uploaded_file = st.sidebar.file_uploader("Choose a PDF file", type="pdf") | |
| if uploaded_file: | |
| if "processed_file" not in st.session_state or \ | |
| st.session_state.processed_file != uploaded_file.name: | |
| with st.spinner("Processing PDF..."): | |
| embedding_model = load_embedding_model() | |
| chunks = pdf_to_chunks(uploaded_file) | |
| if not chunks: | |
| st.error("No text extracted from PDF.") | |
| return | |
| collection_name = create_vector_database(chunks, embedding_model) | |
| if collection_name: | |
| st.session_state.processed_file = uploaded_file.name | |
| st.success("PDF processed and vector database created!") | |
| st.sidebar.header("Ask a Question") | |
| query = st.sidebar.text_input("Enter your question:") | |
| if query: | |
| if "collection_name" not in st.session_state: | |
| st.warning("Please upload and process a PDF first.") | |
| else: | |
| embedding_model = load_embedding_model() | |
| groq_client = setup_groq() | |
| if groq_client: | |
| with st.spinner("Generating answer..."): | |
| relevant_chunks = query_vector_database(query, embedding_model) | |
| if not relevant_chunks: | |
| st.error("No relevant chunks found.") | |
| return | |
| answer = generate_answer_with_groq(groq_client, query, relevant_chunks) | |
| st.subheader("Answer:") | |
| st.write(answer) | |
| st.subheader("Relevant Chunks:") | |
| for chunk in relevant_chunks: | |
| st.markdown( | |
| f"**Page {chunk['page_number']} (Score: {chunk['similarity']:.2f})**\n\n" | |
| f"{chunk['text'][:500]}..." | |
| ) | |
| if __name__ == "__main__": | |
| main() | |