Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,164 +1,206 @@
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
import gradio as gr
|
| 4 |
-
|
| 5 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 6 |
-
from langchain_community.llms import HuggingFacePipeline
|
| 7 |
from langchain.prompts import PromptTemplate
|
| 8 |
from langchain_community.vectorstores import Chroma
|
| 9 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
# ------
|
| 12 |
-
|
| 13 |
-
|
| 14 |
PERSIST_DIR = Path("data/processed/vector_db")
|
|
|
|
| 15 |
if not PERSIST_DIR.exists() or not any(PERSIST_DIR.iterdir()):
|
| 16 |
print("⚠️ Vector DB not found. Run complete_ingestion.py first.")
|
| 17 |
raise SystemExit(1)
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
vectordb = Chroma(
|
| 21 |
persist_directory=str(PERSIST_DIR),
|
| 22 |
embedding_function=embedding_model,
|
| 23 |
collection_name="legal_documents"
|
| 24 |
)
|
| 25 |
|
| 26 |
-
retriever
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
max_new_tokens=120, # reduced for speed
|
| 40 |
-
temperature=0.2,
|
| 41 |
-
top_p=0.85,
|
| 42 |
-
do_sample=True,
|
| 43 |
-
repetition_penalty=1.05,
|
| 44 |
-
return_full_text=False,
|
| 45 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
-
# ------
|
| 64 |
-
|
| 65 |
-
# ----------------------------
|
| 66 |
-
def _format_history(turns, max_turns=4):
|
| 67 |
-
if not turns:
|
| 68 |
-
return ""
|
| 69 |
-
turns = turns[-max_turns:]
|
| 70 |
-
return "\n".join([f"User: {u}\nAssistant: {a}" for u, a in turns])
|
| 71 |
-
|
| 72 |
-
def _retrieve(question, k=3):
|
| 73 |
-
docs = retriever.invoke(question) # ✅ fixed deprecation
|
| 74 |
-
texts = [d.page_content.strip() for d in docs[:k]]
|
| 75 |
-
context = "\n\n---\n\n".join(texts)
|
| 76 |
-
return context, docs
|
| 77 |
-
|
| 78 |
-
def _generate(question, history):
|
| 79 |
-
hist = _format_history(history, max_turns=4)
|
| 80 |
-
context, docs = _retrieve(question, k=3)
|
| 81 |
-
prompt = RAG_PROMPT.format(question=question, context=context, history=hist)
|
| 82 |
-
out = llm.invoke(prompt) # ✅ fixed deprecation
|
| 83 |
-
if isinstance(out, list) and out and "generated_text" in out[0]:
|
| 84 |
-
text = out[0]["generated_text"]
|
| 85 |
-
else:
|
| 86 |
-
text = str(out)
|
| 87 |
-
return text.strip(), docs
|
| 88 |
-
|
| 89 |
-
# ----------------------------
|
| 90 |
-
# Main logic
|
| 91 |
-
# ----------------------------
|
| 92 |
def answer_question(user_input, lang_choice, history=[]):
|
|
|
|
| 93 |
try:
|
| 94 |
-
|
| 95 |
-
if not
|
| 96 |
return history, history
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
history.append((user_input, ans))
|
| 103 |
return history, history
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
answer = "I
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
return history[-8:], history[-8:]
|
| 131 |
|
| 132 |
except Exception as e:
|
| 133 |
-
print(f"
|
| 134 |
-
|
| 135 |
-
history.append((user_input,
|
| 136 |
return history, history
|
| 137 |
|
| 138 |
def _reset():
|
|
|
|
| 139 |
return [], []
|
| 140 |
|
| 141 |
-
# ------
|
| 142 |
-
|
| 143 |
-
# ----------------------------
|
| 144 |
def build_ui():
|
| 145 |
-
|
|
|
|
| 146 |
gr.Markdown("# 📜 KnowYourRight Bot — Nigerian Legal Assistant")
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
with gr.Row():
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
return demo
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 4 |
from langchain.prompts import PromptTemplate
|
| 5 |
from langchain_community.vectorstores import Chroma
|
| 6 |
+
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceHub
|
| 7 |
+
from langchain.schema.runnable import RunnablePassthrough
|
| 8 |
+
from langchain.schema.output_parser import StrOutputParser
|
| 9 |
+
|
| 10 |
+
# --- 1. CONFIGURATION & INITIALIZATION ---
|
| 11 |
+
|
| 12 |
+
# Load environment variables (for Hugging Face API token)
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
# Check for the API token
|
| 17 |
+
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
|
| 18 |
+
print(" HUGGINGFACEHUB_API_TOKEN not found in secrets. Please add it.")
|
| 19 |
+
exit()
|
| 20 |
|
| 21 |
+
# --- 2. LOAD VECTOR DATABASE (Retriever) ---
|
| 22 |
+
|
| 23 |
+
print("Loading vector database...")
|
| 24 |
PERSIST_DIR = Path("data/processed/vector_db")
|
| 25 |
+
|
| 26 |
if not PERSIST_DIR.exists() or not any(PERSIST_DIR.iterdir()):
|
| 27 |
print("⚠️ Vector DB not found. Run complete_ingestion.py first.")
|
| 28 |
raise SystemExit(1)
|
| 29 |
|
| 30 |
+
# Use the same embedding model as in the ingestion script
|
| 31 |
+
embedding_model = HuggingFaceEmbeddings(
|
| 32 |
+
model_name="BAAI/bge-small-en",
|
| 33 |
+
model_kwargs={'device': 'cpu'} # Run embeddings on CPU
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Load the Chroma vector store
|
| 37 |
vectordb = Chroma(
|
| 38 |
persist_directory=str(PERSIST_DIR),
|
| 39 |
embedding_function=embedding_model,
|
| 40 |
collection_name="legal_documents"
|
| 41 |
)
|
| 42 |
|
| 43 |
+
# Create a retriever to fetch relevant documents
|
| 44 |
+
# Increasing k to 4 gives the LLM more context to work with
|
| 45 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 4})
|
| 46 |
+
print("Vector database loaded successfully.")
|
| 47 |
+
|
| 48 |
+
# --- 3. SETUP THE LIGHTWEIGHT LLM (via Inference API) ---
|
| 49 |
+
|
| 50 |
+
print("Initializing LLM via Hugging Face Hub...")
|
| 51 |
+
# We use the Inference API to avoid loading the model locally, which is much faster.
|
| 52 |
+
# Mixtral is a powerful model available on the free tier.
|
| 53 |
+
llm = HuggingFaceHub(
|
| 54 |
+
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
| 55 |
+
model_kwargs={"temperature": 0.1, "max_length": 1024, "max_new_tokens": 512}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
+
print("LLM initialized.")
|
| 58 |
+
|
| 59 |
+
# --- 4. CREATE THE IMPROVED PROMPT TEMPLATE ---
|
| 60 |
+
|
| 61 |
+
# This new prompt is more directive and helps shape the output.
|
| 62 |
+
RAG_PROMPT_TEMPLATE = """
|
| 63 |
+
You are an expert Nigerian Legal Assistant. Your primary goal is to help users understand Nigerian law by providing clear, concise, and helpful explanations.
|
| 64 |
+
|
| 65 |
+
**TASK:** Analyze the provided legal context below to answer the user's question.
|
| 66 |
+
|
| 67 |
+
**CONTEXT:**
|
| 68 |
+
{context}
|
| 69 |
|
| 70 |
+
**RULES:**
|
| 71 |
+
1. **Explain, Don't Just Quote:** Do not just copy the text from the context. You MUST synthesize, summarize, and explain the relevant laws in simple, easy-to-understand language.
|
| 72 |
+
2. **Be Conversational:** Respond in a helpful and advisory tone.
|
| 73 |
+
3. **Use Only Provided Context:** Base your answer SOLELY on the provided context. If the context does not contain the information needed to answer the question, you MUST say "The provided legal documents do not contain specific information on this topic." Do not use outside knowledge.
|
| 74 |
+
4. **Language:** Respond in the user's chosen language (English or Nigerian Pidgin).
|
| 75 |
+
5. **Citations:** At the end of your answer, always list the sources you used from the context.
|
| 76 |
+
|
| 77 |
+
**QUESTION:** {question}
|
| 78 |
+
|
| 79 |
+
**ANSWER:**
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
RAG_PROMPT = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
|
| 83 |
+
|
| 84 |
+
# --- 5. DEFINE THE RAG CHAIN ---
|
| 85 |
+
|
| 86 |
+
def format_docs(docs):
|
| 87 |
+
"""Helper function to format retrieved documents into a single string."""
|
| 88 |
+
return "\n\n---\n\n".join(f"Source: {d.metadata.get('source', 'Unknown')}\nSection: {d.metadata.get('section', 'Unknown')}\nContent: {d.page_content}" for d in docs)
|
| 89 |
+
|
| 90 |
+
# Create the LangChain RAG chain
|
| 91 |
+
rag_chain = (
|
| 92 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
| 93 |
+
| RAG_PROMPT
|
| 94 |
+
| llm
|
| 95 |
+
| StrOutputParser()
|
| 96 |
)
|
| 97 |
|
| 98 |
+
# --- 6. MAIN APPLICATION LOGIC ---
|
| 99 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def answer_question(user_input, lang_choice, history=[]):
|
| 101 |
+
"""Main function to handle user queries, run the RAG chain, and format the output."""
|
| 102 |
try:
|
| 103 |
+
query = (user_input or "").strip()
|
| 104 |
+
if not query:
|
| 105 |
return history, history
|
| 106 |
|
| 107 |
+
# Simple conversational starters
|
| 108 |
+
if query.lower() in ["hi", "hello", "hey"]:
|
| 109 |
+
ans = ("Hello! I'm your Nigerian Legal AI Assistant. How can I help you today?"
|
| 110 |
+
if lang_choice == "english" else
|
| 111 |
+
"Howfa! I be your Nigerian Legal AI Assistant. How I fit help you today? No be legal advice o.")
|
| 112 |
history.append((user_input, ans))
|
| 113 |
return history, history
|
| 114 |
|
| 115 |
+
print(f"Received query: {query}")
|
| 116 |
+
|
| 117 |
+
# Retrieve documents first to build references
|
| 118 |
+
docs = retriever.invoke(query)
|
| 119 |
+
if not docs:
|
| 120 |
+
print("No documents retrieved.")
|
| 121 |
+
answer = "I could not find any relevant information in the legal documents for your query. Please try rephrasing."
|
| 122 |
+
else:
|
| 123 |
+
# Invoke the RAG chain to get the answer
|
| 124 |
+
print("Invoking RAG chain...")
|
| 125 |
+
answer = rag_chain.invoke(query)
|
| 126 |
+
print("RAG chain finished.")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Add a disclaimer
|
| 130 |
+
disclaimer = ("\n\n--- \n*⚠️ Disclaimer: This is AI-generated information and not legal advice. Please consult a qualified lawyer for professional guidance.*"
|
| 131 |
+
if lang_choice == "english" else
|
| 132 |
+
"\n\n--- \n*⚠️ No be legal advice o, abeg find lawyer for proper advice.*")
|
| 133 |
+
|
| 134 |
+
# Build robust references
|
| 135 |
+
# Use a set to avoid duplicate references
|
| 136 |
+
references = set()
|
| 137 |
+
for doc in docs:
|
| 138 |
+
source = doc.metadata.get("source", "Unknown Source")
|
| 139 |
+
section = doc.metadata.get("section", "Unknown Section")
|
| 140 |
+
# Only add if both source and section are known
|
| 141 |
+
if source != "Unknown Source" and section != "Unknown Section":
|
| 142 |
+
references.add(f"- {source} ({section})")
|
| 143 |
+
|
| 144 |
+
if references:
|
| 145 |
+
answer += "\n\n**References:**\n" + "\n".join(sorted(list(references)))
|
| 146 |
+
|
| 147 |
+
answer += disclaimer
|
| 148 |
+
|
| 149 |
+
history.append((user_input, answer.strip()))
|
| 150 |
+
|
| 151 |
+
# Keep chat history to a reasonable length
|
| 152 |
return history[-8:], history[-8:]
|
| 153 |
|
| 154 |
except Exception as e:
|
| 155 |
+
print(f"An error occurred: {e}")
|
| 156 |
+
error_message = "Sorry, an unexpected error occurred. Please try again or rephrase your question."
|
| 157 |
+
history.append((user_input, error_message))
|
| 158 |
return history, history
|
| 159 |
|
| 160 |
def _reset():
|
| 161 |
+
"""Resets the chat state."""
|
| 162 |
return [], []
|
| 163 |
|
| 164 |
+
# --- 7. GRADIO UI ---
|
| 165 |
+
|
|
|
|
| 166 |
def build_ui():
|
| 167 |
+
"""Builds the Gradio web interface."""
|
| 168 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="KnowYourRight Bot") as demo:
|
| 169 |
gr.Markdown("# 📜 KnowYourRight Bot — Nigerian Legal Assistant")
|
| 170 |
+
gr.Markdown("Ask questions about the Nigerian Constitution, Labour Act, and more. *Powered by AI.*")
|
| 171 |
+
|
| 172 |
+
chatbot = gr.Chatbot(label="Chat History", height=600, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
|
| 173 |
+
|
| 174 |
with gr.Row():
|
| 175 |
+
msg = gr.Textbox(
|
| 176 |
+
label="Your Question",
|
| 177 |
+
placeholder="e.g., 'What are my rights if I am arrested?'",
|
| 178 |
+
lines=2,
|
| 179 |
+
scale=4,
|
| 180 |
+
)
|
| 181 |
+
submit_btn = gr.Button("▶️ Send", variant="primary", scale=1)
|
| 182 |
+
|
| 183 |
+
lang_choice = gr.Radio(["english", "pidgin"], value="english", label="Response Language")
|
| 184 |
+
clear_btn = gr.Button("🗑️ Clear Chat")
|
| 185 |
+
|
| 186 |
+
# State to store the conversation history
|
| 187 |
+
chat_state = gr.State([])
|
| 188 |
+
|
| 189 |
+
# Event handlers
|
| 190 |
+
submit_btn.click(answer_question, [msg, lang_choice, chat_state], [chatbot, chat_state])
|
| 191 |
+
msg.submit(answer_question, [msg, lang_choice, chat_state], [chatbot, chat_state])
|
| 192 |
+
|
| 193 |
+
# Clear the input textbox after submission
|
| 194 |
+
clear_on_submit = [submit_btn, msg]
|
| 195 |
+
for component in clear_on_submit:
|
| 196 |
+
component.click(lambda: "", None, msg)
|
| 197 |
+
|
| 198 |
+
clear_btn.click(_reset, None, [chatbot, chat_state])
|
| 199 |
+
|
| 200 |
return demo
|
| 201 |
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
print("Building Gradio UI...")
|
| 204 |
+
demo = build_ui()
|
| 205 |
+
print("Launching Gradio app...")
|
| 206 |
+
demo.launch(debug=True) # Set debug=False for production
|