File size: 3,797 Bytes
251d59a 62bebe1 251d59a 86cad8f 251d59a 86cad8f 251d59a 86cad8f 65ff670 86cad8f 251d59a 86cad8f 62bebe1 86cad8f 62bebe1 86cad8f 62bebe1 251d59a 86cad8f 62bebe1 86cad8f 62bebe1 86cad8f 62bebe1 251d59a 86cad8f 62bebe1 86cad8f 3ddc813 86cad8f 3ddc813 8fbc00e 86cad8f 251d59a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | import os
import torch
import faiss
from PyPDF2 import PdfReader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import streamlit as st
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load GPT-2 model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer.pad_token = tokenizer.eos_token # Set padding token
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
# Function to extract text from uploaded PDFs
def extract_text_from_pdfs(uploaded_files):
text_data = []
for file in uploaded_files:
reader = PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text() or ""
text_data.append(text)
return text_data
# Function to create a FAISS index
def create_faiss_index(text_data):
embeddings = []
for text in text_data:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
embeddings.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
embeddings = torch.cat([torch.tensor(embed) for embed in embeddings], dim=0).numpy()
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index, embeddings
# Function to answer queries
def answer_query(query, index, text_data):
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
query_embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()
_, indices = index.search(query_embedding, k=1)
nearest_index = indices[0][0]
relevant_text = text_data[nearest_index]
input_text = f"Context: {relevant_text}\nQuestion: {query}\nAnswer:"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Streamlit UI
st.title("RAG App with GPT-2")
st.write("Upload PDF files to build a database and ask questions!")
# Upload PDF files
uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
# Build database
if st.button("Build Database") and uploaded_files:
with st.spinner("Processing files..."):
text_data = extract_text_from_pdfs(uploaded_files)
index, _ = create_faiss_index(text_data)
# Save the index and text data
faiss.write_index(index, "faiss_index.bin")
with open("text_data.txt", "w") as f:
for text in text_data:
f.write(text + "\n")
st.success("Database built successfully!")
# Load existing database
if os.path.exists("faiss_index.bin") and os.path.exists("text_data.txt"):
with st.spinner("Loading existing database..."):
index = faiss.read_index("faiss_index.bin")
with open("text_data.txt", "r") as f:
text_data = f.readlines()
st.success("Database loaded successfully!")
# Query input
query = st.text_input("Enter your query:")
# Get answer
if st.button("Get Answer") and query:
with st.spinner("Searching and generating answer..."):
try:
answer = answer_query(query, index, text_data)
st.success("Answer generated successfully!")
st.write(answer)
except Exception as e:
st.error(f"Error: {e}")
|