File size: 4,790 Bytes
a562fbe
0c5530c
99d6fc6
 
 
 
 
f513bae
9aed1e1
f513bae
 
 
 
 
 
 
8354b04
 
5484f67
6ec114c
baabbfa
54c3a62
 
58be8d6
 
 
fc3faf9
58be8d6
99d6fc6
f513bae
 
 
8354b04
f513bae
 
 
 
 
 
 
 
99d6fc6
 
f513bae
 
 
 
 
 
 
 
 
 
 
8354b04
f513bae
 
 
 
 
 
 
 
 
 
 
 
8354b04
a4af23f
f513bae
 
 
 
99d6fc6
 
f513bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8354b04
f513bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import chromadb
import torch
from sentence_transformers import SentenceTransformer
import os
from chromadb.utils import embedding_functions

# Initialize ChromaDB client with the existing path
client = chromadb.PersistentClient(path="new_hadith_rag_source")

# Load the existing collection
collection = client.get_collection(name="hadiths_new_complete")

# Debugging print to verify the number of documents in the collection
print(f"Number of documents in collection: {collection.count()}")

# Model and Tokenizer Loading
model_name = "google/flan-t5-base"
token = os.getenv("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", token=token)
llm = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    pad_token_id=tokenizer.eos_token_id,
    token=token,
    device_map="auto"
)

# Load the pre-trained model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
retrieval_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)

# Function to query the collection
def query_collection(query, n_results):
    # Compute the embedding for the query
    query_embedding = retrieval_model.encode([query], convert_to_tensor=True, device=device).cpu().numpy()
    
    # Query the collection
    results = collection.query(query_embeddings=query_embedding, n_results=n_results)
    
    return results

# Generate a response using the retrieved documents as context
def generate_response(context, question):
    prompt = f"Please provide a short, well-structured answer and avoids repetition from context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Main chatbot function with basic RAG
def chatbot_response(user_query, top_k=2):
    # Step 1: Retrieve relevant documents
    results = query_collection(user_query, top_k)
    
    # Step 2: Combine retrieved documents into context
    documents = [doc for doc_list in results['documents'] for doc in doc_list]
    combined_context = "\n\n".join(documents)
    
    # Step 3: Generate a response using the combined context
    response = generate_response(combined_context, user_query)
    
    return response

# Global variable to control the processing state
stop_processing = False

def chatbot(query, num_candidates):
    global stop_processing
    stop_processing = False  # Reset stop flag at the beginning of each query

    # Jika query kosong, kembalikan pesan default
    if not query.strip():
        return "Please ask a question about hadiths."
    
    # Lakukan retrieval dan generation dengan Speculative RAG
    answer = chatbot_response(query, num_candidates)
    
    # Check if stop button was pressed
    if stop_processing:
        return "Processing was stopped by the user."

    # Format jawaban
    if "don't know" in answer.lower() or "not sure" in answer.lower():
        return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge."
    else:
        return answer

def stop():
    global stop_processing
    stop_processing = True
    return "Processing stopped."

# Buat Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Burhan AI
        Assalamualaikum! I am Burhan AI, a chatbot that can help you find answers to your questions about hadiths. 
        \n 
        Please note that this is a demo version and may not be perfect.
        This chatbot is powered by the ChromaDB and Flan-T5-base models with RAG architecture.
        Flan-T5-base is a small model and may not be as accurate as the bigger models.
        If you have any feedback or suggestions, you can contact me at frendyrachman7@gmail.com
        \n
        Jazakallah Khairan!
        """
    )
    with gr.Row():
        query_input = gr.Textbox(lines=2, placeholder="Enter your question here...")
        num_candidates_input = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="Number of References")
        submit_button = gr.Button("Submit")
    
    output_text = gr.Textbox(label="Response")

    submit_button.click(chatbot, inputs=[query_input, num_candidates_input], outputs=output_text)

    # Add a button to stop processing
    stop_button = gr.Button("Stop Processing")
    stop_output = gr.Textbox(visible=False)
    stop_button.click(stop, inputs=[], outputs=stop_output)

# Jalankan Gradio interface
demo.launch()