menikev commited on
Commit
5b69f3e
·
verified ·
1 Parent(s): 369c8e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -56
app.py CHANGED
@@ -1,85 +1,223 @@
1
- import gradio as gr
 
2
  import os
3
- import time
4
- from langchain.chains import RetrievalQA
5
- from langchain.prompts import PromptTemplate
6
- from langchain_community.llms import HuggingFacePipeline
7
  from transformers import pipeline
 
8
 
9
- # Import your retriever (update import path as needed)
10
- try:
11
- from retriever import get_retriever
12
- except ImportError:
13
- from src.retriever import get_retriever
14
 
15
- # Use a lightweight model for fast inference
 
 
 
 
 
16
  pipe = pipeline(
17
  "text-generation",
18
- model="microsoft/DialoGPT-medium",
19
  device_map="auto",
20
- max_new_tokens=150,
 
21
  do_sample=False,
22
- pad_token_id=tokenizer.eos_token_id
23
  )
24
 
25
- llm = HuggingFacePipeline(pipeline=pipe)
26
- retriever = get_retriever()
 
 
 
 
 
 
 
 
 
27
 
28
- # Simple prompt template
29
- template = """Answer this legal question about Nigeria using the context provided.
30
- If asked to respond in Nigerian Pidgin, use Nigerian Pidgin.
 
31
 
32
  Context: {context}
 
33
  Question: {question}
34
- Answer:"""
35
 
36
- prompt = PromptTemplate(input_variables=["question", "context"], template=template)
 
37
 
38
- qa_chain = RetrievalQA.from_chain_type(
39
- llm=llm,
40
- retriever=retriever,
41
- chain_type="stuff",
42
- return_source_documents=True,
43
- chain_type_kwargs={"prompt": prompt}
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def answer_question(user_input, lang_choice):
47
- if not user_input.strip():
48
- return "Please enter a question."
49
-
50
- start_time = time.time()
51
-
52
  try:
53
- if lang_choice == "pidgin":
54
- user_input = f"Respond in Nigerian Pidgin: {user_input}"
 
 
 
55
 
56
- result = qa_chain.invoke({"query": user_input})
 
 
 
 
57
 
58
- processing_time = time.time() - start_time
59
- answer_text = result["result"]
 
60
 
61
- # Collect sources
62
- sources = list({doc.metadata.get("source", "Unknown")
63
- for doc in result["source_documents"][:3]})
64
- sources_text = "\n".join(f"📄 {os.path.basename(src)}" for src in sources)
 
65
 
66
- return f"{answer_text}\n\n**References:**\n{sources_text}\n\n*Response time: {processing_time:.1f}s*"
67
 
68
  except Exception as e:
69
- return f"Error: {str(e)}"
 
 
 
70
 
71
- # Gradio interface
72
- with gr.Blocks(title="KnowYourRight Bot") as demo:
73
- gr.Markdown("# 🇳🇬 KnowYourRight Bot\nAsk legal questions in English or Nigerian Pidgin")
 
 
 
 
74
 
75
- with gr.Row():
76
- question = gr.Textbox(label="Your question", lines=2)
77
- language = gr.Radio(["english", "pidgin"], label="Language", value="english")
78
 
79
- submit = gr.Button("Ask Question", variant="primary")
80
- output = gr.Textbox(label="Answer", lines=8)
 
 
 
 
 
81
 
82
- submit.click(answer_question, inputs=[question, language], outputs=output)
83
- question.submit(answer_question, inputs=[question, language], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- demo.launch()
 
1
+ # src/knowyourright_bot.py
2
+
3
  import os
4
+ from sentence_transformers import SentenceTransformer
5
+ import chromadb
6
+ from chromadb.config import Settings
 
7
  from transformers import pipeline
8
+ import gradio as gr
9
 
10
+ # Configuration
11
+ VECTOR_DIR = "vector_db"
12
+ MODEL_NAME = "microsoft/DialoGPT-medium" # Free, fast model
 
 
13
 
14
+ # Initialize embedding model and vector database
15
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
16
+ client = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=VECTOR_DIR))
17
+ collection = client.get_collection("laws")
18
+
19
+ # Initialize language model
20
  pipe = pipeline(
21
  "text-generation",
22
+ model=MODEL_NAME,
23
  device_map="auto",
24
+ max_new_tokens=300,
25
+ temperature=0.1,
26
  do_sample=False,
27
+ pad_token_id=50256
28
  )
29
 
30
+ # English Prompt Template
31
+ ENGLISH_TEMPLATE = """
32
+ You are a knowledgeable legal assistant for Nigerian law. Answer the question using only the provided context.
33
+ Be concise, accurate, and cite specific sections when possible.
34
+
35
+ Context: {context}
36
+
37
+ Question: {question}
38
+
39
+ Answer (in clear English):
40
+ """
41
 
42
+ # Pidgin Prompt Template
43
+ PIDGIN_TEMPLATE = """
44
+ You be legal assistant wey sabi Nigerian law well well. Use only the context wey dem give you answer the question.
45
+ Make your answer short, correct, and talk the specific law section if e dey.
46
 
47
  Context: {context}
48
+
49
  Question: {question}
 
50
 
51
+ Answer for Nigerian Pidgin:
52
+ """
53
 
54
+ def get_relevant_context(question, k=4):
55
+ """Retrieve relevant legal context from vector database"""
56
+ try:
57
+ q_emb = embed_model.encode([question], convert_to_numpy=True)
58
+ results = collection.query(
59
+ query_embeddings=q_emb,
60
+ n_results=k,
61
+ include=["documents", "metadatas"]
62
+ )
63
+
64
+ # Format context with sources
65
+ context_chunks = []
66
+ sources = []
67
+
68
+ for i, doc in enumerate(results['documents'][0]):
69
+ source = results['metadatas'][0][i].get("source", "Unknown")
70
+ context_chunks.append(doc)
71
+ sources.append(source)
72
+
73
+ context = "\n\n".join(context_chunks)
74
+ return context, sources
75
+
76
+ except Exception as e:
77
+ print(f"Error retrieving context: {e}")
78
+ return "", []
79
 
80
+ def generate_response(question, language="english"):
81
+ """Generate response using appropriate prompt template"""
 
 
 
 
82
  try:
83
+ # Get relevant context
84
+ context, sources = get_relevant_context(question)
85
+
86
+ if not context:
87
+ return "Sorry, I couldn't find relevant information to answer your question.", []
88
 
89
+ # Choose prompt template based on language
90
+ if language.lower() == "pidgin":
91
+ prompt = PIDGIN_TEMPLATE.format(context=context, question=question)
92
+ else:
93
+ prompt = ENGLISH_TEMPLATE.format(context=context, question=question)
94
 
95
+ # Generate response
96
+ response = pipe(prompt, max_new_tokens=256, do_sample=False, pad_token_id=50256)
97
+ answer = response[0]['generated_text']
98
 
99
+ # Extract only the generated part (remove the prompt)
100
+ if "Answer" in answer:
101
+ answer = answer.split("Answer")[-1].strip()
102
+ if answer.startswith("(in clear English):") or answer.startswith("for Nigerian Pidgin:"):
103
+ answer = answer.split(":", 1)[-1].strip()
104
 
105
+ return answer, sources
106
 
107
  except Exception as e:
108
+ error_msg = f"Sorry, I encountered an error: {str(e)}"
109
+ if language.lower() == "pidgin":
110
+ error_msg = "Sorry o, something happen when I dey answer your question. Try ask again."
111
+ return error_msg, []
112
 
113
+ def answer_question(user_input, lang_choice):
114
+ """Main function for processing questions"""
115
+ if not user_input or len(user_input.strip()) < 3:
116
+ return "Please ask a more specific question about your legal rights."
117
+
118
+ if len(user_input) > 1000:
119
+ return "Please ask a shorter question (maximum 1000 characters)."
120
 
121
+ # Generate response
122
+ answer, sources = generate_response(user_input.strip(), lang_choice)
 
123
 
124
+ # Format sources
125
+ if sources:
126
+ unique_sources = list(set([os.path.basename(src) for src in sources[:3]]))
127
+ sources_text = "\n".join(f"📄 {src}" for src in unique_sources)
128
+ formatted_response = f"{answer}\n\n**References:**\n{sources_text}"
129
+ else:
130
+ formatted_response = f"{answer}\n\n**References:**\n📄 No sources found"
131
 
132
+ return formatted_response
133
+
134
+ def create_gradio_interface():
135
+ """Create Gradio interface for testing"""
136
+ with gr.Blocks(
137
+ title="KnowYourRight Bot - Nigerian Legal Assistant",
138
+ theme=gr.themes.Soft()
139
+ ) as demo:
140
+
141
+ gr.Markdown(
142
+ """
143
+ # 🇳🇬 KnowYourRight Bot
144
+ ## Your AI Legal Assistant for Nigerian Law
145
+
146
+ Ask questions about your rights under:
147
+ - Nigerian Constitution
148
+ - Labor Laws
149
+ - Data Protection Regulation (NDPR)
150
+ - Consumer Protection Act (FCCPA)
151
+
152
+ **Available in English and Nigerian Pidgin**
153
+ """
154
+ )
155
+
156
+ with gr.Row():
157
+ with gr.Column(scale=3):
158
+ question_input = gr.Textbox(
159
+ label="Ask about your legal rights",
160
+ placeholder="e.g., Can my landlord evict me without notice?",
161
+ lines=3,
162
+ max_lines=5
163
+ )
164
+
165
+ with gr.Column(scale=1):
166
+ language_choice = gr.Radio(
167
+ choices=["english", "pidgin"],
168
+ label="Language / Language wey you wan use",
169
+ value="english"
170
+ )
171
+
172
+ submit_btn = gr.Button("Ask Question / Ask Question", variant="primary", size="lg")
173
+
174
+ answer_output = gr.Textbox(
175
+ label="Answer / Answer",
176
+ lines=10,
177
+ max_lines=15
178
+ )
179
+
180
+ # Example questions
181
+ gr.Markdown("### Example Questions / Example Questions")
182
+ examples = [
183
+ ["Can my employer sack me without notice?", "english"],
184
+ ["Wetin be my right as tenant?", "pidgin"],
185
+ ["What does NDPR say about data privacy?", "english"],
186
+ ["How can I report consumer fraud?", "english"],
187
+ ["Wetin happen if person collect my data without permission?", "pidgin"]
188
+ ]
189
+
190
+ gr.Examples(
191
+ examples=examples,
192
+ inputs=[question_input, language_choice],
193
+ outputs=answer_output,
194
+ fn=answer_question
195
+ )
196
+
197
+ # Event handlers
198
+ submit_btn.click(
199
+ fn=answer_question,
200
+ inputs=[question_input, language_choice],
201
+ outputs=answer_output
202
+ )
203
+
204
+ question_input.submit(
205
+ fn=answer_question,
206
+ inputs=[question_input, language_choice],
207
+ outputs=answer_output
208
+ )
209
+
210
+ # Footer
211
+ gr.Markdown(
212
+ """
213
+ ---
214
+ **Disclaimer:** This is an AI assistant for informational purposes only.
215
+ For legal advice, consult a qualified lawyer.
216
+
217
+ Built by **AI Club Lagos** | Open Source Project
218
+ """
219
+ )
220
+
221
+ return demo
222
+
223