amasood commited on
Commit
86cad8f
·
verified ·
1 Parent(s): 8fbc00e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -91
app.py CHANGED
@@ -1,120 +1,102 @@
1
  import os
2
- import streamlit as st
3
- import PyPDF2
4
  import torch
5
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
6
  import faiss
7
- import numpy as np
8
-
9
-
10
- # Load GPT-2 Model and Tokenizer
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model_name = "gpt2"
13
 
14
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
15
- model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
16
 
17
- # Set pad_token to eos_token
18
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
19
 
20
- # Sidebar for file upload
21
- st.sidebar.title("Upload PDFs")
22
- uploaded_files = st.sidebar.file_uploader("Upload one or more PDF files", accept_multiple_files=True, type=["pdf"])
23
 
24
- # Process PDF files
25
- def extract_text_from_pdf(pdf_files):
26
  text_data = []
27
- for file in pdf_files:
28
- pdf_reader = PyPDF2.PdfReader(file)
29
  text = ""
30
- for page in pdf_reader.pages:
31
- text += page.extract_text()
32
  text_data.append(text)
33
  return text_data
34
 
35
- # Create FAISS index
36
  def create_faiss_index(text_data):
37
- """
38
- Creates a FAISS index from the text data.
39
- """
40
- # Enable hidden states in the model configuration
41
- model.config.output_hidden_states = True
42
-
43
- # Initialize FAISS index
44
- dim = model.config.hidden_size # GPT-2 hidden size
45
- index = faiss.IndexFlatL2(dim)
46
-
47
  embeddings = []
48
-
49
  for text in text_data:
50
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
51
  with torch.no_grad():
52
- outputs = model(**inputs)
53
- # Extract the last layer's hidden state
54
- hidden_states = outputs.hidden_states[-1]
55
- embedding = hidden_states.mean(dim=1).cpu().numpy()
56
- embeddings.append(embedding)
57
- index.add(embedding)
58
-
59
  return index, embeddings
60
 
61
- # Answer queries
62
- def answer_query(query, index, embeddings, text_data):
63
- """
64
- Answers a query based on the FAISS index and text data.
65
- """
66
- # Check if FAISS index is populated
67
- if index.ntotal == 0:
68
- raise ValueError("The FAISS index is empty. Please upload documents to populate the database.")
69
-
70
- # Enable hidden states in the model configuration
71
- model.config.output_hidden_states = True
72
-
73
- # Tokenize the query
74
- inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True).to(device)
75
-
76
  with torch.no_grad():
77
- outputs = model(**inputs)
78
  query_embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()
79
 
80
- # Search for the nearest neighbor in the FAISS index
81
  _, indices = index.search(query_embedding, k=1)
82
- if len(indices) == 0 or indices[0][0] < 0:
83
- raise ValueError("No relevant context found for the given query.")
84
-
85
  nearest_index = indices[0][0]
86
-
87
- # Ensure text data size matches the FAISS index
88
- if nearest_index >= len(text_data):
89
- raise IndexError("Index out of range in text data. Please ensure data alignment.")
90
-
91
- # Retrieve the most relevant text
92
  relevant_text = text_data[nearest_index]
93
 
94
- # Generate an answer using the model
95
  input_text = f"Context: {relevant_text}\nQuestion: {query}\nAnswer:"
96
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device)
97
-
98
  with torch.no_grad():
99
  outputs = model.generate(**inputs, max_new_tokens=200)
100
-
101
- # Decode the generated answer
102
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
103
-
104
- return answer
105
-
106
- # Main app
107
- if uploaded_files:
108
- st.title("RAG Query Application")
109
- text_data = extract_text_from_pdf(uploaded_files)
110
- index, embeddings = create_faiss_index(text_data)
111
-
112
- query = st.text_input("Enter your query:")
113
- if query:
114
- with st.spinner("Fetching answer..."):
115
- answer = answer_query(query, index, embeddings, text_data)
116
- st.success(answer)
117
- else:
118
- st.title("Upload PDFs to Build RAG Database")
119
- st.write("Please upload one or more PDF files using the sidebar to start.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
1
  import os
 
 
2
  import torch
 
3
  import faiss
4
+ from PyPDF2 import PdfReader
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
+ import streamlit as st
 
 
 
7
 
8
+ # Device setup
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # Load GPT-2 model and tokenizer
12
+ @st.cache_resource
13
+ def load_model_and_tokenizer():
14
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
16
+ tokenizer.pad_token = tokenizer.eos_token # Set padding token
17
+ return model, tokenizer
18
 
19
+ model, tokenizer = load_model_and_tokenizer()
 
 
20
 
21
+ # Function to extract text from uploaded PDFs
22
+ def extract_text_from_pdfs(uploaded_files):
23
  text_data = []
24
+ for file in uploaded_files:
25
+ reader = PdfReader(file)
26
  text = ""
27
+ for page in reader.pages:
28
+ text += page.extract_text() or ""
29
  text_data.append(text)
30
  return text_data
31
 
32
+ # Function to create a FAISS index
33
  def create_faiss_index(text_data):
 
 
 
 
 
 
 
 
 
 
34
  embeddings = []
 
35
  for text in text_data:
36
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
37
  with torch.no_grad():
38
+ outputs = model(**inputs, output_hidden_states=True)
39
+ embeddings.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
40
+ embeddings = torch.cat([torch.tensor(embed) for embed in embeddings], dim=0).numpy()
41
+ dimension = embeddings.shape[1]
42
+ index = faiss.IndexFlatL2(dimension)
43
+ index.add(embeddings)
 
44
  return index, embeddings
45
 
46
+ # Function to answer queries
47
+ def answer_query(query, index, text_data):
48
+ inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
49
  with torch.no_grad():
50
+ outputs = model(**inputs, output_hidden_states=True)
51
  query_embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()
52
 
 
53
  _, indices = index.search(query_embedding, k=1)
 
 
 
54
  nearest_index = indices[0][0]
 
 
 
 
 
 
55
  relevant_text = text_data[nearest_index]
56
 
 
57
  input_text = f"Context: {relevant_text}\nQuestion: {query}\nAnswer:"
58
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=1024).to(device)
 
59
  with torch.no_grad():
60
  outputs = model.generate(**inputs, max_new_tokens=200)
61
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+
63
+ # Streamlit UI
64
+ st.title("RAG App with GPT-2")
65
+ st.write("Upload PDF files to build a database and ask questions!")
66
+
67
+ # Upload PDF files
68
+ uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
69
+
70
+ # Build database
71
+ if st.button("Build Database") and uploaded_files:
72
+ with st.spinner("Processing files..."):
73
+ text_data = extract_text_from_pdfs(uploaded_files)
74
+ index, _ = create_faiss_index(text_data)
75
+ # Save the index and text data
76
+ faiss.write_index(index, "faiss_index.bin")
77
+ with open("text_data.txt", "w") as f:
78
+ for text in text_data:
79
+ f.write(text + "\n")
80
+ st.success("Database built successfully!")
81
+
82
+ # Load existing database
83
+ if os.path.exists("faiss_index.bin") and os.path.exists("text_data.txt"):
84
+ with st.spinner("Loading existing database..."):
85
+ index = faiss.read_index("faiss_index.bin")
86
+ with open("text_data.txt", "r") as f:
87
+ text_data = f.readlines()
88
+ st.success("Database loaded successfully!")
89
+
90
+ # Query input
91
+ query = st.text_input("Enter your query:")
92
+
93
+ # Get answer
94
+ if st.button("Get Answer") and query:
95
+ with st.spinner("Searching and generating answer..."):
96
+ try:
97
+ answer = answer_query(query, index, text_data)
98
+ st.success("Answer generated successfully!")
99
+ st.write(answer)
100
+ except Exception as e:
101
+ st.error(f"Error: {e}")
102