akazmi commited on
Commit
7aec612
Β·
verified Β·
1 Parent(s): 45d1cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -41
app.py CHANGED
@@ -1,77 +1,86 @@
1
  import gradio as gr
2
  import pdfplumber
 
3
  from sentence_transformers import SentenceTransformer
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
5
  import numpy as np
6
  from sklearn.metrics.pairwise import cosine_similarity
7
- import torch
8
 
9
- # βœ… Load models
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  embedder = SentenceTransformer("all-MiniLM-L6-v2", device=device)
12
 
13
- model_name = "google/flan-t5-small"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
16
  qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
17
 
18
- # βœ… Step 1: Read PDF using pdfplumber
19
  def read_pdf(file_path):
20
  try:
21
  with pdfplumber.open(file_path) as pdf:
22
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
23
- return text.strip()
24
  except Exception as e:
25
- return f"❌ Failed to read PDF: {e}"
26
 
27
- # βœ… Step 2: Chunk document text
28
- def chunk_text(text, chunk_size=500):
29
- words = text.split()
30
- return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
 
 
 
 
 
 
 
 
 
 
31
 
32
- # βœ… Step 3: Semantic retrieval of relevant chunks
33
- def get_relevant_chunks(question, chunks, top_k=2):
34
- question_embedding = embedder.encode([question])
35
- chunk_embeddings = embedder.encode(chunks)
36
- similarities = cosine_similarity(question_embedding, chunk_embeddings)[0]
37
- top_k_indices = np.argsort(similarities)[-top_k:][::-1]
38
- return "\n\n".join([chunks[i] for i in top_k_indices])
39
 
40
- # βœ… Step 4: Generate answer with retrieved context
41
  def answer_question(pdf_file, user_question):
42
- if pdf_file is None or user_question.strip() == "":
43
- return "⚠️ Please upload a document and enter a question."
44
 
45
  text = read_pdf(pdf_file.name)
46
- if not text:
47
- return "⚠️ PDF has no readable text."
48
 
49
  chunks = chunk_text(text)
50
- relevant_context = get_relevant_chunks(user_question, chunks, top_k=2)
51
-
52
- prompt = f"""You are a helpful assistant. Use the context to answer the question.
53
-
54
- Context:
55
- {relevant_context}
56
 
57
- Question: {user_question}
58
- Answer:"""
 
 
 
 
59
 
60
  try:
61
- result = qa_pipeline(prompt, max_new_tokens=200)
62
  return result[0]["generated_text"].split("Answer:")[-1].strip()
63
  except Exception as e:
64
- return f"❌ Error during generation: {e}"
65
 
66
- # βœ… Step 5: Gradio UI
67
  with gr.Blocks() as demo:
68
- gr.Markdown("### πŸ“„ Ask Questions from Your PDF Document (RAG-based QA)")
69
- with gr.Row():
70
- pdf_input = gr.File(label="πŸ“ Upload PDF", file_types=[".pdf"])
71
- question_input = gr.Textbox(label="❓ Ask a question about the document")
72
- answer_output = gr.Textbox(label="🧠 Answer", lines=8)
73
- submit_btn = gr.Button("πŸ” Get Answer")
74
 
75
- submit_btn.click(fn=answer_question, inputs=[pdf_input, question_input], outputs=answer_output)
76
 
77
  demo.launch()
 
1
  import gradio as gr
2
  import pdfplumber
3
+ import torch
4
  from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  import numpy as np
7
  from sklearn.metrics.pairwise import cosine_similarity
8
+ import re
9
 
10
+ # Load models
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  embedder = SentenceTransformer("all-MiniLM-L6-v2", device=device)
13
 
14
+ model_name = "google/flan-t5-base" # stronger than 'small'
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
17
  qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
18
 
19
+ # Extract and clean PDF text
20
  def read_pdf(file_path):
21
  try:
22
  with pdfplumber.open(file_path) as pdf:
23
  text = "\n".join([page.extract_text() or "" for page in pdf.pages])
24
+ return re.sub(r'\n+', '\n', text.strip())
25
  except Exception as e:
26
+ return f"❌ PDF reading failed: {e}"
27
 
28
+ # Chunk the text into clean sentence-like blocks
29
+ def chunk_text(text, max_length=500):
30
+ sentences = re.split(r'(?<=[.!?])\s+', text)
31
+ chunks = []
32
+ current_chunk = ""
33
+ for sentence in sentences:
34
+ if len(current_chunk) + len(sentence) <= max_length:
35
+ current_chunk += sentence + " "
36
+ else:
37
+ chunks.append(current_chunk.strip())
38
+ current_chunk = sentence + " "
39
+ if current_chunk:
40
+ chunks.append(current_chunk.strip())
41
+ return chunks
42
 
43
+ # Embed and get top chunks
44
+ def get_top_chunks(question, chunks, k=2):
45
+ q_embed = embedder.encode([question])
46
+ chunk_embeds = embedder.encode(chunks)
47
+ sims = cosine_similarity(q_embed, chunk_embeds)[0]
48
+ top_k_idx = np.argsort(sims)[-k:][::-1]
49
+ return "\n\n".join([chunks[i] for i in top_k_idx])
50
 
51
+ # Generate answer
52
  def answer_question(pdf_file, user_question):
53
+ if not pdf_file or not user_question.strip():
54
+ return "⚠️ Upload a PDF and enter your question."
55
 
56
  text = read_pdf(pdf_file.name)
57
+ if not text or text.startswith("❌"):
58
+ return text
59
 
60
  chunks = chunk_text(text)
61
+ relevant = get_top_chunks(user_question, chunks)
 
 
 
 
 
62
 
63
+ prompt = (
64
+ f"You are a legal document assistant. Based on the context below, "
65
+ f"answer the question briefly and clearly.\n\n"
66
+ f"Context:\n{relevant}\n\n"
67
+ f"Question: {user_question}\n\nAnswer:"
68
+ )
69
 
70
  try:
71
+ result = qa_pipeline(prompt, max_new_tokens=256, do_sample=False)
72
  return result[0]["generated_text"].split("Answer:")[-1].strip()
73
  except Exception as e:
74
+ return f"❌ Generation error: {e}"
75
 
76
+ # Gradio interface
77
  with gr.Blocks() as demo:
78
+ gr.Markdown("## πŸ“š Legal Document Q&A Assistant")
79
+ pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
80
+ question_input = gr.Textbox(label="Ask a question")
81
+ answer_output = gr.Textbox(label="Answer", lines=8)
82
+ ask_button = gr.Button("Get Answer")
 
83
 
84
+ ask_button.click(answer_question, inputs=[pdf_input, question_input], outputs=answer_output)
85
 
86
  demo.launch()