| | import os |
| | import torch |
| | import faiss |
| | from PyPDF2 import PdfReader |
| | from transformers import GPT2Tokenizer, GPT2LMHeadModel |
| | import streamlit as st |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | @st.cache_resource |
| | def load_model_and_tokenizer(): |
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| | model = GPT2LMHeadModel.from_pretrained("gpt2").to(device) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | return model, tokenizer |
| |
|
| | model, tokenizer = load_model_and_tokenizer() |
| |
|
| | |
| | def extract_text_from_pdfs(uploaded_files): |
| | text_data = [] |
| | for file in uploaded_files: |
| | reader = PdfReader(file) |
| | text = "" |
| | for page in reader.pages: |
| | text += page.extract_text() or "" |
| | text_data.append(text) |
| | return text_data |
| |
|
| | |
| | def create_faiss_index(text_data): |
| | embeddings = [] |
| | for text in text_data: |
| | inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device) |
| | with torch.no_grad(): |
| | outputs = model(**inputs, output_hidden_states=True) |
| | embeddings.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy()) |
| | embeddings = torch.cat([torch.tensor(embed) for embed in embeddings], dim=0).numpy() |
| | dimension = embeddings.shape[1] |
| | index = faiss.IndexFlatL2(dimension) |
| | index.add(embeddings) |
| | return index, embeddings |
| |
|
| | |
| | def answer_query(query, index, text_data): |
| | inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device) |
| | with torch.no_grad(): |
| | outputs = model(**inputs, output_hidden_states=True) |
| | query_embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy() |
| |
|
| | _, indices = index.search(query_embedding, k=1) |
| | nearest_index = indices[0][0] |
| | relevant_text = text_data[nearest_index] |
| |
|
| | input_text = f"Context: {relevant_text}\nQuestion: {query}\nAnswer:" |
| | inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device) |
| | with torch.no_grad(): |
| | outputs = model.generate(**inputs, max_new_tokens=200) |
| | return tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | st.title("RAG App with GPT-2") |
| | st.write("Upload PDF files to build a database and ask questions!") |
| |
|
| | |
| | uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True) |
| |
|
| | |
| | if st.button("Build Database") and uploaded_files: |
| | with st.spinner("Processing files..."): |
| | text_data = extract_text_from_pdfs(uploaded_files) |
| | index, _ = create_faiss_index(text_data) |
| | |
| | faiss.write_index(index, "faiss_index.bin") |
| | with open("text_data.txt", "w") as f: |
| | for text in text_data: |
| | f.write(text + "\n") |
| | st.success("Database built successfully!") |
| |
|
| | |
| | if os.path.exists("faiss_index.bin") and os.path.exists("text_data.txt"): |
| | with st.spinner("Loading existing database..."): |
| | index = faiss.read_index("faiss_index.bin") |
| | with open("text_data.txt", "r") as f: |
| | text_data = f.readlines() |
| | st.success("Database loaded successfully!") |
| |
|
| | |
| | query = st.text_input("Enter your query:") |
| |
|
| | |
| | if st.button("Get Answer") and query: |
| | with st.spinner("Searching and generating answer..."): |
| | try: |
| | answer = answer_query(query, index, text_data) |
| | st.success("Answer generated successfully!") |
| | st.write(answer) |
| | except Exception as e: |
| | st.error(f"Error: {e}") |
| |
|
| |
|