File size: 11,765 Bytes
0a8bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e9a6c
0a8bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e9a6c
0a8bbc5
 
 
 
19e9a6c
0a8bbc5
 
19e9a6c
 
 
 
 
0a8bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e9a6c
0a8bbc5
19e9a6c
0a8bbc5
 
 
 
 
 
19e9a6c
 
 
 
 
 
 
 
 
0a8bbc5
 
 
 
 
19e9a6c
0a8bbc5
 
19e9a6c
0a8bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4783950
 
 
0a8bbc5
 
 
 
4783950
0a8bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
19e9a6c
0a8bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4783950
0a8bbc5
 
 
 
 
4783950
0a8bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4783950
0a8bbc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os
import tempfile
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain_community.llms import HuggingFacePipeline

# Configure environment
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "google/flan-t5-large"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
THRESHOLD = 0.7  # Relevance threshold for retrieval
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TEMPERATURE = 0.1
MAX_NEW_TOKENS = 512
TOP_K = 3  # Number of chunks to retrieve

# Store for conversation history
conversation_history = {}
current_session_id = None
current_document_store = None
current_document_name = None
FILE_EXTENSIONS = {
    ".pdf": PyPDFLoader,
    ".txt": TextLoader,
    ".docx": Docx2txtLoader,
    ".pptx": UnstructuredPowerPointLoader,
}

class DocumentAIBot:
    def __init__(self):
        self.setup_models()
        
    def setup_models(self):
        print("Setting up models...")
        # Set up embedding model
        self.embedding_model = HuggingFaceEmbeddings(
            model_name=EMBEDDING_MODEL,
            model_kwargs={"device": DEVICE},
            encode_kwargs={"normalize_embeddings": True}
        )
        
        # Set up LLM model
        self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
        self.llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL).to(DEVICE)
        
        # Create text generation pipeline
        self.text_generation_pipeline = pipeline(
            "text2text-generation",
            model=self.llm_model,
            tokenizer=self.tokenizer,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            device=0 if DEVICE == "cuda" else -1
        )
        
        # Create HuggingFace pipeline for LangChain
        self.llm = HuggingFacePipeline(pipeline=self.text_generation_pipeline)
        
        # Text splitter for document chunking
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=CHUNK_SIZE,
            chunk_overlap=CHUNK_OVERLAP,
            length_function=len
        )
        
        print("Models loaded successfully!")
        
    def process_document(self, file_path):
        """Process a document and create a vector store."""
        print(f"Processing document: {file_path}")
        file_extension = os.path.splitext(file_path)[1].lower()
        
        if file_extension not in FILE_EXTENSIONS:
            raise ValueError(f"Unsupported file format: {file_extension}")
        
        # Select appropriate loader
        loader_class = FILE_EXTENSIONS[file_extension]
        loader = loader_class(file_path)
        
        # Load and split the document
        documents = loader.load()
        chunks = self.text_splitter.split_documents(documents)
        
        if not chunks:
            raise ValueError("No content extracted from the document")
        
        print(f"Document split into {len(chunks)} chunks")
        
        # Create vector store
        vector_store = FAISS.from_documents(chunks, self.embedding_model)
        return vector_store

    def setup_retrieval_chain(self, vector_store):
        """Set up the retrieval chain with the vector store."""
        retriever = vector_store.as_retriever(
            search_type="similarity_score_threshold",
            search_kwargs={
                "k": TOP_K,
                "score_threshold": THRESHOLD
            }
        )
        
        chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=retriever,
            return_source_documents=True,
            verbose=True
        )
        
        return chain
    
    def get_answer(self, question, session_id, vector_store, chat_history):
        """Get answer for a question using the retrieval chain."""
        if not question.strip():
            return "Please enter a question related to the document.", chat_history
        
        # Setup retrieval chain if needed
        retrieval_chain = self.setup_retrieval_chain(vector_store)
        
        # Format chat history for the model
        formatted_chat_history = [(q, a) for q, a in chat_history]
        
        # Get response from the chain
        response = retrieval_chain(
            {"question": question, "chat_history": formatted_chat_history}
        )
        
        answer = response["answer"]
        source_documents = response.get("source_documents", [])
        
        # Format answer with source information
        if source_documents:
            source_info = "\n\nSources:"
            seen_sources = set()
            
            for doc in source_documents:
                source = doc.metadata.get("source", "Unknown source")
                page = doc.metadata.get("page", "Unknown page")
                
                source_key = f"{source}-{page}"
                if source_key not in seen_sources:
                    seen_sources.add(source_key)
                    if source == "Unknown source":
                        source_info += f"\n- Document chunk (page {page})"
                    else:
                        source_info += f"\n- {os.path.basename(source)} (page {page})"
            
            answer += source_info
        
        return answer, chat_history + [(question, answer)]

def generate_session_id():
    """Generate a unique session ID."""
    import uuid
    return str(uuid.uuid4())

def process_uploaded_document(file_path):
    """Process an uploaded document and set up the session."""
    global current_session_id, current_document_store, current_document_name, conversation_history
    
    try:
        if file_path is None:
            return None, "Please upload a document first."
        
        # In newer Gradio versions, the file input with type="filepath" returns the path directly
        # No need to save the file as it's already saved by Gradio
        
        # Extract filename for display
        filename = os.path.basename(file_path)
        
        # Create document AI bot if not already created
        if not hasattr(process_uploaded_document, "bot"):
            process_uploaded_document.bot = DocumentAIBot()
        
        # Process the document
        vector_store = process_uploaded_document.bot.process_document(file_path)
        
        # Create a new session
        session_id = generate_session_id()
        conversation_history[session_id] = []
        
        # Update global variables
        current_session_id = session_id
        current_document_store = vector_store
        current_document_name = filename
        
        return [], f"Document '{filename}' processed successfully. You can now ask questions about it."
    
    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"Error processing document: {str(e)}"

def clear_conversation():
    """Clear the conversation history for the current session."""
    global conversation_history, current_session_id
    
    if current_session_id and current_session_id in conversation_history:
        conversation_history[current_session_id] = []
    
    return [], f"Conversation cleared. You can continue asking questions about '{current_document_name}'."

def answer_question(question, history):
    """Answer a question about the current document."""
    global current_session_id, current_document_store, conversation_history
    
    if not current_document_store:
        return "", history + [(question, "Please upload a document first.")]
    
    if not hasattr(process_uploaded_document, "bot"):
        return "", history + [(question, "Document AI bot not initialized. Please reload the page and try again.")]
    
    try:
        # Get current chat history
        chat_history = conversation_history.get(current_session_id, [])
        
        # Get answer
        answer, updated_history = process_uploaded_document.bot.get_answer(
            question, 
            current_session_id, 
            current_document_store, 
            chat_history
        )
        
        # Update conversation history
        conversation_history[current_session_id] = updated_history
        
        # Update the display history
        history = history + [(question, answer)]
        return "", history
    
    except Exception as e:
        import traceback
        traceback.print_exc()
        return "", history + [(question, f"Error generating answer: {str(e)}")]

def build_interface():
    """Build and launch the Gradio interface."""
    # Define the Gradio blocks
    with gr.Blocks(title="Document AI Chatbot") as interface:
        gr.Markdown("# 📄 Document AI Chatbot")
        gr.Markdown("Upload a document (PDF, TXT, DOCX, PPTX) and ask questions about its content.")
        
        with gr.Row():
            with gr.Column(scale=1):
                # Document upload and processing section
                file_input = gr.File(
                    label="Upload Document",
                    file_types=[".pdf", ".txt", ".docx", ".pptx"],
                    type="filepath"  # This returns the file path directly
                )
                
                upload_button = gr.Button("Process Document", variant="primary")
                upload_status = gr.Textbox(label="Upload Status", interactive=False)
                
                clear_button = gr.Button("Clear Conversation")
                
                gr.Markdown("### System Information")
                system_info = gr.Markdown(f"""
                - Embedding Model: {EMBEDDING_MODEL}
                - Language Model: {LLM_MODEL}
                - Running on: {DEVICE}
                - Chunk Size: {CHUNK_SIZE}
                - Relevance Threshold: {THRESHOLD}
                """)
            
            with gr.Column(scale=2):
                # Chat interface
                chatbot = gr.Chatbot(
                    label="Conversation",
                    height=500,
                    show_label=True,
                )
                
                with gr.Row():
                    question_input = gr.Textbox(
                        label="Ask a question about the document",
                        placeholder="What is the main topic of this document?",
                        lines=2,
                        show_label=True
                    )
                    
                    submit_button = gr.Button("Submit", variant="primary")
        
        # Set up event handlers
        upload_button.click(
            process_uploaded_document,
            inputs=[file_input],
            outputs=[chatbot, upload_status]
        )
        
        submit_button.click(
            answer_question,
            inputs=[question_input, chatbot],
            outputs=[question_input, chatbot]
        )
        
        question_input.submit(
            answer_question,
            inputs=[question_input, chatbot],
            outputs=[question_input, chatbot]
        )
        
        clear_button.click(
            clear_conversation,
            inputs=[],
            outputs=[chatbot, upload_status]
        )
    
    return interface

# Main execution
if __name__ == "__main__":
    demo = build_interface()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )