File size: 5,776 Bytes
927fe6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import streamlit as st
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from groq import Groq
import os
import pypdf
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Initialize session state variables
if "faiss_index" not in st.session_state:
    st.session_state["faiss_index"] = None
if "chunks" not in st.session_state:
    st.session_state["chunks"] = []

# Set Groq API key - Consider using st.secrets for better security
GROQ_API_KEY = os.getenv("GROQ_API_KEY") or st.secrets.get("GROQ_API_KEY", "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")
if not GROQ_API_KEY:
    st.error("⚠️ GROQ_API_KEY is missing! Please set it in your environment variables or secrets.toml file.")
    st.stop()

# Load embedding model with error handling
try:
    embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e:
    st.error(f"❌ Failed to load embedding model: {str(e)}")
    st.stop()

# Set up Groq client with error handling
try:
    client = Groq(api_key=GROQ_API_KEY)
except Exception as e:
    st.error(f"❌ Failed to initialize Groq client: {str(e)}")
    st.stop()

# Function to extract text from PDF with error handling
def extract_text_from_pdf(uploaded_file):
    try:
        reader = pypdf.PdfReader(uploaded_file)
        extracted_text = [page.extract_text() for page in reader.pages if page.extract_text()]
        return "\n".join(extracted_text) if extracted_text else ""
    except Exception as e:
        st.error(f"❌ Error extracting text from PDF: {str(e)}")
        return ""

# Function to create text chunks
def create_chunks(text, chunk_size=500, chunk_overlap=100):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=["\n\n", "\n", " ", ""]  # Added separators for better splitting
    )
    return text_splitter.split_text(text)

# Function to create and save FAISS index
def create_faiss_index(chunks):
    try:
        embeddings = embedding_model.encode(chunks, convert_to_numpy=True)
        
        # Create FAISS index
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
        
        return index, chunks
    except Exception as e:
        st.error(f"❌ Error creating FAISS index: {str(e)}")
        return None, []

# Function to search FAISS
def search_faiss(query, index, chunks, top_k=2):
    if index is None or not chunks:
        return []
    
    try:
        query_embedding = embedding_model.encode([query], convert_to_numpy=True)
        distances, indices = index.search(query_embedding, top_k)
        return [chunks[i] for i in indices[0] if i < len(chunks)]
    except Exception as e:
        st.error(f"❌ Search error: {str(e)}")
        return []

# Function to query Groq with enhanced prompt
def query_groq(query, context=None):
    try:
        prompt = f"""Use the following context to answer the question. 
        If you don't know the answer, say you don't know. Don't make up answers.
        
        Context: {context if context else 'No specific context provided'}
        
        Question: {query}
        
        Answer:"""
        
        chat_completion = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama-3-70b-8192",  # Updated model name
            temperature=0.3,
            max_tokens=1024
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        return f"Error querying Groq: {str(e)}"

# Streamlit UI
st.set_page_config(page_title="RAG Chatbot", page_icon="πŸ€–", layout="wide")
st.title("πŸ“„ RAG-Based Chatbot with FAISS & Groq")

# Sidebar for settings
with st.sidebar:
    st.header("Settings")
    top_k = st.slider("Number of chunks to retrieve", 1, 5, 2)
    chunk_size = st.slider("Chunk size (characters)", 200, 1000, 500)
    chunk_overlap = st.slider("Chunk overlap (characters)", 0, 200, 100)

# Upload PDF
uploaded_file = st.file_uploader("πŸ“€ Upload a PDF file", type="pdf")

if uploaded_file:
    with st.spinner("πŸ”„ Processing PDF..."):
        text = extract_text_from_pdf(uploaded_file)
        if text.strip():
            chunks = create_chunks(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
            
            # Create FAISS index
            index, chunks = create_faiss_index(chunks)
            
            # Store in session state
            st.session_state["faiss_index"] = index
            st.session_state["chunks"] = chunks
            
            st.success(f"βœ… PDF processed successfully! Created {len(chunks)} chunks.")
        else:
            st.error("❌ No text found in the uploaded PDF.")

# Chat interface
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# User query input
if prompt := st.chat_input("πŸ’¬ Ask me something about the document:"):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)
    
    with st.spinner("πŸ”Ž Retrieving response..."):
        retrieved_text = search_faiss(prompt, st.session_state["faiss_index"], st.session_state["chunks"], top_k=top_k)
        context = "\n".join(retrieved_text) if retrieved_text else "No relevant context found."
        
        response = query_groq(prompt, context)
        
        st.session_state.messages.append({"role": "assistant", "content": response})
        with st.chat_message("assistant"):
            st.markdown(response)