RAG / app.py
jake2004's picture
Update app.py
c4fb363 verified
import gradio as gr
import torch
import faiss
import numpy as np
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import pdfplumber
import docx
# Load RAG Model
model_name = "facebook/rag-sequence-nq"
tokenizer = RagTokenizer.from_pretrained(model_name)
retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
# FAISS Vector Store
dimension = 768
index = faiss.IndexFlatL2(dimension)
stored_docs = []
chat_history = []
# Extract text from uploaded files
def extract_text(files):
texts = []
for file in files:
file_name = file.name
file_ext = file_name.split('.')[-1].lower()
text = ""
if file_ext == "txt":
text = file.read().decode("utf-8")
elif file_ext == "pdf":
with pdfplumber.open(file) as pdf:
for page in pdf.pages:
text += page.extract_text() + "\n"
elif file_ext == "docx":
doc = docx.Document(file)
for para in doc.paragraphs:
text += para.text + "\n"
else:
return "Unsupported file format! Upload TXT, PDF, or DOCX."
texts.append(text.strip())
store_in_faiss(text.strip())
return "\n\n---\n\n".join(texts)
# Store document in FAISS
def store_in_faiss(document):
global index, stored_docs
if not document.strip():
return
# Tokenize and get embeddings
inputs = tokenizer(document, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
embeddings = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
index.add(embeddings)
stored_docs.append(document)
# Retrieve the most relevant document from FAISS
def retrieve_relevant_doc(query):
if index.ntotal == 0:
return ""
# Tokenize query and get embeddings
inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
query_embedding = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
_, top_idx = index.search(query_embedding, k=1)
return stored_docs[top_idx[0][0]]
# Answer questions using RAG with FAISS and maintain chat history
def chat_with_ai(history, question):
if not stored_docs:
return history + [[question, "Please upload a document first."]]
relevant_doc = retrieve_relevant_doc(question)
chat_context = "\n".join(["User: " + q + "\nAI: " + a for q, a in history])
full_input = f"Context: {chat_context}\n\nDocument: {relevant_doc}\n\nQuestion: {question}"
inputs = tokenizer(question, relevant_doc, return_tensors="pt", truncation=True)
with torch.no_grad():
generated = model.generate(**inputs)
answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
history.append([question, answer])
return history, answer
# Gradio UI with Chat Interface, Voice Input & Text-to-Speech
with gr.Blocks(theme=gr.themes.Soft(), css="""
.gradio-container {background-color: #1E1E1E; color: #FFFFFF;}
.voice-btn, .speak-btn {background-color: #FFA500; color: black; border-radius: 5px; padding: 5px;}
""") as app:
gr.Markdown("# πŸŽ™οΈ AI-Powered Document Chatbot with Voice Input & AI Speech", elem_id="title")
with gr.Row():
file_input = gr.File(label="Upload Documents (TXT, PDF, DOCX)", type="file", multiple=True)
file_output = gr.Textbox(label="Extracted Text (Editable)", lines=10)
file_input.change(extract_text, inputs=file_input, outputs=file_output)
editor = gr.Textbox(label="Editor Canvas (Modify Extracted Text)", lines=10)
file_output.change(lambda x: x, inputs=file_output, outputs=editor)
chatbot = gr.Chatbot(label="AI Chat Assistant", elem_id="chatbot")
question_input = gr.Textbox(label="Ask AI a Question", placeholder="Type or use voice...")
with gr.Row():
send_btn = gr.Button("Send", elem_id="send-btn")
voice_btn = gr.Button("🎀 Voice", elem_id="voice-btn")
speak_btn = gr.Button("πŸ—£οΈ Speak Answer", elem_id="speak-btn")
send_btn.click(chat_with_ai, inputs=[chatbot, question_input], outputs=[chatbot, None])
voice_btn.click(None, _js="""
() => {
const recognition = new webkitSpeechRecognition() || new SpeechRecognition();
recognition.lang = "en-US";
recognition.start();
recognition.onresult = function(event) {
let transcript = event.results[0][0].transcript;
document.querySelector('textarea').value = transcript;
};
}
""")
speak_btn.click(None, _js="""
() => {
let lastMsg = document.querySelectorAll('.chat-message:last-child .chat-response')[0].innerText;
let utterance = new SpeechSynthesisUtterance(lastMsg);
utterance.lang = "en-US";
utterance.rate = 1.0;
speechSynthesis.speak(utterance);
}
""")
app.launch()