import os import torch import faiss from PyPDF2 import PdfReader from transformers import GPT2Tokenizer, GPT2LMHeadModel import streamlit as st # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load GPT-2 model and tokenizer @st.cache_resource def load_model_and_tokenizer(): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2").to(device) tokenizer.pad_token = tokenizer.eos_token # Set padding token return model, tokenizer model, tokenizer = load_model_and_tokenizer() # Function to extract text from uploaded PDFs def extract_text_from_pdfs(uploaded_files): text_data = [] for file in uploaded_files: reader = PdfReader(file) text = "" for page in reader.pages: text += page.extract_text() or "" text_data.append(text) return text_data # Function to create a FAISS index def create_faiss_index(text_data): embeddings = [] for text in text_data: inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) embeddings.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy()) embeddings = torch.cat([torch.tensor(embed) for embed in embeddings], dim=0).numpy() dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) return index, embeddings # Function to answer queries def answer_query(query, index, text_data): inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) query_embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy() _, indices = index.search(query_embedding, k=1) nearest_index = indices[0][0] relevant_text = text_data[nearest_index] input_text = f"Context: {relevant_text}\nQuestion: {query}\nAnswer:" inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=200) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Streamlit UI st.title("RAG App with GPT-2") st.write("Upload PDF files to build a database and ask questions!") # Upload PDF files uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True) # Build database if st.button("Build Database") and uploaded_files: with st.spinner("Processing files..."): text_data = extract_text_from_pdfs(uploaded_files) index, _ = create_faiss_index(text_data) # Save the index and text data faiss.write_index(index, "faiss_index.bin") with open("text_data.txt", "w") as f: for text in text_data: f.write(text + "\n") st.success("Database built successfully!") # Load existing database if os.path.exists("faiss_index.bin") and os.path.exists("text_data.txt"): with st.spinner("Loading existing database..."): index = faiss.read_index("faiss_index.bin") with open("text_data.txt", "r") as f: text_data = f.readlines() st.success("Database loaded successfully!") # Query input query = st.text_input("Enter your query:") # Get answer if st.button("Get Answer") and query: with st.spinner("Searching and generating answer..."): try: answer = answer_query(query, index, text_data) st.success("Answer generated successfully!") st.write(answer) except Exception as e: st.error(f"Error: {e}")