SujathaL commited on
Commit
6c5b356
·
verified ·
1 Parent(s): 345c264

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -20
app.py CHANGED
@@ -2,11 +2,15 @@ import streamlit as st
2
  from transformers import pipeline
3
  import PyPDF2
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
5
 
6
- # Load Hugging Face Question Answering model
7
- qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
8
 
9
- # Function to extract text from PDF
 
 
 
10
  def extract_text_from_pdf(pdf_path):
11
  with open(pdf_path, "rb") as f:
12
  pdf_reader = PyPDF2.PdfReader(f)
@@ -15,31 +19,31 @@ def extract_text_from_pdf(pdf_path):
15
  text += page.extract_text() + "\n"
16
  return text
17
 
18
- # Function to split text into smaller chunks
19
  def split_text(text):
20
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
21
  chunks = text_splitter.split_text(text)
22
  return chunks
23
 
24
- # Function to find the most relevant chunk for a question
25
- def find_relevant_chunk(question, chunks):
26
- best_chunk = ""
27
- best_score = 0
28
- for chunk in chunks:
29
- response = qa_pipeline(question=question, context=chunk)
30
- score = response['score']
31
- if score > best_score:
32
- best_score = score
33
- best_chunk = chunk
34
- return best_chunk
35
 
36
  # Streamlit UI
37
  st.title("Chat with AWS Restart PDF")
38
 
39
- # Use the uploaded PDF file
40
- pdf_path = "AWS restart program information.docx.pdf" # Update with your file name
41
  pdf_text = extract_text_from_pdf(pdf_path)
42
- chunks = split_text(pdf_text) # Split the text into chunks
43
 
44
  st.write("✅ PDF Loaded Successfully!")
45
 
@@ -47,6 +51,6 @@ st.write("✅ PDF Loaded Successfully!")
47
  question = st.text_input("Ask a question about AWS Restart program:")
48
 
49
  if st.button("Get Answer") and question:
50
- relevant_chunk = find_relevant_chunk(question, chunks) # Get the best chunk
51
- response = qa_pipeline(question=question, context=relevant_chunk) # Ask model on best chunk
52
  st.write("Answer:", response['answer'])
 
2
  from transformers import pipeline
3
  import PyPDF2
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from sentence_transformers import SentenceTransformer, util
6
 
7
+ # Load the Question Answering Model
8
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
9
 
10
+ # Load Embeddings Model for Better Context Matching
11
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
12
+
13
+ # Function to Extract Text from PDF
14
  def extract_text_from_pdf(pdf_path):
15
  with open(pdf_path, "rb") as f:
16
  pdf_reader = PyPDF2.PdfReader(f)
 
19
  text += page.extract_text() + "\n"
20
  return text
21
 
22
+ # Function to Split Text into Chunks
23
  def split_text(text):
24
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
25
  chunks = text_splitter.split_text(text)
26
  return chunks
27
 
28
+ # Function to Find the Most Relevant Chunk Using Embeddings
29
+ def find_best_chunk(question, chunks):
30
+ question_embedding = embedding_model.encode(question, convert_to_tensor=True)
31
+ chunk_embeddings = [embedding_model.encode(chunk, convert_to_tensor=True) for chunk in chunks]
32
+
33
+ # Compute similarity between question and each chunk
34
+ similarities = [util.pytorch_cos_sim(question_embedding, chunk_emb).item() for chunk_emb in chunk_embeddings]
35
+
36
+ # Find the most relevant chunk
37
+ best_chunk_index = similarities.index(max(similarities))
38
+ return chunks[best_chunk_index]
39
 
40
  # Streamlit UI
41
  st.title("Chat with AWS Restart PDF")
42
 
43
+ # Load and Process PDF
44
+ pdf_path = "AWS restart program information.docx.pdf"
45
  pdf_text = extract_text_from_pdf(pdf_path)
46
+ chunks = split_text(pdf_text)
47
 
48
  st.write("✅ PDF Loaded Successfully!")
49
 
 
51
  question = st.text_input("Ask a question about AWS Restart program:")
52
 
53
  if st.button("Get Answer") and question:
54
+ relevant_chunk = find_best_chunk(question, chunks) # Retrieve the best chunk
55
+ response = qa_pipeline(question=question, context=relevant_chunk) # Ask the model
56
  st.write("Answer:", response['answer'])