myRAG / app.py
amasood's picture
Update app.py
86cad8f verified
import os
import torch
import faiss
from PyPDF2 import PdfReader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import streamlit as st
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load GPT-2 model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer.pad_token = tokenizer.eos_token # Set padding token
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
# Function to extract text from uploaded PDFs
def extract_text_from_pdfs(uploaded_files):
text_data = []
for file in uploaded_files:
reader = PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text() or ""
text_data.append(text)
return text_data
# Function to create a FAISS index
def create_faiss_index(text_data):
embeddings = []
for text in text_data:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
embeddings.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
embeddings = torch.cat([torch.tensor(embed) for embed in embeddings], dim=0).numpy()
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index, embeddings
# Function to answer queries
def answer_query(query, index, text_data):
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
query_embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()
_, indices = index.search(query_embedding, k=1)
nearest_index = indices[0][0]
relevant_text = text_data[nearest_index]
input_text = f"Context: {relevant_text}\nQuestion: {query}\nAnswer:"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Streamlit UI
st.title("RAG App with GPT-2")
st.write("Upload PDF files to build a database and ask questions!")
# Upload PDF files
uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
# Build database
if st.button("Build Database") and uploaded_files:
with st.spinner("Processing files..."):
text_data = extract_text_from_pdfs(uploaded_files)
index, _ = create_faiss_index(text_data)
# Save the index and text data
faiss.write_index(index, "faiss_index.bin")
with open("text_data.txt", "w") as f:
for text in text_data:
f.write(text + "\n")
st.success("Database built successfully!")
# Load existing database
if os.path.exists("faiss_index.bin") and os.path.exists("text_data.txt"):
with st.spinner("Loading existing database..."):
index = faiss.read_index("faiss_index.bin")
with open("text_data.txt", "r") as f:
text_data = f.readlines()
st.success("Database loaded successfully!")
# Query input
query = st.text_input("Enter your query:")
# Get answer
if st.button("Get Answer") and query:
with st.spinner("Searching and generating answer..."):
try:
answer = answer_query(query, index, text_data)
st.success("Answer generated successfully!")
st.write(answer)
except Exception as e:
st.error(f"Error: {e}")