akazmi commited on
Commit
51c2867
·
verified ·
1 Parent(s): ecc7d10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -70
app.py CHANGED
@@ -1,85 +1,56 @@
1
- import gradio as gr
2
  import torch
3
- import pdfplumber
4
- import re
5
- import numpy as np
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
 
9
 
10
- # ===== Load Embedding Model =====
11
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
12
-
13
- # ===== Load QA Model =====
14
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto")
17
- qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
18
-
19
- # ===== Read PDF and Clean =====
20
- def read_pdf(file_path):
21
- try:
22
- with pdfplumber.open(file_path) as pdf:
23
- return "\n".join(page.extract_text() or "" for page in pdf.pages)
24
- except Exception as e:
25
- return f"Error reading PDF: {str(e)}"
26
 
27
- # ===== Smart Sentence Chunking =====
28
- def chunk_text(text, max_len=500):
29
- sentences = re.split(r'(?<=[.؟!])\s+', text)
30
- chunks, current = [], ""
31
- for sentence in sentences:
32
- if len(current) + len(sentence) <= max_len:
33
- current += sentence + " "
34
- else:
35
- chunks.append(current.strip())
36
- current = sentence + " "
37
- if current:
38
- chunks.append(current.strip())
39
- return chunks
40
 
41
- # ===== Semantic Retrieval =====
42
- def get_relevant_chunks(question, chunks, top_k=2):
43
- q_vec = embedder.encode([question])
44
- c_vecs = embedder.encode(chunks)
45
- sims = cosine_similarity(q_vec, c_vecs)[0]
46
- top_indices = np.argsort(sims)[-top_k:][::-1]
47
- return "\n\n".join([chunks[i] for i in top_indices])
48
 
49
- # ===== Generate Answer =====
50
- def answer_question(file, question):
51
- if not file:
52
- return "⚠️ Please upload a PDF."
53
- if not question.strip():
54
- return "⚠️ Please enter a question."
55
 
56
- raw_text = read_pdf(file.name)
57
- if raw_text.startswith("Error"):
58
- return raw_text
 
 
59
 
60
- chunks = chunk_text(raw_text)
61
- context = get_relevant_chunks(question, chunks)
 
62
 
63
- prompt = (
64
- f"You are a legal expert. Based on the context below, answer the question in a detailed and explanatory manner.\n\n"
65
- f"Context:\n{context}\n\n"
66
- f"Question: {question}\n\n"
67
- f"Answer:"
68
- )
69
 
70
- try:
71
- response = qa_pipeline(prompt, max_new_tokens=300, do_sample=False, temperature=0.3)
72
- return response[0]["generated_text"].split("Answer:")[-1].strip()
73
- except Exception as e:
74
- return f"Error generating answer: {e}"
75
 
76
- # ===== Gradio Interface =====
77
  with gr.Blocks() as demo:
78
- gr.Markdown("## 📘 Document Question Answering (RAG-powered)")
79
- file = gr.File(label="Upload PDF", file_types=[".pdf"])
80
- question = gr.Textbox(label="Ask a question", placeholder="e.g., Is there any section for cost audit?")
81
- answer = gr.Textbox(label="Answer", lines=10)
82
- submit = gr.Button("Get Answer")
83
- submit.click(fn=answer_question, inputs=[file, question], outputs=answer)
 
 
 
84
 
85
  demo.launch()
 
1
+ import os
2
  import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
5
  from sentence_transformers import SentenceTransformer
6
  from sklearn.metrics.pairwise import cosine_similarity
7
+ import PyPDF2
8
 
9
+ # Load LLM and Embedding model
10
+ qa_model = "google/flan-t5-large"
11
+ tokenizer = AutoTokenizer.from_pretrained(qa_model)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(qa_model)
13
+ qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Global document store
18
+ documents = []
19
+ document_embeddings = []
 
 
 
 
20
 
21
+ def extract_text(file):
22
+ reader = PyPDF2.PdfReader(file)
23
+ return "\n".join(page.extract_text() for page in reader.pages if page.extract_text())
 
 
 
24
 
25
+ def add_document(file):
26
+ text = extract_text(file)
27
+ documents.append(text)
28
+ document_embeddings.append(embedder.encode(text))
29
+ return "Document uploaded and indexed successfully."
30
 
31
+ def generate_answer(query):
32
+ if not documents:
33
+ return "Please upload a document first."
34
 
35
+ query_embedding = embedder.encode(query)
36
+ similarities = cosine_similarity([query_embedding], document_embeddings)[0]
37
+ best_match_index = similarities.argmax()
38
+ relevant_text = documents[best_match_index][:3000] # Truncate if too long
 
 
39
 
40
+ prompt = f"Answer this question based on the context:\n\nContext: {relevant_text}\n\nQuestion: {query}"
41
+ answer = qa_pipeline(prompt, max_new_tokens=300, temperature=0.3)[0]["generated_text"]
42
+ return answer.strip()
 
 
43
 
44
+ # Gradio UI
45
  with gr.Blocks() as demo:
46
+ gr.Markdown("# 📄 Document Reader with RAG (Flan-T5)")
47
+ file_input = gr.File(label="Upload PDF", type="file")
48
+ upload_btn = gr.Button("Upload & Index")
49
+ query = gr.Textbox(label="Ask a question")
50
+ submit_btn = gr.Button("Get Answer")
51
+ answer_box = gr.Textbox(label="Answer")
52
+
53
+ upload_btn.click(fn=add_document, inputs=file_input, outputs=answer_box)
54
+ submit_btn.click(fn=generate_answer, inputs=query, outputs=answer_box)
55
 
56
  demo.launch()