|
|
import streamlit as st |
|
|
import torch |
|
|
import fitz |
|
|
import os |
|
|
import faiss |
|
|
import numpy as np |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="RAG with Phi-2", layout="wide") |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
with st.sidebar: |
|
|
st.markdown("### π₯οΈ System Info") |
|
|
st.text(f"Device: {DEVICE}") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_llm(): |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"microsoft/phi-2", |
|
|
token=HF_TOKEN |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"microsoft/phi-2", |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_embedder(): |
|
|
return SentenceTransformer("all-MiniLM-L6-v2") |
|
|
|
|
|
|
|
|
tokenizer, model = load_llm() |
|
|
embedder = load_embedder() |
|
|
|
|
|
|
|
|
st.title("π RAG App using π€ Phi-2") |
|
|
|
|
|
with st.sidebar: |
|
|
st.header("π Upload Document") |
|
|
uploaded_file = st.file_uploader("Upload PDF or TXT", type=["pdf", "txt"]) |
|
|
|
|
|
|
|
|
def extract_text(file): |
|
|
if file.type == "application/pdf": |
|
|
doc = fitz.open(stream=file.read(), filetype="pdf") |
|
|
return "\n".join(page.get_text() for page in doc) |
|
|
else: |
|
|
return file.read().decode("utf-8") |
|
|
|
|
|
|
|
|
def split_into_chunks(text, chunk_size=500): |
|
|
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] |
|
|
|
|
|
|
|
|
def create_faiss_index(chunks): |
|
|
embeddings = embedder.encode(chunks, show_progress_bar=True) |
|
|
embeddings = np.array(embeddings).astype("float32") |
|
|
|
|
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
|
index.add(embeddings) |
|
|
return index, embeddings |
|
|
|
|
|
|
|
|
def retrieve_chunks(query, chunks, index, k=5): |
|
|
query_embedding = embedder.encode([query]).astype("float32") |
|
|
_, indices = index.search(query_embedding, k) |
|
|
return [chunks[i] for i in indices[0]] |
|
|
|
|
|
|
|
|
def generate_answer(context, question): |
|
|
prompt = f""" |
|
|
Instruction: Answer ONLY using the context below. |
|
|
If the answer is not present, say "Information not found." |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{question} |
|
|
|
|
|
Answer: |
|
|
""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.2, |
|
|
do_sample=True, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
return tokenizer.decode(output[0], skip_special_tokens=True).split("Answer:")[-1].strip() |
|
|
|
|
|
|
|
|
if uploaded_file: |
|
|
raw_text = extract_text(uploaded_file) |
|
|
chunks = split_into_chunks(raw_text) |
|
|
|
|
|
st.sidebar.success(f"β
{len(chunks)} chunks created") |
|
|
|
|
|
with st.sidebar.expander("π Extracted Text"): |
|
|
st.text_area("Text", raw_text, height=300) |
|
|
|
|
|
index, _ = create_faiss_index(chunks) |
|
|
|
|
|
st.markdown("### π¬ Chat with your document") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
for msg in st.session_state.messages: |
|
|
with st.chat_message(msg["role"]): |
|
|
st.markdown(msg["content"]) |
|
|
|
|
|
if user_query := st.chat_input("Ask a question"): |
|
|
with st.chat_message("user"): |
|
|
st.markdown(user_query) |
|
|
|
|
|
st.session_state.messages.append( |
|
|
{"role": "user", "content": user_query} |
|
|
) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
context = "\n".join( |
|
|
retrieve_chunks(user_query, chunks, index) |
|
|
) |
|
|
answer = generate_answer(context, user_query) |
|
|
st.markdown(answer) |
|
|
|
|
|
st.session_state.messages.append( |
|
|
{"role": "assistant", "content": answer} |
|
|
) |
|
|
|
|
|
else: |
|
|
st.info("π Upload a document to begin chatting") |
|
|
|