Spaces:
Sleeping
Sleeping
File size: 5,776 Bytes
927fe6a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import streamlit as st
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from groq import Groq
import os
import pypdf
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Initialize session state variables
if "faiss_index" not in st.session_state:
st.session_state["faiss_index"] = None
if "chunks" not in st.session_state:
st.session_state["chunks"] = []
# Set Groq API key - Consider using st.secrets for better security
GROQ_API_KEY = os.getenv("GROQ_API_KEY") or st.secrets.get("GROQ_API_KEY", "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")
if not GROQ_API_KEY:
st.error("β οΈ GROQ_API_KEY is missing! Please set it in your environment variables or secrets.toml file.")
st.stop()
# Load embedding model with error handling
try:
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e:
st.error(f"β Failed to load embedding model: {str(e)}")
st.stop()
# Set up Groq client with error handling
try:
client = Groq(api_key=GROQ_API_KEY)
except Exception as e:
st.error(f"β Failed to initialize Groq client: {str(e)}")
st.stop()
# Function to extract text from PDF with error handling
def extract_text_from_pdf(uploaded_file):
try:
reader = pypdf.PdfReader(uploaded_file)
extracted_text = [page.extract_text() for page in reader.pages if page.extract_text()]
return "\n".join(extracted_text) if extracted_text else ""
except Exception as e:
st.error(f"β Error extracting text from PDF: {str(e)}")
return ""
# Function to create text chunks
def create_chunks(text, chunk_size=500, chunk_overlap=100):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " ", ""] # Added separators for better splitting
)
return text_splitter.split_text(text)
# Function to create and save FAISS index
def create_faiss_index(chunks):
try:
embeddings = embedding_model.encode(chunks, convert_to_numpy=True)
# Create FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index, chunks
except Exception as e:
st.error(f"β Error creating FAISS index: {str(e)}")
return None, []
# Function to search FAISS
def search_faiss(query, index, chunks, top_k=2):
if index is None or not chunks:
return []
try:
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, top_k)
return [chunks[i] for i in indices[0] if i < len(chunks)]
except Exception as e:
st.error(f"β Search error: {str(e)}")
return []
# Function to query Groq with enhanced prompt
def query_groq(query, context=None):
try:
prompt = f"""Use the following context to answer the question.
If you don't know the answer, say you don't know. Don't make up answers.
Context: {context if context else 'No specific context provided'}
Question: {query}
Answer:"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3-70b-8192", # Updated model name
temperature=0.3,
max_tokens=1024
)
return chat_completion.choices[0].message.content
except Exception as e:
return f"Error querying Groq: {str(e)}"
# Streamlit UI
st.set_page_config(page_title="RAG Chatbot", page_icon="π€", layout="wide")
st.title("π RAG-Based Chatbot with FAISS & Groq")
# Sidebar for settings
with st.sidebar:
st.header("Settings")
top_k = st.slider("Number of chunks to retrieve", 1, 5, 2)
chunk_size = st.slider("Chunk size (characters)", 200, 1000, 500)
chunk_overlap = st.slider("Chunk overlap (characters)", 0, 200, 100)
# Upload PDF
uploaded_file = st.file_uploader("π€ Upload a PDF file", type="pdf")
if uploaded_file:
with st.spinner("π Processing PDF..."):
text = extract_text_from_pdf(uploaded_file)
if text.strip():
chunks = create_chunks(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# Create FAISS index
index, chunks = create_faiss_index(chunks)
# Store in session state
st.session_state["faiss_index"] = index
st.session_state["chunks"] = chunks
st.success(f"β
PDF processed successfully! Created {len(chunks)} chunks.")
else:
st.error("β No text found in the uploaded PDF.")
# Chat interface
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User query input
if prompt := st.chat_input("π¬ Ask me something about the document:"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.spinner("π Retrieving response..."):
retrieved_text = search_faiss(prompt, st.session_state["faiss_index"], st.session_state["chunks"], top_k=top_k)
context = "\n".join(retrieved_text) if retrieved_text else "No relevant context found."
response = query_groq(prompt, context)
st.session_state.messages.append({"role": "assistant", "content": response})
with st.chat_message("assistant"):
st.markdown(response) |