Spaces:
Sleeping
Sleeping
File size: 4,790 Bytes
a562fbe 0c5530c 99d6fc6 f513bae 9aed1e1 f513bae 8354b04 5484f67 6ec114c baabbfa 54c3a62 58be8d6 fc3faf9 58be8d6 99d6fc6 f513bae 8354b04 f513bae 99d6fc6 f513bae 8354b04 f513bae 8354b04 a4af23f f513bae 99d6fc6 f513bae 8354b04 f513bae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import chromadb
import torch
from sentence_transformers import SentenceTransformer
import os
from chromadb.utils import embedding_functions
# Initialize ChromaDB client with the existing path
client = chromadb.PersistentClient(path="new_hadith_rag_source")
# Load the existing collection
collection = client.get_collection(name="hadiths_new_complete")
# Debugging print to verify the number of documents in the collection
print(f"Number of documents in collection: {collection.count()}")
# Model and Tokenizer Loading
model_name = "google/flan-t5-base"
token = os.getenv("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", token=token)
llm = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
pad_token_id=tokenizer.eos_token_id,
token=token,
device_map="auto"
)
# Load the pre-trained model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
retrieval_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
# Function to query the collection
def query_collection(query, n_results):
# Compute the embedding for the query
query_embedding = retrieval_model.encode([query], convert_to_tensor=True, device=device).cpu().numpy()
# Query the collection
results = collection.query(query_embeddings=query_embedding, n_results=n_results)
return results
# Generate a response using the retrieved documents as context
def generate_response(context, question):
prompt = f"Please provide a short, well-structured answer and avoids repetition from context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Main chatbot function with basic RAG
def chatbot_response(user_query, top_k=2):
# Step 1: Retrieve relevant documents
results = query_collection(user_query, top_k)
# Step 2: Combine retrieved documents into context
documents = [doc for doc_list in results['documents'] for doc in doc_list]
combined_context = "\n\n".join(documents)
# Step 3: Generate a response using the combined context
response = generate_response(combined_context, user_query)
return response
# Global variable to control the processing state
stop_processing = False
def chatbot(query, num_candidates):
global stop_processing
stop_processing = False # Reset stop flag at the beginning of each query
# Jika query kosong, kembalikan pesan default
if not query.strip():
return "Please ask a question about hadiths."
# Lakukan retrieval dan generation dengan Speculative RAG
answer = chatbot_response(query, num_candidates)
# Check if stop button was pressed
if stop_processing:
return "Processing was stopped by the user."
# Format jawaban
if "don't know" in answer.lower() or "not sure" in answer.lower():
return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge."
else:
return answer
def stop():
global stop_processing
stop_processing = True
return "Processing stopped."
# Buat Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Burhan AI
Assalamualaikum! I am Burhan AI, a chatbot that can help you find answers to your questions about hadiths.
\n
Please note that this is a demo version and may not be perfect.
This chatbot is powered by the ChromaDB and Flan-T5-base models with RAG architecture.
Flan-T5-base is a small model and may not be as accurate as the bigger models.
If you have any feedback or suggestions, you can contact me at frendyrachman7@gmail.com
\n
Jazakallah Khairan!
"""
)
with gr.Row():
query_input = gr.Textbox(lines=2, placeholder="Enter your question here...")
num_candidates_input = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="Number of References")
submit_button = gr.Button("Submit")
output_text = gr.Textbox(label="Response")
submit_button.click(chatbot, inputs=[query_input, num_candidates_input], outputs=output_text)
# Add a button to stop processing
stop_button = gr.Button("Stop Processing")
stop_output = gr.Textbox(visible=False)
stop_button.click(stop, inputs=[], outputs=stop_output)
# Jalankan Gradio interface
demo.launch() |