cm_penal_code / app.py
paulinusjua's picture
Update app.py
068d18a verified
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer, pipeline ,AutoModelForSeq2SeqLM,AutoModelForCausalLM,GenerationConfig
from langchain_huggingface import HuggingFacePipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import Tool
from langchain.agents import initialize_agent, AgentType
from langdetect import detect
import re
import os
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
from ctransformers import AutoModelForCausalLM, AutoTokenizer
from langchain.llms import CTransformers
BASE_PATH = os.getcwd()
INDEX_PATHS = {
"en": os.path.join(BASE_PATH, "faiss_index_en"),
"fr": os.path.join(BASE_PATH, "faiss_index_fr"),
}
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
#retriever = FAISS.load_local(folder_path=path, embeddings=embedding,allow_dangerous_deserialization=True).as_retriever()
model = AutoModelForCausalLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
model_file="tinyllama-1.1b-chat-v1.0.Q4_0.gguf",model_type="llama"
)
llm = CTransformers(
model="TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
model_file="tinyllama-1.1b-chat-v1.0.Q4_0.gguf",
model_type="llama",
config={"max_new_tokens": 512, "temperature": 0.0,"context_length": 4096}
)
system_template = ("You are Bot β€” an intelligent assistant trained on cameroon penal code data."
"You exist to help individuals answer questions about the Cameroonian Penal Code."
" You always provide the source penal code section or article number, clear, compliant, and factual answers grounded in official penal code documentation."
"When given an law question and information, you explain all components."
"If a query is ambiguous or unsupported, you politely defer or recommend reviewing the relevant penal code manually."
"You do not speculate or make law interpretations β€” you clarify with precision and data.")
condense_question_prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("placeholder", "{chat_history}"),
("human", "{input}"),
]
)
#history_aware_retriever = create_history_aware_retriever(
# llm, retriever, condense_question_prompt
#)
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)
from langchain.prompts import PromptTemplate
qa_prompt = PromptTemplate.from_template(
"You are a legal assistant. Use and highlight the following penal or article code number and context and conversation history to answer the current question.\n\n"
"Context:\n{context}\n\n"
"Conversation History:\n{chat_history}\n\n"
"Current Question:\n{input}\n"
"Answer:"
)
llm_chain = create_stuff_documents_chain(llm=llm, prompt=qa_prompt)
#rag_chain = create_retrieval_chain(history_aware_retriever, llm_chain)
# Preload retrievers once
retrievers = {
"en": FAISS.load_local(
folder_path=INDEX_PATHS["en"],
embeddings=embedding,
allow_dangerous_deserialization=True
).as_retriever(search_kwargs={"k": 2}),
"fr": FAISS.load_local(
folder_path=INDEX_PATHS["fr"],
embeddings=embedding,
allow_dangerous_deserialization=True
).as_retriever(search_kwargs={"k": 2}),
}
# Truncate long history
def truncate_history(chat_history, max_chars=1500):
total = 0
trimmed = []
for q, a in reversed(chat_history):
pair_len = len(q) + len(a)
if total + pair_len > max_chars:
break
trimmed.insert(0, (q, a))
total += pair_len
return trimmed
# Simpler, faster RAG function
def rag_tool_func(input_question: str, chat_history: list = None) -> str:
lang = detect(input_question)
lang = "fr" if lang == "fr" else "en"
retriever = retrievers[lang]
# Format chat history (optional, for prompt context)
chat_history = truncate_history(chat_history)
history_str = ""
if isinstance(chat_history, list):
for q, a in chat_history:
history_str += f"User: {q}\nAssistant: {a}\n"
rag_chain = create_retrieval_chain(retriever, create_stuff_documents_chain(llm=llm, prompt=qa_prompt))
result = rag_chain.invoke({
"input": input_question,
"chat_history": history_str
})
return result["answer"]
chat_history = [] # Global chat history
def chatbot_interface(user_input, history):
if history is None or not isinstance(history, list):
history = []
trimmed_history = truncate_history(history)
answer = rag_tool_func(user_input, trimmed_history)
history.append((user_input, answer))
return history, history # For chatbot + state
with gr.Blocks() as demo:
gr.Markdown("# πŸ‡¨πŸ‡² Cameroon Penal Code Chatbot")
chatbot_ui = gr.Chatbot(label="Ask me anything about the Cameroon Penal Code")
with gr.Row():
question_box = gr.Textbox(placeholder="Ask a legal question...", label="Your question")
send_btn = gr.Button("Send")
chat_state = gr.State([])
send_btn.click(fn=chatbot_interface, inputs=[question_box, chat_state], outputs=[chatbot_ui, chat_state])
question_box.submit(fn=chatbot_interface, inputs=[question_box, chat_state], outputs=[chatbot_ui, chat_state])
if __name__ == "__main__":
demo.launch()