akazmi commited on
Commit
51138f0
Β·
verified Β·
1 Parent(s): 145a85d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -36
app.py CHANGED
@@ -1,26 +1,20 @@
1
  import gradio as gr
2
- import os
3
  from PyPDF2 import PdfReader
4
  from sentence_transformers import SentenceTransformer
5
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
- import torch
7
  import numpy as np
8
  from sklearn.metrics.pairwise import cosine_similarity
9
 
10
- # Load sentence transformer for embeddings
11
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
12
 
13
- # Load Zephyr model (use accelerate for GPU/CPU device map)
14
- model_name = "HuggingFaceH4/zephyr-7b-beta"
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_name,
18
- torch_dtype=torch.float16,
19
- device_map="auto" # βœ… Works if 'accelerate' is installed
20
- )
21
- rag_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
22
 
23
- # Function to extract text from PDF
24
  def read_pdf(file_path):
25
  try:
26
  with open(file_path, "rb") as file:
@@ -34,12 +28,12 @@ def read_pdf(file_path):
34
  except Exception as e:
35
  return f"Error reading PDF: {str(e)}"
36
 
37
- # Function to split text into chunks (~500 words)
38
  def chunk_text(text, chunk_size=500):
39
  words = text.split()
40
  return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
41
 
42
- # Function to find top-k most relevant chunks
43
  def retrieve_relevant_chunks(question, chunks, top_k=3):
44
  chunk_embeddings = embedder.encode(chunks)
45
  question_embedding = embedder.encode([question])
@@ -47,44 +41,38 @@ def retrieve_relevant_chunks(question, chunks, top_k=3):
47
  top_indices = np.argsort(scores)[-top_k:][::-1]
48
  return "\n\n".join([chunks[i] for i in top_indices])
49
 
50
- # Main function to process question and return answer
51
  def answer_question(uploaded_file, user_question):
52
  if uploaded_file is None:
53
  return "❌ Please upload a PDF file."
54
 
55
- file_path = uploaded_file.name
56
- document_text = read_pdf(file_path)
 
57
 
58
- if not document_text or not isinstance(document_text, str):
59
- return "❌ Document is empty or could not be read."
60
-
61
- chunks = chunk_text(document_text)
62
  if not chunks:
63
- return "❌ Document is too short to process."
64
-
65
- relevant_context = retrieve_relevant_chunks(user_question, chunks)
66
 
67
- prompt = f"""You are a helpful assistant. Use the context below to answer the user's question.\n\nContext:\n{relevant_context}\n\nQuestion: {user_question}\nAnswer:"""
68
 
 
 
69
  try:
70
- result = rag_pipeline(prompt, max_new_tokens=300, do_sample=True, temperature=0.7)
71
- answer = result[0]["generated_text"].split("Answer:")[-1].strip()
72
- return str(answer)
73
  except Exception as e:
74
- return f"❌ Error generating answer: {str(e)}"
75
 
76
- # Gradio interface
77
  def create_interface():
78
  with gr.Blocks() as demo:
79
- gr.Markdown("## πŸ“„ Ask Questions from a PDF Document (RAG using Zephyr 7B)")
80
-
81
  file_input = gr.File(label="Upload PDF", file_types=[".pdf"])
82
  question_input = gr.Textbox(label="Enter your question")
83
  answer_output = gr.Textbox(label="Answer", lines=10)
84
 
85
- ask_button = gr.Button("Ask")
86
- ask_button.click(fn=answer_question, inputs=[file_input, question_input], outputs=answer_output)
87
-
88
  return demo
89
 
90
  if __name__ == "__main__":
 
1
  import gradio as gr
 
2
  from PyPDF2 import PdfReader
3
  from sentence_transformers import SentenceTransformer
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
5
  import numpy as np
6
  from sklearn.metrics.pairwise import cosine_similarity
7
 
8
+ # Load embedding model (small and fast)
9
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
10
 
11
+ # Load FLAN-T5 model (CPU-friendly)
12
+ model_name = "google/flan-t5-base"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
+ rag_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
16
 
17
+ # Read PDF text
18
  def read_pdf(file_path):
19
  try:
20
  with open(file_path, "rb") as file:
 
28
  except Exception as e:
29
  return f"Error reading PDF: {str(e)}"
30
 
31
+ # Split into ~500-word chunks
32
  def chunk_text(text, chunk_size=500):
33
  words = text.split()
34
  return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
35
 
36
+ # Find top-k relevant chunks via cosine similarity
37
  def retrieve_relevant_chunks(question, chunks, top_k=3):
38
  chunk_embeddings = embedder.encode(chunks)
39
  question_embedding = embedder.encode([question])
 
41
  top_indices = np.argsort(scores)[-top_k:][::-1]
42
  return "\n\n".join([chunks[i] for i in top_indices])
43
 
44
+ # Main QA function
45
  def answer_question(uploaded_file, user_question):
46
  if uploaded_file is None:
47
  return "❌ Please upload a PDF file."
48
 
49
+ text = read_pdf(uploaded_file.name)
50
+ if not text or not isinstance(text, str):
51
+ return "❌ Could not extract text from the document."
52
 
53
+ chunks = chunk_text(text)
 
 
 
54
  if not chunks:
55
+ return "❌ Document too short or empty."
 
 
56
 
57
+ context = retrieve_relevant_chunks(user_question, chunks)
58
 
59
+ prompt = f"Context: {context}\n\nQuestion: {user_question}\nAnswer:"
60
+
61
  try:
62
+ result = rag_pipeline(prompt, max_new_tokens=256)
63
+ return result[0]["generated_text"].strip()
 
64
  except Exception as e:
65
+ return f"❌ Error during generation: {str(e)}"
66
 
67
+ # Gradio Interface
68
  def create_interface():
69
  with gr.Blocks() as demo:
70
+ gr.Markdown("## πŸ“„ Ask Questions from a PDF Document (RAG using FLAN-T5)")
 
71
  file_input = gr.File(label="Upload PDF", file_types=[".pdf"])
72
  question_input = gr.Textbox(label="Enter your question")
73
  answer_output = gr.Textbox(label="Answer", lines=10)
74
 
75
+ gr.Button("Ask").click(fn=answer_question, inputs=[file_input, question_input], outputs=[answer_output])
 
 
76
  return demo
77
 
78
  if __name__ == "__main__":