akazmi commited on
Commit
ffd7a87
·
verified ·
1 Parent(s): dc8bce1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -39
app.py CHANGED
@@ -1,56 +1,73 @@
1
- import os
2
- import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from sentence_transformers import SentenceTransformer
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- import PyPDF2
 
8
 
9
- # Load LLM and Embedding model
10
- qa_model = "google/flan-t5-large"
11
- tokenizer = AutoTokenizer.from_pretrained(qa_model)
12
- model = AutoModelForSeq2SeqLM.from_pretrained(qa_model)
13
- qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
14
 
15
- embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
16
 
17
- # Global document store
18
- documents = []
19
- document_embeddings = []
 
 
 
 
 
 
20
 
21
- def extract_text(file):
22
- reader = PyPDF2.PdfReader(file)
23
- return "\n".join(page.extract_text() for page in reader.pages if page.extract_text())
 
 
 
 
 
 
24
 
25
- def add_document(file):
26
- text = extract_text(file)
27
- documents.append(text)
28
- document_embeddings.append(embedder.encode(text))
29
- return "Document uploaded and indexed successfully."
30
 
31
- def generate_answer(query):
32
- if not documents:
33
- return "Please upload a document first."
 
34
 
35
- query_embedding = embedder.encode(query)
36
- similarities = cosine_similarity([query_embedding], document_embeddings)[0]
37
- best_match_index = similarities.argmax()
38
- relevant_text = documents[best_match_index][:3000] # Truncate if too long
39
 
40
- prompt = f"Answer this question based on the context:\n\nContext: {relevant_text}\n\nQuestion: {query}"
41
- answer = qa_pipeline(prompt, max_new_tokens=300, temperature=0.3)[0]["generated_text"]
42
- return answer.strip()
 
 
 
 
 
 
 
 
 
43
 
44
  # Gradio UI
45
  with gr.Blocks() as demo:
46
- gr.Markdown("# 📄 Document Reader with RAG (Flan-T5)")
47
- file_input = gr.File(label="Upload PDF", type="filepath") # <-- FIXED HERE
48
  upload_btn = gr.Button("Upload & Index")
49
- query = gr.Textbox(label="Ask a question")
50
- submit_btn = gr.Button("Get Answer")
51
- answer_box = gr.Textbox(label="Answer")
52
 
53
- upload_btn.click(fn=add_document, inputs=file_input, outputs=answer_box)
54
- submit_btn.click(fn=generate_answer, inputs=query, outputs=answer_box)
55
 
56
  demo.launch()
 
 
 
1
  import gradio as gr
2
+ from PyPDF2 import PdfReader
3
  from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ from transformers import pipeline
7
 
8
+ # Load models once
9
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
10
+ qa_model = pipeline("text2text-generation", model="google/flan-t5-base")
 
 
11
 
12
+ # Store docs and vectors
13
+ doc_chunks = []
14
+ doc_embeddings = []
15
+ index = None
16
 
17
+ def read_pdf(file_path):
18
+ try:
19
+ reader = PdfReader(file_path)
20
+ text = ""
21
+ for page in reader.pages:
22
+ text += page.extract_text() or ""
23
+ return text
24
+ except Exception as e:
25
+ return f"Error reading PDF: {e}"
26
 
27
+ def add_document(file_path):
28
+ global doc_chunks, doc_embeddings, index
29
+ text = read_pdf(file_path)
30
+ if not text.strip():
31
+ return "❌ Could not extract text from PDF."
32
+
33
+ # Chunking the text (you can improve chunking logic)
34
+ chunks = [text[i:i+500] for i in range(0, len(text), 500)]
35
+ embeddings = embedding_model.encode(chunks)
36
 
37
+ # Save to global
38
+ doc_chunks = chunks
39
+ doc_embeddings = embeddings
 
 
40
 
41
+ # Create FAISS index
42
+ dim = len(embeddings[0])
43
+ index = faiss.IndexFlatL2(dim)
44
+ index.add(np.array(embeddings))
45
 
46
+ return f"✅ Uploaded & indexed {len(chunks)} chunks."
 
 
 
47
 
48
+ def generate_answer(query):
49
+ if index is None:
50
+ return "⚠️ Please upload a document first."
51
+
52
+ query_vec = embedding_model.encode([query])
53
+ D, I = index.search(np.array(query_vec), k=3)
54
+ context = " ".join([doc_chunks[i] for i in I[0]])
55
+
56
+ # Use QA model
57
+ prompt = f"Context: {context}\n\nQuestion: {query}\nAnswer:"
58
+ result = qa_model(prompt, max_new_tokens=128)[0]["generated_text"]
59
+ return result.strip()
60
 
61
  # Gradio UI
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("## 📄 Document Q&A with PDF Upload")
64
+ file_input = gr.File(label="Upload PDF", type="filepath")
65
  upload_btn = gr.Button("Upload & Index")
66
+ query_input = gr.Textbox(label="Ask your question here")
67
+ submit_btn = gr.Button("Answer")
68
+ output_box = gr.Textbox(label="Answer")
69
 
70
+ upload_btn.click(fn=add_document, inputs=file_input, outputs=output_box)
71
+ submit_btn.click(fn=generate_answer, inputs=query_input, outputs=output_box)
72
 
73
  demo.launch()