Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from datasets import load_dataset | |
| import chromadb | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| import os | |
| from chromadb.utils import embedding_functions | |
| # Initialize ChromaDB client with the existing path | |
| client = chromadb.PersistentClient(path="new_hadith_rag_source") | |
| # Load the existing collection | |
| collection = client.get_collection(name="hadiths_new_complete") | |
| # Debugging print to verify the number of documents in the collection | |
| print(f"Number of documents in collection: {collection.count()}") | |
| # Model and Tokenizer Loading | |
| model_name = "google/flan-t5-base" | |
| token = os.getenv("HUGGINGFACE_TOKEN") | |
| tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", token=token) | |
| llm = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| pad_token_id=tokenizer.eos_token_id, | |
| token=token, | |
| device_map="auto" | |
| ) | |
| # Load the pre-trained model and tokenizer | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| retrieval_model = SentenceTransformer('all-MiniLM-L6-v2').to(device) | |
| # Function to query the collection | |
| def query_collection(query, n_results): | |
| # Compute the embedding for the query | |
| query_embedding = retrieval_model.encode([query], convert_to_tensor=True, device=device).cpu().numpy() | |
| # Query the collection | |
| results = collection.query(query_embeddings=query_embedding, n_results=n_results) | |
| return results | |
| # Generate a response using the retrieved documents as context | |
| def generate_response(context, question): | |
| prompt = f"Please provide a short, well-structured answer and avoids repetition from context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # Main chatbot function with basic RAG | |
| def chatbot_response(user_query, top_k=2): | |
| # Step 1: Retrieve relevant documents | |
| results = query_collection(user_query, top_k) | |
| # Step 2: Combine retrieved documents into context | |
| documents = [doc for doc_list in results['documents'] for doc in doc_list] | |
| combined_context = "\n\n".join(documents) | |
| # Step 3: Generate a response using the combined context | |
| response = generate_response(combined_context, user_query) | |
| return response | |
| # Global variable to control the processing state | |
| stop_processing = False | |
| def chatbot(query, num_candidates): | |
| global stop_processing | |
| stop_processing = False # Reset stop flag at the beginning of each query | |
| # Jika query kosong, kembalikan pesan default | |
| if not query.strip(): | |
| return "Please ask a question about hadiths." | |
| # Lakukan retrieval dan generation dengan Speculative RAG | |
| answer = chatbot_response(query, num_candidates) | |
| # Check if stop button was pressed | |
| if stop_processing: | |
| return "Processing was stopped by the user." | |
| # Format jawaban | |
| if "don't know" in answer.lower() or "not sure" in answer.lower(): | |
| return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge." | |
| else: | |
| return answer | |
| def stop(): | |
| global stop_processing | |
| stop_processing = True | |
| return "Processing stopped." | |
| # Buat Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Burhan AI | |
| Assalamualaikum! I am Burhan AI, a chatbot that can help you find answers to your questions about hadiths. | |
| \n | |
| Please note that this is a demo version and may not be perfect. | |
| This chatbot is powered by the ChromaDB and Flan-T5-base models with RAG architecture. | |
| Flan-T5-base is a small model and may not be as accurate as the bigger models. | |
| If you have any feedback or suggestions, you can contact me at frendyrachman7@gmail.com | |
| \n | |
| Jazakallah Khairan! | |
| """ | |
| ) | |
| with gr.Row(): | |
| query_input = gr.Textbox(lines=2, placeholder="Enter your question here...") | |
| num_candidates_input = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="Number of References") | |
| submit_button = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Response") | |
| submit_button.click(chatbot, inputs=[query_input, num_candidates_input], outputs=output_text) | |
| # Add a button to stop processing | |
| stop_button = gr.Button("Stop Processing") | |
| stop_output = gr.Textbox(visible=False) | |
| stop_button.click(stop, inputs=[], outputs=stop_output) | |
| # Jalankan Gradio interface | |
| demo.launch() |