RAG / app.py
jk12p's picture
Update app.py
8588048 verified
import streamlit as st
import torch
import fitz # PyMuPDF
import os
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
# ================= CONFIG =================
st.set_page_config(page_title="RAG with Phi-2", layout="wide")
HF_TOKEN = os.environ.get("HF_TOKEN")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
with st.sidebar:
st.markdown("### πŸ–₯️ System Info")
st.text(f"Device: {DEVICE}")
# ================= LOAD MODEL =================
@st.cache_resource
def load_llm():
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/phi-2",
token=HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
torch_dtype=torch.float32, # REQUIRED for CPU
low_cpu_mem_usage=True
)
model.eval()
return tokenizer, model
@st.cache_resource
def load_embedder():
return SentenceTransformer("all-MiniLM-L6-v2")
tokenizer, model = load_llm()
embedder = load_embedder()
# ================= UI =================
st.title("πŸ” RAG App using πŸ€– Phi-2")
with st.sidebar:
st.header("πŸ“ Upload Document")
uploaded_file = st.file_uploader("Upload PDF or TXT", type=["pdf", "txt"])
# ================= HELPERS =================
def extract_text(file):
if file.type == "application/pdf":
doc = fitz.open(stream=file.read(), filetype="pdf")
return "\n".join(page.get_text() for page in doc)
else:
return file.read().decode("utf-8")
def split_into_chunks(text, chunk_size=500):
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
def create_faiss_index(chunks):
embeddings = embedder.encode(chunks, show_progress_bar=True)
embeddings = np.array(embeddings).astype("float32")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return index, embeddings
def retrieve_chunks(query, chunks, index, k=5):
query_embedding = embedder.encode([query]).astype("float32")
_, indices = index.search(query_embedding, k)
return [chunks[i] for i in indices[0]]
def generate_answer(context, question):
prompt = f"""
Instruction: Answer ONLY using the context below.
If the answer is not present, say "Information not found."
Context:
{context}
Question:
{question}
Answer:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.2,
do_sample=True,
top_p=0.9
)
return tokenizer.decode(output[0], skip_special_tokens=True).split("Answer:")[-1].strip()
# ================= MAIN LOGIC =================
if uploaded_file:
raw_text = extract_text(uploaded_file)
chunks = split_into_chunks(raw_text)
st.sidebar.success(f"βœ… {len(chunks)} chunks created")
with st.sidebar.expander("πŸ“„ Extracted Text"):
st.text_area("Text", raw_text, height=300)
index, _ = create_faiss_index(chunks)
st.markdown("### πŸ’¬ Chat with your document")
if "messages" not in st.session_state:
st.session_state.messages = []
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if user_query := st.chat_input("Ask a question"):
with st.chat_message("user"):
st.markdown(user_query)
st.session_state.messages.append(
{"role": "user", "content": user_query}
)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
context = "\n".join(
retrieve_chunks(user_query, chunks, index)
)
answer = generate_answer(context, user_query)
st.markdown(answer)
st.session_state.messages.append(
{"role": "assistant", "content": answer}
)
else:
st.info("πŸ‘ˆ Upload a document to begin chatting")