Burhan-AI / app.py
frendyrachman's picture
Update app.py
f513bae verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import chromadb
import torch
from sentence_transformers import SentenceTransformer
import os
from chromadb.utils import embedding_functions
# Initialize ChromaDB client with the existing path
client = chromadb.PersistentClient(path="new_hadith_rag_source")
# Load the existing collection
collection = client.get_collection(name="hadiths_new_complete")
# Debugging print to verify the number of documents in the collection
print(f"Number of documents in collection: {collection.count()}")
# Model and Tokenizer Loading
model_name = "google/flan-t5-base"
token = os.getenv("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", token=token)
llm = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
pad_token_id=tokenizer.eos_token_id,
token=token,
device_map="auto"
)
# Load the pre-trained model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
retrieval_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
# Function to query the collection
def query_collection(query, n_results):
# Compute the embedding for the query
query_embedding = retrieval_model.encode([query], convert_to_tensor=True, device=device).cpu().numpy()
# Query the collection
results = collection.query(query_embeddings=query_embedding, n_results=n_results)
return results
# Generate a response using the retrieved documents as context
def generate_response(context, question):
prompt = f"Please provide a short, well-structured answer and avoids repetition from context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Main chatbot function with basic RAG
def chatbot_response(user_query, top_k=2):
# Step 1: Retrieve relevant documents
results = query_collection(user_query, top_k)
# Step 2: Combine retrieved documents into context
documents = [doc for doc_list in results['documents'] for doc in doc_list]
combined_context = "\n\n".join(documents)
# Step 3: Generate a response using the combined context
response = generate_response(combined_context, user_query)
return response
# Global variable to control the processing state
stop_processing = False
def chatbot(query, num_candidates):
global stop_processing
stop_processing = False # Reset stop flag at the beginning of each query
# Jika query kosong, kembalikan pesan default
if not query.strip():
return "Please ask a question about hadiths."
# Lakukan retrieval dan generation dengan Speculative RAG
answer = chatbot_response(query, num_candidates)
# Check if stop button was pressed
if stop_processing:
return "Processing was stopped by the user."
# Format jawaban
if "don't know" in answer.lower() or "not sure" in answer.lower():
return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge."
else:
return answer
def stop():
global stop_processing
stop_processing = True
return "Processing stopped."
# Buat Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Burhan AI
Assalamualaikum! I am Burhan AI, a chatbot that can help you find answers to your questions about hadiths.
\n
Please note that this is a demo version and may not be perfect.
This chatbot is powered by the ChromaDB and Flan-T5-base models with RAG architecture.
Flan-T5-base is a small model and may not be as accurate as the bigger models.
If you have any feedback or suggestions, you can contact me at frendyrachman7@gmail.com
\n
Jazakallah Khairan!
"""
)
with gr.Row():
query_input = gr.Textbox(lines=2, placeholder="Enter your question here...")
num_candidates_input = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="Number of References")
submit_button = gr.Button("Submit")
output_text = gr.Textbox(label="Response")
submit_button.click(chatbot, inputs=[query_input, num_candidates_input], outputs=output_text)
# Add a button to stop processing
stop_button = gr.Button("Stop Processing")
stop_output = gr.Textbox(visible=False)
stop_button.click(stop, inputs=[], outputs=stop_output)
# Jalankan Gradio interface
demo.launch()