menikev commited on
Commit
8307462
·
verified ·
1 Parent(s): 5ff90e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -123
app.py CHANGED
@@ -1,164 +1,206 @@
1
  import os
2
  from pathlib import Path
3
  import gradio as gr
4
-
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
- from langchain_community.llms import HuggingFacePipeline
7
  from langchain.prompts import PromptTemplate
8
  from langchain_community.vectorstores import Chroma
9
- from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # ----------------------------
12
- # Load vector DB
13
- # ----------------------------
14
  PERSIST_DIR = Path("data/processed/vector_db")
 
15
  if not PERSIST_DIR.exists() or not any(PERSIST_DIR.iterdir()):
16
  print("⚠️ Vector DB not found. Run complete_ingestion.py first.")
17
  raise SystemExit(1)
18
 
19
- embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
 
 
 
 
 
 
20
  vectordb = Chroma(
21
  persist_directory=str(PERSIST_DIR),
22
  embedding_function=embedding_model,
23
  collection_name="legal_documents"
24
  )
25
 
26
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
27
-
28
- # ----------------------------
29
- # Lightweight LLM
30
- # ----------------------------
31
- MODEL_ID = os.getenv("LLM_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
32
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
33
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
34
-
35
- gen_pipe = pipeline(
36
- "text-generation",
37
- model=model,
38
- tokenizer=tokenizer,
39
- max_new_tokens=120, # reduced for speed
40
- temperature=0.2,
41
- top_p=0.85,
42
- do_sample=True,
43
- repetition_penalty=1.05,
44
- return_full_text=False,
45
  )
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- llm = HuggingFacePipeline(pipeline=gen_pipe)
48
-
49
- # ----------------------------
50
- # Prompt
51
- # ----------------------------
52
- RAG_PROMPT = PromptTemplate.from_template(
53
- "You are a helpful Nigerian Legal Assistant.\n"
54
- "Respond conversationally, summarize clearly, and explain in plain English (or Pidgin if chosen).\n"
55
- "Always include the referenced section(s) at the end.\n"
56
- "If the answer is not in the context, say you don't know.\n\n"
57
- "Conversation history:\n{history}\n\n"
58
- "Question: {question}\n\n"
59
- "Context from legal documents:\n{context}\n\n"
60
- "Answer:"
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
- # ----------------------------
64
- # Helpers
65
- # ----------------------------
66
- def _format_history(turns, max_turns=4):
67
- if not turns:
68
- return ""
69
- turns = turns[-max_turns:]
70
- return "\n".join([f"User: {u}\nAssistant: {a}" for u, a in turns])
71
-
72
- def _retrieve(question, k=3):
73
- docs = retriever.invoke(question) # ✅ fixed deprecation
74
- texts = [d.page_content.strip() for d in docs[:k]]
75
- context = "\n\n---\n\n".join(texts)
76
- return context, docs
77
-
78
- def _generate(question, history):
79
- hist = _format_history(history, max_turns=4)
80
- context, docs = _retrieve(question, k=3)
81
- prompt = RAG_PROMPT.format(question=question, context=context, history=hist)
82
- out = llm.invoke(prompt) # ✅ fixed deprecation
83
- if isinstance(out, list) and out and "generated_text" in out[0]:
84
- text = out[0]["generated_text"]
85
- else:
86
- text = str(out)
87
- return text.strip(), docs
88
-
89
- # ----------------------------
90
- # Main logic
91
- # ----------------------------
92
  def answer_question(user_input, lang_choice, history=[]):
 
93
  try:
94
- q = (user_input or "").strip()
95
- if not q:
96
  return history, history
97
 
98
- if q.lower() in ["hi", "hello", "hey"]:
99
- ans = "Hello! I'm your Nigerian Legal AI Assistant. How can I help you?" \
100
- if lang_choice == "english" else \
101
- "Hello! I be your Nigerian Legal AI Assistant. How I fit help you? No be legal advice o."
 
102
  history.append((user_input, ans))
103
  return history, history
104
 
105
- if len(q) > 300:
106
- q = q[:300] + "..."
107
-
108
- answer, docs = _generate(q, history)
109
-
110
- if not answer or len(answer) < 5:
111
- answer = "I don't know from the available context. Please try rephrasing your question." \
112
- if lang_choice == "english" else \
113
- "I no sure from the context wey I get. Abeg rephrase your question."
114
-
115
- disclaimer = "⚠️ This is not legal advice. Please consult a qualified lawyer." \
116
- if lang_choice == "english" else \
117
- "⚠️ No be legal advice o, abeg meet lawyer."
118
- answer += f"\n\n{disclaimer}"
119
-
120
- # references improved
121
- refs = []
122
- for d in docs[:2]:
123
- src = d.metadata.get("source", "Unknown Source")
124
- sec = d.metadata.get("section", "Unknown Section")
125
- refs.append(f"{src} {sec}")
126
- if refs:
127
- answer += "\n\nReferenced: " + "; ".join(refs)
128
-
129
- history.append((user_input, answer))
 
 
 
 
 
 
 
 
 
 
 
 
130
  return history[-8:], history[-8:]
131
 
132
  except Exception as e:
133
- print(f"Error: {e}")
134
- err = "Sorry, an error occurred. Please try again."
135
- history.append((user_input, err))
136
  return history, history
137
 
138
  def _reset():
 
139
  return [], []
140
 
141
- # ----------------------------
142
- # UI
143
- # ----------------------------
144
  def build_ui():
145
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
 
146
  gr.Markdown("# 📜 KnowYourRight Bot — Nigerian Legal Assistant")
147
- chatbot = gr.Chatbot(label="Chat with Legal AI", height=600, bubble_full_width=False)
148
- msg = gr.Textbox(label="Ask your question...", placeholder="Type your legal question here", lines=2)
149
- lang_choice = gr.Radio(["english", "pidgin"], value="english", label="Language")
150
-
151
  with gr.Row():
152
- submit = gr.Button("Send", variant="primary")
153
- clear = gr.Button("Clear Chat")
154
-
155
- state = gr.State([])
156
- submit.click(answer_question, [msg, lang_choice, state], [chatbot, state])
157
- submit.click(lambda: "", None, msg)
158
- msg.submit(answer_question, [msg, lang_choice, state], [chatbot, state])
159
- msg.submit(lambda: "", None, msg)
160
- clear.click(_reset, None, [chatbot, state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return demo
162
 
163
- demo = build_ui()
164
- demo.launch()
 
 
 
 
1
  import os
2
  from pathlib import Path
3
  import gradio as gr
 
 
 
4
  from langchain.prompts import PromptTemplate
5
  from langchain_community.vectorstores import Chroma
6
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceHub
7
+ from langchain.schema.runnable import RunnablePassthrough
8
+ from langchain.schema.output_parser import StrOutputParser
9
+
10
+ # --- 1. CONFIGURATION & INITIALIZATION ---
11
+
12
+ # Load environment variables (for Hugging Face API token)
13
+ from dotenv import load_dotenv
14
+ load_dotenv()
15
+
16
+ # Check for the API token
17
+ if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
18
+ print(" HUGGINGFACEHUB_API_TOKEN not found in secrets. Please add it.")
19
+ exit()
20
 
21
+ # --- 2. LOAD VECTOR DATABASE (Retriever) ---
22
+
23
+ print("Loading vector database...")
24
  PERSIST_DIR = Path("data/processed/vector_db")
25
+
26
  if not PERSIST_DIR.exists() or not any(PERSIST_DIR.iterdir()):
27
  print("⚠️ Vector DB not found. Run complete_ingestion.py first.")
28
  raise SystemExit(1)
29
 
30
+ # Use the same embedding model as in the ingestion script
31
+ embedding_model = HuggingFaceEmbeddings(
32
+ model_name="BAAI/bge-small-en",
33
+ model_kwargs={'device': 'cpu'} # Run embeddings on CPU
34
+ )
35
+
36
+ # Load the Chroma vector store
37
  vectordb = Chroma(
38
  persist_directory=str(PERSIST_DIR),
39
  embedding_function=embedding_model,
40
  collection_name="legal_documents"
41
  )
42
 
43
+ # Create a retriever to fetch relevant documents
44
+ # Increasing k to 4 gives the LLM more context to work with
45
+ retriever = vectordb.as_retriever(search_kwargs={"k": 4})
46
+ print("Vector database loaded successfully.")
47
+
48
+ # --- 3. SETUP THE LIGHTWEIGHT LLM (via Inference API) ---
49
+
50
+ print("Initializing LLM via Hugging Face Hub...")
51
+ # We use the Inference API to avoid loading the model locally, which is much faster.
52
+ # Mixtral is a powerful model available on the free tier.
53
+ llm = HuggingFaceHub(
54
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
55
+ model_kwargs={"temperature": 0.1, "max_length": 1024, "max_new_tokens": 512}
 
 
 
 
 
 
56
  )
57
+ print("LLM initialized.")
58
+
59
+ # --- 4. CREATE THE IMPROVED PROMPT TEMPLATE ---
60
+
61
+ # This new prompt is more directive and helps shape the output.
62
+ RAG_PROMPT_TEMPLATE = """
63
+ You are an expert Nigerian Legal Assistant. Your primary goal is to help users understand Nigerian law by providing clear, concise, and helpful explanations.
64
+
65
+ **TASK:** Analyze the provided legal context below to answer the user's question.
66
+
67
+ **CONTEXT:**
68
+ {context}
69
 
70
+ **RULES:**
71
+ 1. **Explain, Don't Just Quote:** Do not just copy the text from the context. You MUST synthesize, summarize, and explain the relevant laws in simple, easy-to-understand language.
72
+ 2. **Be Conversational:** Respond in a helpful and advisory tone.
73
+ 3. **Use Only Provided Context:** Base your answer SOLELY on the provided context. If the context does not contain the information needed to answer the question, you MUST say "The provided legal documents do not contain specific information on this topic." Do not use outside knowledge.
74
+ 4. **Language:** Respond in the user's chosen language (English or Nigerian Pidgin).
75
+ 5. **Citations:** At the end of your answer, always list the sources you used from the context.
76
+
77
+ **QUESTION:** {question}
78
+
79
+ **ANSWER:**
80
+ """
81
+
82
+ RAG_PROMPT = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
83
+
84
+ # --- 5. DEFINE THE RAG CHAIN ---
85
+
86
+ def format_docs(docs):
87
+ """Helper function to format retrieved documents into a single string."""
88
+ return "\n\n---\n\n".join(f"Source: {d.metadata.get('source', 'Unknown')}\nSection: {d.metadata.get('section', 'Unknown')}\nContent: {d.page_content}" for d in docs)
89
+
90
+ # Create the LangChain RAG chain
91
+ rag_chain = (
92
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
93
+ | RAG_PROMPT
94
+ | llm
95
+ | StrOutputParser()
96
  )
97
 
98
+ # --- 6. MAIN APPLICATION LOGIC ---
99
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def answer_question(user_input, lang_choice, history=[]):
101
+ """Main function to handle user queries, run the RAG chain, and format the output."""
102
  try:
103
+ query = (user_input or "").strip()
104
+ if not query:
105
  return history, history
106
 
107
+ # Simple conversational starters
108
+ if query.lower() in ["hi", "hello", "hey"]:
109
+ ans = ("Hello! I'm your Nigerian Legal AI Assistant. How can I help you today?"
110
+ if lang_choice == "english" else
111
+ "Howfa! I be your Nigerian Legal AI Assistant. How I fit help you today? No be legal advice o.")
112
  history.append((user_input, ans))
113
  return history, history
114
 
115
+ print(f"Received query: {query}")
116
+
117
+ # Retrieve documents first to build references
118
+ docs = retriever.invoke(query)
119
+ if not docs:
120
+ print("No documents retrieved.")
121
+ answer = "I could not find any relevant information in the legal documents for your query. Please try rephrasing."
122
+ else:
123
+ # Invoke the RAG chain to get the answer
124
+ print("Invoking RAG chain...")
125
+ answer = rag_chain.invoke(query)
126
+ print("RAG chain finished.")
127
+
128
+
129
+ # Add a disclaimer
130
+ disclaimer = ("\n\n--- \n*⚠️ Disclaimer: This is AI-generated information and not legal advice. Please consult a qualified lawyer for professional guidance.*"
131
+ if lang_choice == "english" else
132
+ "\n\n--- \n*⚠️ No be legal advice o, abeg find lawyer for proper advice.*")
133
+
134
+ # Build robust references
135
+ # Use a set to avoid duplicate references
136
+ references = set()
137
+ for doc in docs:
138
+ source = doc.metadata.get("source", "Unknown Source")
139
+ section = doc.metadata.get("section", "Unknown Section")
140
+ # Only add if both source and section are known
141
+ if source != "Unknown Source" and section != "Unknown Section":
142
+ references.add(f"- {source} ({section})")
143
+
144
+ if references:
145
+ answer += "\n\n**References:**\n" + "\n".join(sorted(list(references)))
146
+
147
+ answer += disclaimer
148
+
149
+ history.append((user_input, answer.strip()))
150
+
151
+ # Keep chat history to a reasonable length
152
  return history[-8:], history[-8:]
153
 
154
  except Exception as e:
155
+ print(f"An error occurred: {e}")
156
+ error_message = "Sorry, an unexpected error occurred. Please try again or rephrase your question."
157
+ history.append((user_input, error_message))
158
  return history, history
159
 
160
  def _reset():
161
+ """Resets the chat state."""
162
  return [], []
163
 
164
+ # --- 7. GRADIO UI ---
165
+
 
166
  def build_ui():
167
+ """Builds the Gradio web interface."""
168
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="KnowYourRight Bot") as demo:
169
  gr.Markdown("# 📜 KnowYourRight Bot — Nigerian Legal Assistant")
170
+ gr.Markdown("Ask questions about the Nigerian Constitution, Labour Act, and more. *Powered by AI.*")
171
+
172
+ chatbot = gr.Chatbot(label="Chat History", height=600, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
173
+
174
  with gr.Row():
175
+ msg = gr.Textbox(
176
+ label="Your Question",
177
+ placeholder="e.g., 'What are my rights if I am arrested?'",
178
+ lines=2,
179
+ scale=4,
180
+ )
181
+ submit_btn = gr.Button("▶️ Send", variant="primary", scale=1)
182
+
183
+ lang_choice = gr.Radio(["english", "pidgin"], value="english", label="Response Language")
184
+ clear_btn = gr.Button("🗑️ Clear Chat")
185
+
186
+ # State to store the conversation history
187
+ chat_state = gr.State([])
188
+
189
+ # Event handlers
190
+ submit_btn.click(answer_question, [msg, lang_choice, chat_state], [chatbot, chat_state])
191
+ msg.submit(answer_question, [msg, lang_choice, chat_state], [chatbot, chat_state])
192
+
193
+ # Clear the input textbox after submission
194
+ clear_on_submit = [submit_btn, msg]
195
+ for component in clear_on_submit:
196
+ component.click(lambda: "", None, msg)
197
+
198
+ clear_btn.click(_reset, None, [chatbot, chat_state])
199
+
200
  return demo
201
 
202
+ if __name__ == "__main__":
203
+ print("Building Gradio UI...")
204
+ demo = build_ui()
205
+ print("Launching Gradio app...")
206
+ demo.launch(debug=True) # Set debug=False for production