|
|
import gradio as gr |
|
|
import torch |
|
|
import faiss |
|
|
import numpy as np |
|
|
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration |
|
|
import pdfplumber |
|
|
import docx |
|
|
|
|
|
|
|
|
model_name = "facebook/rag-sequence-nq" |
|
|
tokenizer = RagTokenizer.from_pretrained(model_name) |
|
|
retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True) |
|
|
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever) |
|
|
|
|
|
|
|
|
dimension = 768 |
|
|
index = faiss.IndexFlatL2(dimension) |
|
|
stored_docs = [] |
|
|
chat_history = [] |
|
|
|
|
|
|
|
|
def extract_text(files): |
|
|
texts = [] |
|
|
for file in files: |
|
|
file_name = file.name |
|
|
file_ext = file_name.split('.')[-1].lower() |
|
|
text = "" |
|
|
|
|
|
if file_ext == "txt": |
|
|
text = file.read().decode("utf-8") |
|
|
|
|
|
elif file_ext == "pdf": |
|
|
with pdfplumber.open(file) as pdf: |
|
|
for page in pdf.pages: |
|
|
text += page.extract_text() + "\n" |
|
|
|
|
|
elif file_ext == "docx": |
|
|
doc = docx.Document(file) |
|
|
for para in doc.paragraphs: |
|
|
text += para.text + "\n" |
|
|
|
|
|
else: |
|
|
return "Unsupported file format! Upload TXT, PDF, or DOCX." |
|
|
|
|
|
texts.append(text.strip()) |
|
|
store_in_faiss(text.strip()) |
|
|
|
|
|
return "\n\n---\n\n".join(texts) |
|
|
|
|
|
|
|
|
def store_in_faiss(document): |
|
|
global index, stored_docs |
|
|
if not document.strip(): |
|
|
return |
|
|
|
|
|
|
|
|
inputs = tokenizer(document, return_tensors="pt", truncation=True, max_length=512) |
|
|
with torch.no_grad(): |
|
|
embeddings = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy() |
|
|
|
|
|
index.add(embeddings) |
|
|
stored_docs.append(document) |
|
|
|
|
|
|
|
|
def retrieve_relevant_doc(query): |
|
|
if index.ntotal == 0: |
|
|
return "" |
|
|
|
|
|
|
|
|
inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=512) |
|
|
with torch.no_grad(): |
|
|
query_embedding = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy() |
|
|
|
|
|
_, top_idx = index.search(query_embedding, k=1) |
|
|
return stored_docs[top_idx[0][0]] |
|
|
|
|
|
|
|
|
def chat_with_ai(history, question): |
|
|
if not stored_docs: |
|
|
return history + [[question, "Please upload a document first."]] |
|
|
|
|
|
relevant_doc = retrieve_relevant_doc(question) |
|
|
|
|
|
chat_context = "\n".join(["User: " + q + "\nAI: " + a for q, a in history]) |
|
|
full_input = f"Context: {chat_context}\n\nDocument: {relevant_doc}\n\nQuestion: {question}" |
|
|
|
|
|
inputs = tokenizer(question, relevant_doc, return_tensors="pt", truncation=True) |
|
|
with torch.no_grad(): |
|
|
generated = model.generate(**inputs) |
|
|
answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] |
|
|
|
|
|
history.append([question, answer]) |
|
|
return history, answer |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=""" |
|
|
.gradio-container {background-color: #1E1E1E; color: #FFFFFF;} |
|
|
.voice-btn, .speak-btn {background-color: #FFA500; color: black; border-radius: 5px; padding: 5px;} |
|
|
""") as app: |
|
|
gr.Markdown("# ποΈ AI-Powered Document Chatbot with Voice Input & AI Speech", elem_id="title") |
|
|
|
|
|
with gr.Row(): |
|
|
file_input = gr.File(label="Upload Documents (TXT, PDF, DOCX)", type="file", multiple=True) |
|
|
file_output = gr.Textbox(label="Extracted Text (Editable)", lines=10) |
|
|
|
|
|
file_input.change(extract_text, inputs=file_input, outputs=file_output) |
|
|
|
|
|
editor = gr.Textbox(label="Editor Canvas (Modify Extracted Text)", lines=10) |
|
|
file_output.change(lambda x: x, inputs=file_output, outputs=editor) |
|
|
|
|
|
chatbot = gr.Chatbot(label="AI Chat Assistant", elem_id="chatbot") |
|
|
question_input = gr.Textbox(label="Ask AI a Question", placeholder="Type or use voice...") |
|
|
|
|
|
with gr.Row(): |
|
|
send_btn = gr.Button("Send", elem_id="send-btn") |
|
|
voice_btn = gr.Button("π€ Voice", elem_id="voice-btn") |
|
|
speak_btn = gr.Button("π£οΈ Speak Answer", elem_id="speak-btn") |
|
|
|
|
|
send_btn.click(chat_with_ai, inputs=[chatbot, question_input], outputs=[chatbot, None]) |
|
|
|
|
|
voice_btn.click(None, _js=""" |
|
|
() => { |
|
|
const recognition = new webkitSpeechRecognition() || new SpeechRecognition(); |
|
|
recognition.lang = "en-US"; |
|
|
recognition.start(); |
|
|
recognition.onresult = function(event) { |
|
|
let transcript = event.results[0][0].transcript; |
|
|
document.querySelector('textarea').value = transcript; |
|
|
}; |
|
|
} |
|
|
""") |
|
|
|
|
|
speak_btn.click(None, _js=""" |
|
|
() => { |
|
|
let lastMsg = document.querySelectorAll('.chat-message:last-child .chat-response')[0].innerText; |
|
|
let utterance = new SpeechSynthesisUtterance(lastMsg); |
|
|
utterance.lang = "en-US"; |
|
|
utterance.rate = 1.0; |
|
|
speechSynthesis.speak(utterance); |
|
|
} |
|
|
""") |
|
|
|
|
|
app.launch() |
|
|
|