| | import streamlit as st |
| | import logging |
| | from BanglaRAG.bangla_rag_pipeline import BanglaRAGChain |
| | import warnings |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| | |
| | DEFAULT_CHAT_MODEL_ID = "hassanaliemon/bn_rag_llama3-8b" |
| | DEFAULT_EMBED_MODEL_ID = "l3cube-pune/bengali-sentence-similarity-sbert" |
| | DEFAULT_K = 4 |
| | DEFAULT_TOP_K = 2 |
| | DEFAULT_TOP_P = 0.6 |
| | DEFAULT_TEMPERATURE = 0.6 |
| | DEFAULT_CHUNK_SIZE = 500 |
| | DEFAULT_CHUNK_OVERLAP = 150 |
| | DEFAULT_MAX_NEW_TOKENS = 256 |
| | DEFAULT_OFFLOAD_DIR = "/tmp" |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
| | ) |
| |
|
| | |
| | @st.cache_resource(show_spinner=False) |
| | def load_model(chat_model_id, embed_model_id, text_path, k, top_k, top_p, temperature, chunk_size, chunk_overlap, hf_token, max_new_tokens, quantization, offload_dir): |
| | rag_chain = BanglaRAGChain() |
| | rag_chain.load( |
| | chat_model_id=chat_model_id, |
| | embed_model_id=embed_model_id, |
| | text_path=text_path, |
| | k=k, |
| | top_k=top_k, |
| | top_p=top_p, |
| | temperature=temperature, |
| | chunk_size=chunk_size, |
| | chunk_overlap=chunk_overlap, |
| | hf_token=hf_token, |
| | max_new_tokens=max_new_tokens, |
| | quantization=quantization, |
| | offload_dir=offload_dir, |
| | ) |
| | return rag_chain |
| |
|
| | def main(): |
| | st.title("Bangla RAG Chatbot") |
| | |
| | |
| | st.sidebar.header("Model Configuration") |
| | |
| | chat_model_id = st.sidebar.text_input("Chat Model ID", DEFAULT_CHAT_MODEL_ID) |
| | embed_model_id = st.sidebar.text_input("Embed Model ID", DEFAULT_EMBED_MODEL_ID) |
| | k = st.sidebar.slider("Number of Documents to Retrieve (k)", 1, 10, DEFAULT_K) |
| | top_k = st.sidebar.slider("Top K", 1, 10, DEFAULT_TOP_K) |
| | top_p = st.sidebar.slider("Top P", 0.0, 1.0, DEFAULT_TOP_P) |
| | temperature = st.sidebar.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) |
| | max_new_tokens = st.sidebar.slider("Max New Tokens", 1, 512, DEFAULT_MAX_NEW_TOKENS) |
| | chunk_size = st.sidebar.slider("Chunk Size", 100, 1000, DEFAULT_CHUNK_SIZE) |
| | chunk_overlap = st.sidebar.slider("Chunk Overlap", 0, 500, DEFAULT_CHUNK_OVERLAP) |
| | text_path = st.sidebar.text_input("Text File Path", "text.txt") |
| | quantization = st.sidebar.checkbox("Enable Quantization (4-bit)", value=False) |
| | show_context = st.sidebar.checkbox("Show Retrieved Context", value=False) |
| | offload_dir = st.sidebar.text_input("Offload Directory", DEFAULT_OFFLOAD_DIR) |
| |
|
| | |
| | rag_chain = load_model( |
| | chat_model_id=chat_model_id, |
| | embed_model_id=embed_model_id, |
| | text_path=text_path, |
| | k=k, |
| | top_k=top_k, |
| | top_p=top_p, |
| | temperature=temperature, |
| | chunk_size=chunk_size, |
| | chunk_overlap=chunk_overlap, |
| | hf_token=None, |
| | max_new_tokens=max_new_tokens, |
| | quantization=quantization, |
| | offload_dir=offload_dir, |
| | ) |
| | |
| | st.write("### Enter your question:") |
| | query = st.text_input("আপনার প্রশ্ন") |
| |
|
| | if st.button("Generate Answer"): |
| | if query: |
| | try: |
| | answer, context = rag_chain(query) |
| | st.write(f"**Answer:** {answer}") |
| | if show_context: |
| | st.write(f"**Context:** {context}") |
| | except Exception as e: |
| | st.error(f"Couldn't generate an answer: {e}") |
| | else: |
| | st.warning("Please enter a query.") |
| | |
| | if __name__ == "__main__": |
| | main() |
| |
|