SujathaL commited on
Commit
4c3d6b3
·
verified ·
1 Parent(s): f04c935

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -16
app.py CHANGED
@@ -2,16 +2,13 @@ import streamlit as st
2
  import pdfplumber
3
  import faiss
4
  import numpy as np
5
- import torch
6
  from sentence_transformers import SentenceTransformer
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
- # Load Flan-T5 Model for Detailed Answers
11
- model_name = "google/flan-t5-small" # Smaller version that loads faster
12
-
13
- tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
 
16
  # Load Sentence Embeddings Model
17
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
@@ -48,16 +45,10 @@ def find_best_chunk(question, index, chunks, embeddings):
48
  _, closest_idx = index.search(np.array(question_embedding), 1)
49
  return chunks[closest_idx[0][0]]
50
 
51
- # Function to Generate a Long, Detailed Answer
52
  def get_answer(question, context):
53
- input_text = f"Question: {question}\nContext: {context}\nAnswer:"
54
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
55
-
56
- # Generate response
57
- with torch.no_grad():
58
- output = model.generate(**inputs, max_length=300, temperature=0.7)
59
-
60
- return tokenizer.decode(output[0], skip_special_tokens=True)
61
 
62
  # Streamlit UI
63
  st.title("Chat with AWS Restart PDF (Like ChatPDF)")
 
2
  import pdfplumber
3
  import faiss
4
  import numpy as np
 
5
  from sentence_transformers import SentenceTransformer
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from transformers import pipeline
8
 
9
+ # Load Extractive QA Model (Like ChatPDF)
10
+ model_name = "deepset/roberta-base-squad2"
11
+ qa_pipeline = pipeline("question-answering", model=model_name)
 
 
12
 
13
  # Load Sentence Embeddings Model
14
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
 
45
  _, closest_idx = index.search(np.array(question_embedding), 1)
46
  return chunks[closest_idx[0][0]]
47
 
48
+ # Function to Extract the Best Answer
49
  def get_answer(question, context):
50
+ response = qa_pipeline(question=question, context=context)
51
+ return response['answer'] # Returns extracted answer (ChatPDF-like behavior)
 
 
 
 
 
 
52
 
53
  # Streamlit UI
54
  st.title("Chat with AWS Restart PDF (Like ChatPDF)")