study / app.py
anamjafar6's picture
Update app.py
f6deddf verified
import os
import streamlit as st
import numpy as np
from pypdf import PdfReader
from typing import List, Dict
from sentence_transformers import SentenceTransformer
import chromadb
# Try importing Groq client
try:
from groq import Groq
except ImportError:
Groq = None
# -----------------------------
# Utility Functions
# -----------------------------
def load_api_key() -> str:
"""Load the GROQ API key from Hugging Face secrets or env vars."""
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
try:
from huggingface_hub import HfFolder
api_key = HfFolder.get_token()
except Exception:
pass
return api_key
def setup_groq() -> Groq:
"""Initialize Groq client with API key."""
api_key = load_api_key()
if not api_key:
st.error("❌ Missing GROQ_API_KEY in environment or Hugging Face secrets.")
return None
if Groq is None:
st.error("❌ Groq library not installed. Please add `groq` to requirements.txt.")
return None
try:
client = Groq(api_key=api_key)
return client
except Exception as e:
st.error(f"Failed to initialize Groq client: {e}")
return None
@st.cache_resource
def load_embedding_model(model_name: str = "all-MiniLM-L6-v2") -> SentenceTransformer:
"""Load and cache the embedding model."""
return SentenceTransformer(model_name)
def pdf_to_chunks(uploaded_file, chunk_size: int = 500, overlap: int = 50) -> List[Dict]:
"""Convert PDF to overlapping text chunks."""
try:
reader = PdfReader(uploaded_file)
except Exception as e:
st.error(f"Error reading PDF: {e}")
return []
chunks = []
for page_num, page in enumerate(reader.pages, start=1):
try:
text = page.extract_text() or ""
except Exception:
text = ""
if not text.strip():
continue
words = text.split()
for i in range(0, len(words), chunk_size - overlap):
chunk_text = " ".join(words[i:i + chunk_size])
if chunk_text.strip():
chunks.append({
"page_number": page_num,
"text": chunk_text
})
return chunks
def create_vector_database(chunks: List[Dict], embedding_model: SentenceTransformer) -> str:
"""Create a new ChromaDB collection with embeddings and return its name."""
if not chunks:
st.error("No text chunks extracted from PDF.")
return None
client = chromadb.Client()
collection_name = f"pdf_chunks_{np.random.randint(10000)}"
try:
collection = client.create_collection(collection_name)
except Exception as e:
st.error(f"Error creating collection: {e}")
return None
texts = [c["text"] for c in chunks]
ids = [str(i) for i in range(len(chunks))]
# Encode in batches for safety
embeddings = []
batch_size = 64
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
emb = embedding_model.encode(batch)
embeddings.extend(emb.tolist() if hasattr(emb, 'tolist') else list(map(list, emb)))
try:
collection.add(
embeddings=embeddings,
documents=texts,
ids=ids,
metadatas=chunks
)
except Exception as e:
st.error(f"Error adding embeddings: {e}")
return None
# Store only the collection name (not object) in session_state
st.session_state.collection_name = collection_name
return collection_name
def query_vector_database(query: str, embedding_model: SentenceTransformer,
top_k: int = 5) -> List[Dict]:
"""Query ChromaDB for relevant chunks."""
if "collection_name" not in st.session_state:
st.error("No active collection found. Upload and process a PDF first.")
return []
try:
client = chromadb.Client()
collection = client.get_collection(st.session_state.collection_name)
except Exception as e:
st.error(f"Error accessing collection: {e}")
return []
try:
query_embedding = embedding_model.encode([query]).tolist()
except Exception as e:
st.error(f"Error encoding query: {e}")
return []
try:
results = collection.query(
query_embeddings=query_embedding,
n_results=top_k
)
except Exception as e:
st.error(f"Error querying database: {e}")
return []
documents = results.get("documents", [[]])[0]
metadatas = results.get("metadatas", [[]])[0]
dists = results.get("distances", [[]])[0] if "distances" in results else []
relevant_chunks = []
for i, doc in enumerate(documents):
meta = metadatas[i] if i < len(metadatas) else {}
distance = dists[i] if i < len(dists) else None
if distance is None:
similarity = 1.0
elif isinstance(distance, (int, float)) and distance <= 1:
similarity = max(0, 1 - distance)
else:
similarity = float(distance)
relevant_chunks.append({
"text": doc,
"page_number": meta.get("page_number", "N/A"),
"similarity": similarity
})
return relevant_chunks
def generate_answer_with_groq(client, query: str, relevant_chunks: List[Dict]) -> str:
"""Generate answer from Groq LLM using retrieved context."""
try:
context_parts = [f"[Page {c['page_number']}]: {c['text']}" for c in relevant_chunks]
context = "\n\n".join(context_parts) if context_parts else ""
prompt = f"""Based ONLY on the following context from a PDF document, answer the user's question.
Context:
{context}
Question: {query}
Instructions:
- Answer using ONLY the information provided in the context above
- If the context does not contain enough information to answer the question, reply exactly: ❌ Insufficient evidence
- Always include page citations in your answer using the format [Page X]
- Be accurate and concise
- Do not add information not present in the context
Answer:"""
if hasattr(client, "chat") and hasattr(client.chat, "completions"):
chat_resp = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[
{"role": "system", "content": "You are a strict assistant that only uses provided context."},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=500
)
else:
chat_resp = client.create(prompt=prompt, max_tokens=500)
if hasattr(chat_resp, "choices"):
return chat_resp.choices[0].message.content
elif isinstance(chat_resp, dict):
choices = chat_resp.get("choices") or []
if choices:
return choices[0].get("message", {}).get("content") \
or choices[0].get("text") \
or str(choices[0])
return str(chat_resp)
except Exception as e:
return f"Error generating answer: {e}"
# -----------------------------
# Streamlit UI
# -----------------------------
def main():
st.set_page_config(page_title="PDF Chatbot with Groq", layout="wide")
st.title("πŸ“š PDF Chatbot with Groq")
st.sidebar.header("Upload PDF")
uploaded_file = st.sidebar.file_uploader("Choose a PDF file", type="pdf")
if uploaded_file:
if "processed_file" not in st.session_state or \
st.session_state.processed_file != uploaded_file.name:
with st.spinner("Processing PDF..."):
embedding_model = load_embedding_model()
chunks = pdf_to_chunks(uploaded_file)
if not chunks:
st.error("No text extracted from PDF.")
return
collection_name = create_vector_database(chunks, embedding_model)
if collection_name:
st.session_state.processed_file = uploaded_file.name
st.success("PDF processed and vector database created!")
st.sidebar.header("Ask a Question")
query = st.sidebar.text_input("Enter your question:")
if query:
if "collection_name" not in st.session_state:
st.warning("Please upload and process a PDF first.")
else:
embedding_model = load_embedding_model()
groq_client = setup_groq()
if groq_client:
with st.spinner("Generating answer..."):
relevant_chunks = query_vector_database(query, embedding_model)
if not relevant_chunks:
st.error("No relevant chunks found.")
return
answer = generate_answer_with_groq(groq_client, query, relevant_chunks)
st.subheader("Answer:")
st.write(answer)
st.subheader("Relevant Chunks:")
for chunk in relevant_chunks:
st.markdown(
f"**Page {chunk['page_number']} (Score: {chunk['similarity']:.2f})**\n\n"
f"{chunk['text'][:500]}..."
)
if __name__ == "__main__":
main()