File size: 5,143 Bytes
01a6a25
 
618a86b
 
01a6a25
618a86b
 
01a6a25
618a86b
01a6a25
 
 
 
 
618a86b
c4fb363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618a86b
 
 
 
 
 
 
 
 
 
 
 
 
c4fb363
618a86b
 
 
 
 
 
 
 
 
 
 
 
c4fb363
 
 
 
01a6a25
618a86b
c4fb363
 
 
 
618a86b
01a6a25
 
 
 
c4fb363
 
01a6a25
c4fb363
 
 
 
 
 
618a86b
c4fb363
 
 
618a86b
c4fb363
01a6a25
c4fb363
618a86b
01a6a25
c4fb363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01a6a25
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import torch
import faiss
import numpy as np
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import pdfplumber
import docx

# Load RAG Model
model_name = "facebook/rag-sequence-nq"
tokenizer = RagTokenizer.from_pretrained(model_name)
retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)

# FAISS Vector Store
dimension = 768  
index = faiss.IndexFlatL2(dimension)  
stored_docs = []  
chat_history = []  

# Extract text from uploaded files
def extract_text(files):
    texts = []
    for file in files:
        file_name = file.name
        file_ext = file_name.split('.')[-1].lower()
        text = ""

        if file_ext == "txt":
            text = file.read().decode("utf-8")

        elif file_ext == "pdf":
            with pdfplumber.open(file) as pdf:
                for page in pdf.pages:
                    text += page.extract_text() + "\n"

        elif file_ext == "docx":
            doc = docx.Document(file)
            for para in doc.paragraphs:
                text += para.text + "\n"

        else:
            return "Unsupported file format! Upload TXT, PDF, or DOCX."
        
        texts.append(text.strip())
        store_in_faiss(text.strip())

    return "\n\n---\n\n".join(texts)  

# Store document in FAISS
def store_in_faiss(document):
    global index, stored_docs
    if not document.strip():
        return

    # Tokenize and get embeddings
    inputs = tokenizer(document, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        embeddings = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()
    
    index.add(embeddings)
    stored_docs.append(document)

# Retrieve the most relevant document from FAISS
def retrieve_relevant_doc(query):
    if index.ntotal == 0:
        return ""

    # Tokenize query and get embeddings
    inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        query_embedding = model.rag.retriever(input_ids=inputs["input_ids"]).cpu().numpy()

    _, top_idx = index.search(query_embedding, k=1)
    return stored_docs[top_idx[0][0]]

# Answer questions using RAG with FAISS and maintain chat history
def chat_with_ai(history, question):
    if not stored_docs:
        return history + [[question, "Please upload a document first."]]

    relevant_doc = retrieve_relevant_doc(question)
    
    chat_context = "\n".join(["User: " + q + "\nAI: " + a for q, a in history])
    full_input = f"Context: {chat_context}\n\nDocument: {relevant_doc}\n\nQuestion: {question}"
    
    inputs = tokenizer(question, relevant_doc, return_tensors="pt", truncation=True)
    with torch.no_grad():
        generated = model.generate(**inputs)
    answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]

    history.append([question, answer])  
    return history, answer  

# Gradio UI with Chat Interface, Voice Input & Text-to-Speech
with gr.Blocks(theme=gr.themes.Soft(), css="""
    .gradio-container {background-color: #1E1E1E; color: #FFFFFF;}
    .voice-btn, .speak-btn {background-color: #FFA500; color: black; border-radius: 5px; padding: 5px;}
""") as app:
    gr.Markdown("# πŸŽ™οΈ AI-Powered Document Chatbot with Voice Input & AI Speech", elem_id="title")

    with gr.Row():
        file_input = gr.File(label="Upload Documents (TXT, PDF, DOCX)", type="file", multiple=True)
        file_output = gr.Textbox(label="Extracted Text (Editable)", lines=10)

    file_input.change(extract_text, inputs=file_input, outputs=file_output)

    editor = gr.Textbox(label="Editor Canvas (Modify Extracted Text)", lines=10)
    file_output.change(lambda x: x, inputs=file_output, outputs=editor)

    chatbot = gr.Chatbot(label="AI Chat Assistant", elem_id="chatbot")
    question_input = gr.Textbox(label="Ask AI a Question", placeholder="Type or use voice...")

    with gr.Row():
        send_btn = gr.Button("Send", elem_id="send-btn")
        voice_btn = gr.Button("🎀 Voice", elem_id="voice-btn")
        speak_btn = gr.Button("πŸ—£οΈ Speak Answer", elem_id="speak-btn")

    send_btn.click(chat_with_ai, inputs=[chatbot, question_input], outputs=[chatbot, None])

    voice_btn.click(None, _js="""
        () => {
            const recognition = new webkitSpeechRecognition() || new SpeechRecognition();
            recognition.lang = "en-US";
            recognition.start();
            recognition.onresult = function(event) {
                let transcript = event.results[0][0].transcript;
                document.querySelector('textarea').value = transcript;
            };
        }
    """)

    speak_btn.click(None, _js="""
        () => {
            let lastMsg = document.querySelectorAll('.chat-message:last-child .chat-response')[0].innerText;
            let utterance = new SpeechSynthesisUtterance(lastMsg);
            utterance.lang = "en-US";
            utterance.rate = 1.0;
            speechSynthesis.speak(utterance);
        }
    """)

app.launch()