File size: 3,797 Bytes
251d59a
62bebe1
251d59a
86cad8f
 
 
251d59a
86cad8f
 
251d59a
86cad8f
 
 
 
 
 
 
65ff670
86cad8f
251d59a
86cad8f
 
62bebe1
86cad8f
 
62bebe1
86cad8f
 
62bebe1
 
251d59a
86cad8f
62bebe1
 
 
86cad8f
62bebe1
86cad8f
 
 
 
 
 
62bebe1
251d59a
86cad8f
 
 
62bebe1
86cad8f
3ddc813
 
 
 
 
 
 
86cad8f
3ddc813
8fbc00e
86cad8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251d59a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import torch
import faiss
from PyPDF2 import PdfReader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import streamlit as st

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load GPT-2 model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
    tokenizer.pad_token = tokenizer.eos_token  # Set padding token
    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()

# Function to extract text from uploaded PDFs
def extract_text_from_pdfs(uploaded_files):
    text_data = []
    for file in uploaded_files:
        reader = PdfReader(file)
        text = ""
        for page in reader.pages:
            text += page.extract_text() or ""
        text_data.append(text)
    return text_data

# Function to create a FAISS index
def create_faiss_index(text_data):
    embeddings = []
    for text in text_data:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            embeddings.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
    embeddings = torch.cat([torch.tensor(embed) for embed in embeddings], dim=0).numpy()
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return index, embeddings

# Function to answer queries
def answer_query(query, index, text_data):
    inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        query_embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()

    _, indices = index.search(query_embedding, k=1)
    nearest_index = indices[0][0]
    relevant_text = text_data[nearest_index]

    input_text = f"Context: {relevant_text}\nQuestion: {query}\nAnswer:"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=200)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Streamlit UI
st.title("RAG App with GPT-2")
st.write("Upload PDF files to build a database and ask questions!")

# Upload PDF files
uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)

# Build database
if st.button("Build Database") and uploaded_files:
    with st.spinner("Processing files..."):
        text_data = extract_text_from_pdfs(uploaded_files)
        index, _ = create_faiss_index(text_data)
        # Save the index and text data
        faiss.write_index(index, "faiss_index.bin")
        with open("text_data.txt", "w") as f:
            for text in text_data:
                f.write(text + "\n")
        st.success("Database built successfully!")

# Load existing database
if os.path.exists("faiss_index.bin") and os.path.exists("text_data.txt"):
    with st.spinner("Loading existing database..."):
        index = faiss.read_index("faiss_index.bin")
        with open("text_data.txt", "r") as f:
            text_data = f.readlines()
    st.success("Database loaded successfully!")

# Query input
query = st.text_input("Enter your query:")

# Get answer
if st.button("Get Answer") and query:
    with st.spinner("Searching and generating answer..."):
        try:
            answer = answer_query(query, index, text_data)
            st.success("Answer generated successfully!")
            st.write(answer)
        except Exception as e:
            st.error(f"Error: {e}")