RAG / app.py
Talha812's picture
Create app.py
22d793c verified
import os
import io
import requests
import numpy as np
import faiss
import gradio as gr
from groq import Groq
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
# -----------------------
# CONFIG
# -----------------------
# Hugging Face Space will inject GROQ_API_KEY as a secret
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("❌ Missing GROQ_API_KEY. Please set it in Hugging Face Space → Settings → Repository secrets.")
client = Groq(api_key=GROQ_API_KEY)
embedder = SentenceTransformer("all-MiniLM-L6-v2")
faiss_index = None
chunks = []
RELEVANCE_THRESHOLD = 0.28
# -----------------------
# HELPER FUNCTIONS
# -----------------------
def download_drive_file_bytes(drive_url: str) -> bytes:
"""Download bytes from a Google Drive share link."""
file_id = drive_url.split("/d/")[1].split("/")[0]
download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
r = requests.get(download_url, timeout=30)
r.raise_for_status()
return r.content
def pdf_bytes_to_text(pdf_bytes: bytes) -> str:
"""Extract text from PDF bytes."""
try:
reader = PdfReader(io.BytesIO(pdf_bytes))
return "\n".join([page.extract_text() or "" for page in reader.pages])
except Exception:
return ""
def chunk_text(text, chunk_size=250, overlap=50):
words = text.split()
return [" ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size - overlap)]
def build_faiss_index_from_drive_links(drive_links):
global faiss_index, chunks
all_chunks = []
for link in drive_links:
try:
data = download_drive_file_bytes(link)
text = pdf_bytes_to_text(data)
if not text.strip():
try:
text = data.decode("utf-8")
except Exception:
text = ""
all_chunks.extend(chunk_text(text))
except Exception as e:
print(f"[Error] {link}: {e}")
if not all_chunks:
return "❌ No valid text found. Please check your Drive file links."
emb = embedder.encode(all_chunks, convert_to_numpy=True).astype("float32")
emb /= np.linalg.norm(emb, axis=1, keepdims=True)
faiss_index = faiss.IndexFlatIP(emb.shape[1])
faiss_index.add(emb)
chunks = all_chunks
return f"✅ Knowledge base ready! {len(chunks)} chunks indexed."
def retrieve_top_k(query, k=4):
if faiss_index is None:
return [], []
q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
q_emb /= np.linalg.norm(q_emb, axis=1, keepdims=True)
D, I = faiss_index.search(q_emb, k)
retrieved = [chunks[i] for i in I[0]]
return retrieved, D[0]
def ask_groq_with_rag(query):
if not query.strip():
return "<div style='color:#b91c1c;'>Please type a question.</div>"
retrieved, scores = retrieve_top_k(query)
if not retrieved:
return "<div style='color:#b91c1c;'>Knowledge base not initialized.</div>"
if max(scores) < RELEVANCE_THRESHOLD:
return "<div style='border:2px solid #ef4444; border-radius:10px; padding:15px; background:#fff7f7;'><h3 style='color:#b91c1c;'>❌ Sorry, I don’t know — that’s not covered in my knowledge base.</h3></div>"
context = "\n\n---\n\n".join(retrieved)
prompt = f"""
You are a helpful assistant. Use ONLY the following context to answer.
If answer not found in context, say exactly:
"❌ Sorry, I don’t know — that’s not covered in my knowledge base."
Context:
{context}
Question:
{query}
"""
try:
resp = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
)
ans = resp.choices[0].message.content.strip()
except Exception as e:
return f"<div style='color:#b91c1c;'>Error: {e}</div>"
return f"""
<div style="border-radius:12px; padding:16px; background:#faf5ff; border:1px solid #d8b4fe;">
<h3 style="color:#6b21a8;">💡 Answer from Knowledge Base</h3>
<div style="font-size:15px; color:#1f2937; line-height:1.45;">{ans}</div>
</div>
"""
# -----------------------
# INITIALIZE KNOWLEDGE BASE
# -----------------------
DRIVE_LINKS = [
"https://drive.google.com/file/d/1gl_6EAvN5uzTUbir_ytOBUaSmr9pWKNF/view?usp=sharing"
]
status_msg = build_faiss_index_from_drive_links(DRIVE_LINKS)
# -----------------------
# GRADIO UI (HCI-driven)
# -----------------------
css = """
body {
font-family: 'Inter', system-ui, sans-serif;
background: linear-gradient(135deg, #ede9fe, #fce7f3);
}
.gradio-container {max-width: 900px; margin: auto;}
h1 {text-align:center; color:#6d28d9;}
.status-box {
text-align:center;
font-size:16px;
color:#047857;
background:#ecfdf5;
border:1px solid #6ee7b7;
border-radius:8px;
padding:8px;
}
"""
with gr.Blocks(css=css, title="RAG Knowledge Chatbot") as demo:
gr.Markdown("<h1>📚 RAG Knowledge Chatbot</h1>")
gr.Markdown("<p style='text-align:center;color:#7c3aed;'>Ask questions only from my internal knowledge base</p>")
gr.Markdown(f"<div class='status-box'>{status_msg}</div>")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 💭 Ask your question")
query = gr.Textbox(
placeholder="Type your question...",
lines=2,
show_label=False,
)
ask_btn = gr.Button("🚀 Ask", variant="primary")
answer_html = gr.HTML(
"<div style='font-size:15px;color:#374151;'>Answer will appear here...</div>"
)
ask_btn.click(ask_groq_with_rag, inputs=query, outputs=answer_html)
with gr.Column(scale=1):
gr.Markdown("### 📄 Knowledge Base Files")
for link in DRIVE_LINKS:
gr.Markdown(f"- <a href='{link}' target='_blank'>{link}</a>")
demo.launch()