abubakaraabi786's picture
Added API Key (#1)
fb9b7af verified
raw
history blame
4.49 kB
import gradio as gr
import os
import tempfile
from typing import List
import PyPDF2
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
# -----------------------------
# Configuration
# -----------------------------
GROQ_API_KEY = os.getenv("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")
MODEL_NAME = "llama3-8b-8192"
client = Groq(api_key=GROQ_API_KEY)
embedder = SentenceTransformer("all-MiniLM-L6-v2")
# -----------------------------
# PDF Processing
# -----------------------------
def extract_text_from_pdfs(files: List[tempfile.NamedTemporaryFile]):
documents = []
for file in files:
reader = PyPDF2.PdfReader(file)
for page_num, page in enumerate(reader.pages):
text = page.extract_text()
if text:
documents.append({
"text": text,
"source": f"{os.path.basename(file.name)} - page {page_num + 1}"
})
return documents
# -----------------------------
# Chunking
# -----------------------------
def chunk_text(documents, chunk_size=500, overlap=50):
chunks = []
for doc in documents:
text = doc["text"]
words = text.split()
start = 0
while start < len(words):
chunk_words = words[start:start + chunk_size]
chunk_text = " ".join(chunk_words)
chunks.append({
"text": chunk_text,
"source": doc["source"]
})
start += chunk_size - overlap
return chunks
# -----------------------------
# Embeddings & Retrieval
# -----------------------------
def embed_chunks(chunks):
texts = [c["text"] for c in chunks]
embeddings = embedder.encode(texts)
return embeddings
def retrieve_relevant_chunks(query, chunks, embeddings, top_k=3):
query_embedding = embedder.encode([query])
similarities = cosine_similarity(query_embedding, embeddings)[0]
top_indices = np.argsort(similarities)[-top_k:][::-1]
results = []
for idx in top_indices:
results.append(chunks[idx])
return results
# -----------------------------
# LLM Call
# -----------------------------
def ask_llm(question, context, history):
messages = history.copy()
system_prompt = (
"You are a helpful assistant. Answer the question strictly using the provided context. "
"If the answer is not in the context, say so."
)
messages.insert(0, {"role": "system", "content": system_prompt})
messages.append({
"role": "user",
"content": f"Context:\n{context}\n\nQuestion:\n{question}"
})
response = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.3
)
return response.choices[0].message.content
# -----------------------------
# Main Chat Logic
# -----------------------------
def chat(files, question, chat_history):
if not files:
return chat_history, "Please upload PDF files first."
documents = extract_text_from_pdfs(files)
chunks = chunk_text(documents)
embeddings = embed_chunks(chunks)
relevant_chunks = retrieve_relevant_chunks(question, chunks, embeddings)
context = ""
sources = []
for c in relevant_chunks:
context += c["text"] + "\n\n"
sources.append(c["source"])
answer = ask_llm(question, context, chat_history)
answer_with_sources = (
f"{answer}\n\n"
f"Sources:\n" + "\n".join(set(sources))
)
chat_history.append({"role": "user", "content": question})
chat_history.append({"role": "assistant", "content": answer_with_sources})
return chat_history, answer_with_sources
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(title="Enhanced RAG Chatbot") as demo:
gr.Markdown("## 📚 Enhanced RAG-Based Chatbot (PDF QA)")
gr.Markdown("Upload multiple PDFs and ask questions based on their content.")
with gr.Row():
pdf_files = gr.File(
label="Upload PDF Files",
file_types=[".pdf"],
file_count="multiple"
)
chatbot = gr.Chatbot()
question = gr.Textbox(label="Ask a question")
state = gr.State([])
submit = gr.Button("Ask")
submit.click(
fn=chat,
inputs=[pdf_files, question, state],
outputs=[chatbot, chatbot]
)
demo.launch()