import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM from datasets import load_dataset import chromadb import torch from sentence_transformers import SentenceTransformer import os from chromadb.utils import embedding_functions # Initialize ChromaDB client with the existing path client = chromadb.PersistentClient(path="new_hadith_rag_source") # Load the existing collection collection = client.get_collection(name="hadiths_new_complete") # Debugging print to verify the number of documents in the collection print(f"Number of documents in collection: {collection.count()}") # Model and Tokenizer Loading model_name = "google/flan-t5-base" token = os.getenv("HUGGINGFACE_TOKEN") tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", token=token) llm = AutoModelForSeq2SeqLM.from_pretrained( model_name, pad_token_id=tokenizer.eos_token_id, token=token, device_map="auto" ) # Load the pre-trained model and tokenizer device = 'cuda' if torch.cuda.is_available() else 'cpu' retrieval_model = SentenceTransformer('all-MiniLM-L6-v2').to(device) # Function to query the collection def query_collection(query, n_results): # Compute the embedding for the query query_embedding = retrieval_model.encode([query], convert_to_tensor=True, device=device).cpu().numpy() # Query the collection results = collection.query(query_embeddings=query_embedding, n_results=n_results) return results # Generate a response using the retrieved documents as context def generate_response(context, question): prompt = f"Please provide a short, well-structured answer and avoids repetition from context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:" inputs = tokenizer(prompt, return_tensors="pt").to(device) outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response # Main chatbot function with basic RAG def chatbot_response(user_query, top_k=2): # Step 1: Retrieve relevant documents results = query_collection(user_query, top_k) # Step 2: Combine retrieved documents into context documents = [doc for doc_list in results['documents'] for doc in doc_list] combined_context = "\n\n".join(documents) # Step 3: Generate a response using the combined context response = generate_response(combined_context, user_query) return response # Global variable to control the processing state stop_processing = False def chatbot(query, num_candidates): global stop_processing stop_processing = False # Reset stop flag at the beginning of each query # Jika query kosong, kembalikan pesan default if not query.strip(): return "Please ask a question about hadiths." # Lakukan retrieval dan generation dengan Speculative RAG answer = chatbot_response(query, num_candidates) # Check if stop button was pressed if stop_processing: return "Processing was stopped by the user." # Format jawaban if "don't know" in answer.lower() or "not sure" in answer.lower(): return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge." else: return answer def stop(): global stop_processing stop_processing = True return "Processing stopped." # Buat Gradio interface with gr.Blocks() as demo: gr.Markdown( """ # Burhan AI Assalamualaikum! I am Burhan AI, a chatbot that can help you find answers to your questions about hadiths. \n Please note that this is a demo version and may not be perfect. This chatbot is powered by the ChromaDB and Flan-T5-base models with RAG architecture. Flan-T5-base is a small model and may not be as accurate as the bigger models. If you have any feedback or suggestions, you can contact me at frendyrachman7@gmail.com \n Jazakallah Khairan! """ ) with gr.Row(): query_input = gr.Textbox(lines=2, placeholder="Enter your question here...") num_candidates_input = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="Number of References") submit_button = gr.Button("Submit") output_text = gr.Textbox(label="Response") submit_button.click(chatbot, inputs=[query_input, num_candidates_input], outputs=output_text) # Add a button to stop processing stop_button = gr.Button("Stop Processing") stop_output = gr.Textbox(visible=False) stop_button.click(stop, inputs=[], outputs=stop_output) # Jalankan Gradio interface demo.launch()