Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| import os | |
| import pypdf | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # Initialize session state variables | |
| if "faiss_index" not in st.session_state: | |
| st.session_state["faiss_index"] = None | |
| if "chunks" not in st.session_state: | |
| st.session_state["chunks"] = [] | |
| # Set Groq API key - Consider using st.secrets for better security | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") or st.secrets.get("GROQ_API_KEY", "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") | |
| if not GROQ_API_KEY: | |
| st.error("β οΈ GROQ_API_KEY is missing! Please set it in your environment variables or secrets.toml file.") | |
| st.stop() | |
| # Load embedding model with error handling | |
| try: | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| except Exception as e: | |
| st.error(f"β Failed to load embedding model: {str(e)}") | |
| st.stop() | |
| # Set up Groq client with error handling | |
| try: | |
| client = Groq(api_key=GROQ_API_KEY) | |
| except Exception as e: | |
| st.error(f"β Failed to initialize Groq client: {str(e)}") | |
| st.stop() | |
| # Function to extract text from PDF with error handling | |
| def extract_text_from_pdf(uploaded_file): | |
| try: | |
| reader = pypdf.PdfReader(uploaded_file) | |
| extracted_text = [page.extract_text() for page in reader.pages if page.extract_text()] | |
| return "\n".join(extracted_text) if extracted_text else "" | |
| except Exception as e: | |
| st.error(f"β Error extracting text from PDF: {str(e)}") | |
| return "" | |
| # Function to create text chunks | |
| def create_chunks(text, chunk_size=500, chunk_overlap=100): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| separators=["\n\n", "\n", " ", ""] # Added separators for better splitting | |
| ) | |
| return text_splitter.split_text(text) | |
| # Function to create and save FAISS index | |
| def create_faiss_index(chunks): | |
| try: | |
| embeddings = embedding_model.encode(chunks, convert_to_numpy=True) | |
| # Create FAISS index | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| return index, chunks | |
| except Exception as e: | |
| st.error(f"β Error creating FAISS index: {str(e)}") | |
| return None, [] | |
| # Function to search FAISS | |
| def search_faiss(query, index, chunks, top_k=2): | |
| if index is None or not chunks: | |
| return [] | |
| try: | |
| query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
| distances, indices = index.search(query_embedding, top_k) | |
| return [chunks[i] for i in indices[0] if i < len(chunks)] | |
| except Exception as e: | |
| st.error(f"β Search error: {str(e)}") | |
| return [] | |
| # Function to query Groq with enhanced prompt | |
| def query_groq(query, context=None): | |
| try: | |
| prompt = f"""Use the following context to answer the question. | |
| If you don't know the answer, say you don't know. Don't make up answers. | |
| Context: {context if context else 'No specific context provided'} | |
| Question: {query} | |
| Answer:""" | |
| chat_completion = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama-3-70b-8192", # Updated model name | |
| temperature=0.3, | |
| max_tokens=1024 | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Error querying Groq: {str(e)}" | |
| # Streamlit UI | |
| st.set_page_config(page_title="RAG Chatbot", page_icon="π€", layout="wide") | |
| st.title("π RAG-Based Chatbot with FAISS & Groq") | |
| # Sidebar for settings | |
| with st.sidebar: | |
| st.header("Settings") | |
| top_k = st.slider("Number of chunks to retrieve", 1, 5, 2) | |
| chunk_size = st.slider("Chunk size (characters)", 200, 1000, 500) | |
| chunk_overlap = st.slider("Chunk overlap (characters)", 0, 200, 100) | |
| # Upload PDF | |
| uploaded_file = st.file_uploader("π€ Upload a PDF file", type="pdf") | |
| if uploaded_file: | |
| with st.spinner("π Processing PDF..."): | |
| text = extract_text_from_pdf(uploaded_file) | |
| if text.strip(): | |
| chunks = create_chunks(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| # Create FAISS index | |
| index, chunks = create_faiss_index(chunks) | |
| # Store in session state | |
| st.session_state["faiss_index"] = index | |
| st.session_state["chunks"] = chunks | |
| st.success(f"β PDF processed successfully! Created {len(chunks)} chunks.") | |
| else: | |
| st.error("β No text found in the uploaded PDF.") | |
| # Chat interface | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # User query input | |
| if prompt := st.chat_input("π¬ Ask me something about the document:"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.spinner("π Retrieving response..."): | |
| retrieved_text = search_faiss(prompt, st.session_state["faiss_index"], st.session_state["chunks"], top_k=top_k) | |
| context = "\n".join(retrieved_text) if retrieved_text else "No relevant context found." | |
| response = query_groq(prompt, context) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| with st.chat_message("assistant"): | |
| st.markdown(response) |