jake2004 commited on
Commit
618a86b
Β·
verified Β·
1 Parent(s): 2a9b073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -12
app.py CHANGED
@@ -1,27 +1,91 @@
1
  import gradio as gr
2
  import torch
 
 
3
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
 
4
 
5
- # Load RAG model
6
  model_name = "facebook/rag-sequence-nq"
7
  tokenizer = RagTokenizer.from_pretrained(model_name)
8
  retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
9
  model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
10
 
11
- # Function to process uploaded document
12
- def process_file(file):
 
 
 
 
 
13
  if file is None:
14
  return "Please upload a document."
15
 
16
- file_text = file.decode("utf-8")
17
- return file_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Function to answer questions using RAG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def answer_question(document, question):
21
  if not document.strip():
22
  return "Please provide document content."
23
 
24
- inputs = tokenizer(question, document, return_tensors="pt", truncation=True)
 
 
 
 
25
  with torch.no_grad():
26
  generated = model.generate(**inputs)
27
  answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
@@ -30,20 +94,26 @@ def answer_question(document, question):
30
 
31
  # Gradio UI
32
  with gr.Blocks() as app:
33
- gr.Markdown("# πŸ“„ Advanced RAG NLP Document Editor")
34
 
35
  # File Uploader
36
- file_input = gr.File(label="Upload Document (TXT only)", type="binary")
37
- file_output = gr.Textbox(label="Extracted Text", lines=10)
 
 
 
 
 
38
 
39
- file_input.change(process_file, inputs=file_input, outputs=file_output)
 
40
 
41
  # Question Answering
42
  question_input = gr.Textbox(label="Ask a Question")
43
  answer_output = gr.Textbox(label="Answer", lines=2)
44
 
45
  submit_btn = gr.Button("Get Answer")
46
- submit_btn.click(answer_question, inputs=[file_output, question_input], outputs=answer_output)
47
 
48
  # Launch in Hugging Face Spaces
49
  app.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import faiss
4
+ import numpy as np
5
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
+ import pdfplumber
7
+ import docx
8
 
9
+ # Load RAG Model
10
  model_name = "facebook/rag-sequence-nq"
11
  tokenizer = RagTokenizer.from_pretrained(model_name)
12
  retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
13
  model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
14
 
15
+ # FAISS Vector Store
16
+ dimension = 768 # Default embedding size for transformers
17
+ index = faiss.IndexFlatL2(dimension) # L2 distance-based index
18
+ stored_docs = [] # To store document texts alongside vectors
19
+
20
+ # Function to extract text from uploaded files
21
+ def extract_text(file):
22
  if file is None:
23
  return "Please upload a document."
24
 
25
+ file_name = file.name
26
+ file_ext = file_name.split('.')[-1].lower()
27
+ text = ""
28
+
29
+ if file_ext == "txt":
30
+ text = file.read().decode("utf-8")
31
+
32
+ elif file_ext == "pdf":
33
+ with pdfplumber.open(file) as pdf:
34
+ for page in pdf.pages:
35
+ text += page.extract_text() + "\n"
36
+
37
+ elif file_ext == "docx":
38
+ doc = docx.Document(file)
39
+ for para in doc.paragraphs:
40
+ text += para.text + "\n"
41
+
42
+ else:
43
+ return "Unsupported file format! Please upload TXT, PDF, or DOCX."
44
+
45
+ # Store document in FAISS index
46
+ store_in_faiss(text.strip())
47
+
48
+ return text.strip()
49
 
50
+ # Function to store document in FAISS
51
+ def store_in_faiss(document):
52
+ global index, stored_docs
53
+ if not document.strip():
54
+ return
55
+
56
+ # Tokenize and get embeddings
57
+ inputs = tokenizer(document, return_tensors="pt", truncation=True, max_length=512)
58
+ with torch.no_grad():
59
+ embeddings = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
60
+
61
+ # Add embeddings to FAISS
62
+ index.add(embeddings)
63
+ stored_docs.append(document)
64
+
65
+ # Function to retrieve top relevant document from FAISS
66
+ def retrieve_relevant_doc(query):
67
+ if index.ntotal == 0:
68
+ return ""
69
+
70
+ # Tokenize query and get embeddings
71
+ inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=512)
72
+ with torch.no_grad():
73
+ query_embedding = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
74
+
75
+ # Search in FAISS
76
+ _, top_idx = index.search(query_embedding, k=1)
77
+ return stored_docs[top_idx[0][0]]
78
+
79
+ # Function to answer questions using RAG with FAISS
80
  def answer_question(document, question):
81
  if not document.strip():
82
  return "Please provide document content."
83
 
84
+ # Retrieve best-matching document
85
+ relevant_doc = retrieve_relevant_doc(question)
86
+
87
+ # Use RAG model for answer generation
88
+ inputs = tokenizer(question, relevant_doc, return_tensors="pt", truncation=True)
89
  with torch.no_grad():
90
  generated = model.generate(**inputs)
91
  answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
 
94
 
95
  # Gradio UI
96
  with gr.Blocks() as app:
97
+ gr.Markdown("# πŸ“„ Advanced RAG NLP Document Editor with FAISS")
98
 
99
  # File Uploader
100
+ file_input = gr.File(label="Upload Document (TXT, PDF, DOCX)", type="file")
101
+ file_output = gr.Textbox(label="Extracted Text (Editable)", lines=12)
102
+
103
+ file_input.change(extract_text, inputs=file_input, outputs=file_output)
104
+
105
+ # Editable Text Editor Canvas
106
+ editor = gr.Textbox(label="Editor Canvas (Modify Text Before Asking)", lines=12)
107
 
108
+ # Update editor with extracted text
109
+ file_output.change(lambda x: x, inputs=file_output, outputs=editor)
110
 
111
  # Question Answering
112
  question_input = gr.Textbox(label="Ask a Question")
113
  answer_output = gr.Textbox(label="Answer", lines=2)
114
 
115
  submit_btn = gr.Button("Get Answer")
116
+ submit_btn.click(answer_question, inputs=[editor, question_input], outputs=answer_output)
117
 
118
  # Launch in Hugging Face Spaces
119
  app.launch()