jake2004 commited on
Commit
c4fb363
Β·
verified Β·
1 Parent(s): 928e0d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -64
app.py CHANGED
@@ -13,41 +13,41 @@ retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dum
13
  model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
14
 
15
  # FAISS Vector Store
16
- dimension = 768 # Default embedding size for transformers
17
- index = faiss.IndexFlatL2(dimension) # L2 distance-based index
18
- stored_docs = [] # To store document texts alongside vectors
19
-
20
- # Function to extract text from uploaded files
21
- def extract_text(file):
22
- if file is None:
23
- return "Please upload a document."
24
-
25
- file_name = file.name
26
- file_ext = file_name.split('.')[-1].lower()
27
- text = ""
28
-
29
- if file_ext == "txt":
30
- text = file.read().decode("utf-8")
31
-
32
- elif file_ext == "pdf":
33
- with pdfplumber.open(file) as pdf:
34
- for page in pdf.pages:
35
- text += page.extract_text() + "\n"
36
-
37
- elif file_ext == "docx":
38
- doc = docx.Document(file)
39
- for para in doc.paragraphs:
40
- text += para.text + "\n"
41
-
42
- else:
43
- return "Unsupported file format! Please upload TXT, PDF, or DOCX."
44
-
45
- # Store document in FAISS index
46
- store_in_faiss(text.strip())
47
-
48
- return text.strip()
49
-
50
- # Function to store document in FAISS
51
  def store_in_faiss(document):
52
  global index, stored_docs
53
  if not document.strip():
@@ -58,11 +58,10 @@ def store_in_faiss(document):
58
  with torch.no_grad():
59
  embeddings = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
60
 
61
- # Add embeddings to FAISS
62
  index.add(embeddings)
63
  stored_docs.append(document)
64
 
65
- # Function to retrieve top relevant document from FAISS
66
  def retrieve_relevant_doc(query):
67
  if index.ntotal == 0:
68
  return ""
@@ -72,48 +71,73 @@ def retrieve_relevant_doc(query):
72
  with torch.no_grad():
73
  query_embedding = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
74
 
75
- # Search in FAISS
76
  _, top_idx = index.search(query_embedding, k=1)
77
  return stored_docs[top_idx[0][0]]
78
 
79
- # Function to answer questions using RAG with FAISS
80
- def answer_question(document, question):
81
- if not document.strip():
82
- return "Please provide document content."
83
 
84
- # Retrieve best-matching document
85
  relevant_doc = retrieve_relevant_doc(question)
86
-
87
- # Use RAG model for answer generation
 
 
88
  inputs = tokenizer(question, relevant_doc, return_tensors="pt", truncation=True)
89
  with torch.no_grad():
90
  generated = model.generate(**inputs)
91
  answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
92
-
93
- return answer
94
 
95
- # Gradio UI
96
- with gr.Blocks() as app:
97
- gr.Markdown("# πŸ“„ Advanced RAG NLP Document Editor with FAISS")
98
 
99
- # File Uploader
100
- file_input = gr.File(label="Upload Document (TXT, PDF, DOCX)", type="file")
101
- file_output = gr.Textbox(label="Extracted Text (Editable)", lines=12)
 
 
 
102
 
103
- file_input.change(extract_text, inputs=file_input, outputs=file_output)
 
 
104
 
105
- # Editable Text Editor Canvas
106
- editor = gr.Textbox(label="Editor Canvas (Modify Text Before Asking)", lines=12)
107
 
108
- # Update editor with extracted text
109
  file_output.change(lambda x: x, inputs=file_output, outputs=editor)
110
 
111
- # Question Answering
112
- question_input = gr.Textbox(label="Ask a Question")
113
- answer_output = gr.Textbox(label="Answer", lines=2)
114
-
115
- submit_btn = gr.Button("Get Answer")
116
- submit_btn.click(answer_question, inputs=[editor, question_input], outputs=answer_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- # Launch in Hugging Face Spaces
119
  app.launch()
 
13
  model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
14
 
15
  # FAISS Vector Store
16
+ dimension = 768
17
+ index = faiss.IndexFlatL2(dimension)
18
+ stored_docs = []
19
+ chat_history = []
20
+
21
+ # Extract text from uploaded files
22
+ def extract_text(files):
23
+ texts = []
24
+ for file in files:
25
+ file_name = file.name
26
+ file_ext = file_name.split('.')[-1].lower()
27
+ text = ""
28
+
29
+ if file_ext == "txt":
30
+ text = file.read().decode("utf-8")
31
+
32
+ elif file_ext == "pdf":
33
+ with pdfplumber.open(file) as pdf:
34
+ for page in pdf.pages:
35
+ text += page.extract_text() + "\n"
36
+
37
+ elif file_ext == "docx":
38
+ doc = docx.Document(file)
39
+ for para in doc.paragraphs:
40
+ text += para.text + "\n"
41
+
42
+ else:
43
+ return "Unsupported file format! Upload TXT, PDF, or DOCX."
44
+
45
+ texts.append(text.strip())
46
+ store_in_faiss(text.strip())
47
+
48
+ return "\n\n---\n\n".join(texts)
49
+
50
+ # Store document in FAISS
51
  def store_in_faiss(document):
52
  global index, stored_docs
53
  if not document.strip():
 
58
  with torch.no_grad():
59
  embeddings = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
60
 
 
61
  index.add(embeddings)
62
  stored_docs.append(document)
63
 
64
+ # Retrieve the most relevant document from FAISS
65
  def retrieve_relevant_doc(query):
66
  if index.ntotal == 0:
67
  return ""
 
71
  with torch.no_grad():
72
  query_embedding = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
73
 
 
74
  _, top_idx = index.search(query_embedding, k=1)
75
  return stored_docs[top_idx[0][0]]
76
 
77
+ # Answer questions using RAG with FAISS and maintain chat history
78
+ def chat_with_ai(history, question):
79
+ if not stored_docs:
80
+ return history + [[question, "Please upload a document first."]]
81
 
 
82
  relevant_doc = retrieve_relevant_doc(question)
83
+
84
+ chat_context = "\n".join(["User: " + q + "\nAI: " + a for q, a in history])
85
+ full_input = f"Context: {chat_context}\n\nDocument: {relevant_doc}\n\nQuestion: {question}"
86
+
87
  inputs = tokenizer(question, relevant_doc, return_tensors="pt", truncation=True)
88
  with torch.no_grad():
89
  generated = model.generate(**inputs)
90
  answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
 
 
91
 
92
+ history.append([question, answer])
93
+ return history, answer
 
94
 
95
+ # Gradio UI with Chat Interface, Voice Input & Text-to-Speech
96
+ with gr.Blocks(theme=gr.themes.Soft(), css="""
97
+ .gradio-container {background-color: #1E1E1E; color: #FFFFFF;}
98
+ .voice-btn, .speak-btn {background-color: #FFA500; color: black; border-radius: 5px; padding: 5px;}
99
+ """) as app:
100
+ gr.Markdown("# πŸŽ™οΈ AI-Powered Document Chatbot with Voice Input & AI Speech", elem_id="title")
101
 
102
+ with gr.Row():
103
+ file_input = gr.File(label="Upload Documents (TXT, PDF, DOCX)", type="file", multiple=True)
104
+ file_output = gr.Textbox(label="Extracted Text (Editable)", lines=10)
105
 
106
+ file_input.change(extract_text, inputs=file_input, outputs=file_output)
 
107
 
108
+ editor = gr.Textbox(label="Editor Canvas (Modify Extracted Text)", lines=10)
109
  file_output.change(lambda x: x, inputs=file_output, outputs=editor)
110
 
111
+ chatbot = gr.Chatbot(label="AI Chat Assistant", elem_id="chatbot")
112
+ question_input = gr.Textbox(label="Ask AI a Question", placeholder="Type or use voice...")
113
+
114
+ with gr.Row():
115
+ send_btn = gr.Button("Send", elem_id="send-btn")
116
+ voice_btn = gr.Button("🎀 Voice", elem_id="voice-btn")
117
+ speak_btn = gr.Button("πŸ—£οΈ Speak Answer", elem_id="speak-btn")
118
+
119
+ send_btn.click(chat_with_ai, inputs=[chatbot, question_input], outputs=[chatbot, None])
120
+
121
+ voice_btn.click(None, _js="""
122
+ () => {
123
+ const recognition = new webkitSpeechRecognition() || new SpeechRecognition();
124
+ recognition.lang = "en-US";
125
+ recognition.start();
126
+ recognition.onresult = function(event) {
127
+ let transcript = event.results[0][0].transcript;
128
+ document.querySelector('textarea').value = transcript;
129
+ };
130
+ }
131
+ """)
132
+
133
+ speak_btn.click(None, _js="""
134
+ () => {
135
+ let lastMsg = document.querySelectorAll('.chat-message:last-child .chat-response')[0].innerText;
136
+ let utterance = new SpeechSynthesisUtterance(lastMsg);
137
+ utterance.lang = "en-US";
138
+ utterance.rate = 1.0;
139
+ speechSynthesis.speak(utterance);
140
+ }
141
+ """)
142
 
 
143
  app.launch()