Navya-Sree's picture
Create app.py
b8c6e8d verified
import os, glob, pickle, gc
from typing import List, Dict, Tuple
import gradio as gr
import numpy as np
from tqdm import tqdm
from pypdf import PdfReader
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# --------- CONFIG ----------
EMB_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # 384-dim, fast on CPU
LLM_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # small chat model
CHUNK_SIZE = 700
CHUNK_OVERLAP = 120
TOP_K = 4
# Cached globals (Space stays warm between requests)
_emb_model = None
_llm_tokenizer = None
_llm_model = None
_faiss_index = None
_meta = None
# --------- HELPERS ----------
def _load_pdf(path: str) -> str:
text = []
try:
pdf = PdfReader(path)
for p in pdf.pages:
text.append(p.extract_text() or "")
except Exception as e:
print(f"[WARN] PDF read failed for {path}: {e}")
return "\n".join(text)
def _load_txt(path: str) -> str:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
def _chunk_text(text: str, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) -> List[str]:
words = text.split()
chunks, i = [], 0
step = max(1, chunk_size - chunk_overlap)
while i < len(words):
chunks.append(" ".join(words[i:i+chunk_size]))
i += step
return chunks
def _ensure_models():
global _emb_model, _llm_tokenizer, _llm_model
if _emb_model is None:
_emb_model = SentenceTransformer(EMB_MODEL_ID)
if _llm_model is None or _llm_tokenizer is None:
_llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
_llm_model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL_ID, torch_dtype=torch.float32, device_map="cpu"
)
def _reset_index():
global _faiss_index, _meta
_faiss_index = None
_meta = None
gc.collect()
def _build_index_from_files(files: List[str]) -> Tuple[int, int]:
"""
Build FAISS from uploaded files. Returns (#files, #chunks)
"""
global _faiss_index, _meta
_ensure_models()
docs = []
for path in files:
lower = path.lower()
if lower.endswith(".pdf"):
txt = _load_pdf(path)
elif lower.endswith((".txt", ".md")):
txt = _load_txt(path)
else:
continue
if txt.strip():
docs.append({"source": os.path.basename(path), "text": txt})
dataset = []
for d in docs:
for ch in _chunk_text(d["text"]):
dataset.append({"source": d["source"], "chunk": ch})
if not dataset:
_reset_index()
return (0, 0)
# embeddings
texts = [row["chunk"] for row in dataset]
embs = []
for t in tqdm(texts, desc="Embedding"):
embs.append(_emb_model.encode(t, show_progress_bar=False, normalize_embeddings=True))
embs = np.vstack(embs).astype("float32")
index = faiss.IndexFlatIP(embs.shape[1]) # cosine via normalized vectors
index.add(embs)
_faiss_index = index
_meta = dataset
return (len(docs), len(dataset))
def _retrieve(query: str, k=TOP_K) -> List[Dict]:
q = _emb_model.encode(query, normalize_embeddings=True).astype("float32")
D, I = _faiss_index.search(np.expand_dims(q, 0), k)
results = []
for score, idx in zip(D[0], I[0]):
row = _meta[idx]
results.append({"score": float(score), "source": row["source"], "text": row["chunk"]})
return results
def _build_prompt(question: str, ctxs: List[Dict]) -> str:
context_block = "\n\n---\n".join(
[f"[{i+1}] Source: {c['source']}\n{c['text']}" for i, c in enumerate(ctxs)]
)
system_rules = (
"You are a careful assistant. Answer ONLY using the provided context. "
"If the answer is not in the context, say you don't know."
)
user_block = (
f"Question: {question}\n\n"
f"Context (use strictly):\n{context_block}\n\n"
"Answer:"
)
return f"<|system|>\n{system_rules}\n<|user|>\n{user_block}\n<|assistant|>\n"
def _generate_answer(question: str, ctxs: List[Dict], max_new_tokens=220) -> str:
inputs = _llm_tokenizer(_build_prompt(question, ctxs), return_tensors="pt")
with torch.no_grad():
out = _llm_model.generate(
**inputs, max_new_tokens=max_new_tokens, temperature=0.2, do_sample=False
)
text = _llm_tokenizer.decode(out[0], skip_special_tokens=True)
return text.split("<|assistant|>")[-1].strip()
# --------- GRADIO LOGIC ----------
def init_with_samples():
"""
Optional: build an index from bundled sample docs on startup.
You can put .txt in a local /docs folder if you like.
"""
sample_dir = "docs"
if os.path.isdir(sample_dir):
files = [p for p in glob.glob(os.path.join(sample_dir, "*")) if os.path.isfile(p)]
if files:
nfiles, nchunks = _build_index_from_files(files)
return f"Initialized with {nfiles} sample files → {nchunks} chunks."
return "No sample docs bundled. Upload your own to get started."
def upload_and_index(files):
if not files:
_reset_index()
return "No files uploaded. Index cleared."
paths = [f.name for f in files]
nfiles, nchunks = _build_index_from_files(paths)
return f"Indexed {nfiles} files → {nchunks} chunks. (Embedding dim=384)"
def ask_question(history, question):
if _faiss_index is None or _meta is None:
return history + [[question, "No index yet. Upload documents first or add sample docs."]]
ctxs = _retrieve(question, k=TOP_K)
ans = _generate_answer(question, ctxs)
# add simple citations footer
cites = " ".join(f"[{i+1}:{c['source']}]" for i, c in enumerate(ctxs))
final = f"{ans}\n\nSources: {cites}"
return history + [[question, final]]
def clear_index():
_reset_index()
return "Index cleared."
with gr.Blocks(title="RAG: Chat with Your Docs (CPU)") as demo:
gr.Markdown("# 🔎 Retrieval-Augmented Generation (CPU)\nUpload PDFs or text notes, then ask questions.")
status = gr.Markdown(init_with_samples())
with gr.Row():
with gr.Column(scale=1):
file_u = gr.File(label="Upload PDFs or .txt", file_count="multiple", file_types=[".pdf",".txt",".md"])
build_btn = gr.Button("Build / Rebuild Index")
clear_btn = gr.Button("Clear Index")
build_out = gr.Markdown()
clear_out = gr.Markdown()
with gr.Column(scale=2):
chat = gr.Chatbot(height=420)
q = gr.Textbox(label="Ask a question")
ask_btn = gr.Button("Ask")
build_btn.click(upload_and_index, inputs=file_u, outputs=build_out)
clear_btn.click(lambda: clear_index(), outputs=clear_out)
ask_btn.click(ask_question, inputs=[chat, q], outputs=chat)
q.submit(ask_question, inputs=[chat, q], outputs=chat)
if __name__ == "__main__":
demo.launch()