import os from retriever import get_retriever from langchain.chains import RetrievalQA from transformers import pipeline from langchain_community.llms import HuggingFacePipeline from langchain_community.llms import HuggingFaceEndpoint from dotenv import load_dotenv load_dotenv() # Load retriever retriever = get_retriever() # Load Hugging Face LLM # Load the model pipeline pipe = pipeline( "text-generation", model="tiiuae/falcon-7b-instruct", trust_remote_code=True, device_map="auto", max_new_tokens=512, temperature=0.2 ) # Wrap in LangChain LLM llm = HuggingFacePipeline(pipeline=pipe) # Prompt templates english_prompt_template = """ You are a helpful Nigerian legal assistant. Answer clearly in English, keeping the legal facts correct. After the answer, list the sources you used. Question: {question} Answer: """ pidgin_prompt_template = """ You be legal assistant wey sabi Nigerian law well well. The user fit talk for English or Pidgin, but you go always answer for Nigerian Pidgin. No change the legal facts, but make am simple so person wey no study law fit understand. After you give the answer, put list of the sources wey you use. Question: {question} Answer for Nigerian Pidgin: """ # Create QA chain qa_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, chain_type="stuff", return_source_documents=True ) def chat(): print("šŸ“œ KnowYourRight Bot") print("Type 'exit' to stop.\n") # Ask language mode while True: lang_choice = input("Choose mode: [1] English [2] Pidgin: ").strip() if lang_choice in ["1", "2"]: break print("āŒ Invalid choice. Please type 1 or 2.") pidgin_mode = lang_choice == "2" # Start chat loop while True: query = input("\nYou: ") if query.lower() in ["exit", "quit"]: break # Pick prompt based on mode if pidgin_mode: formatted_query = pidgin_prompt_template.format(question=query) else: formatted_query = english_prompt_template.format(question=query) result = qa_chain.invoke({"query": formatted_query}) # Print answer print("\nBot:", result["result"]) # Print sources print("\nšŸ“š Sources:") for doc in result["source_documents"]: print("-", doc.metadata.get("source", "Unknown")) print("\n" + "-"*50) if __name__ == "__main__": chat()