menikev's picture
Upload 8 files
84fb4ed verified
import os
from retriever import get_retriever
from langchain.chains import RetrievalQA
from transformers import pipeline
from langchain_community.llms import HuggingFacePipeline
from langchain_community.llms import HuggingFaceEndpoint
from dotenv import load_dotenv
load_dotenv()
# Load retriever
retriever = get_retriever()
# Load Hugging Face LLM
# Load the model pipeline
pipe = pipeline(
"text-generation",
model="tiiuae/falcon-7b-instruct",
trust_remote_code=True,
device_map="auto",
max_new_tokens=512,
temperature=0.2
)
# Wrap in LangChain LLM
llm = HuggingFacePipeline(pipeline=pipe)
# Prompt templates
english_prompt_template = """
You are a helpful Nigerian legal assistant.
Answer clearly in English, keeping the legal facts correct.
After the answer, list the sources you used.
Question: {question}
Answer:
"""
pidgin_prompt_template = """
You be legal assistant wey sabi Nigerian law well well.
The user fit talk for English or Pidgin, but you go always answer for Nigerian Pidgin.
No change the legal facts, but make am simple so person wey no study law fit understand.
After you give the answer, put list of the sources wey you use.
Question: {question}
Answer for Nigerian Pidgin:
"""
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff",
return_source_documents=True
)
def chat():
print("πŸ“œ KnowYourRight Bot")
print("Type 'exit' to stop.\n")
# Ask language mode
while True:
lang_choice = input("Choose mode: [1] English [2] Pidgin: ").strip()
if lang_choice in ["1", "2"]:
break
print("❌ Invalid choice. Please type 1 or 2.")
pidgin_mode = lang_choice == "2"
# Start chat loop
while True:
query = input("\nYou: ")
if query.lower() in ["exit", "quit"]:
break
# Pick prompt based on mode
if pidgin_mode:
formatted_query = pidgin_prompt_template.format(question=query)
else:
formatted_query = english_prompt_template.format(question=query)
result = qa_chain.invoke({"query": formatted_query})
# Print answer
print("\nBot:", result["result"])
# Print sources
print("\nπŸ“š Sources:")
for doc in result["source_documents"]:
print("-", doc.metadata.get("source", "Unknown"))
print("\n" + "-"*50)
if __name__ == "__main__":
chat()