Brian269 commited on
Commit
551c35e
Β·
verified Β·
1 Parent(s): 7c56873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -34
app.py CHANGED
@@ -10,17 +10,25 @@ from langchain.docstore.document import Document
10
  from langchain.embeddings import HuggingFaceEmbeddings
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
12
 
13
- # =============================
14
  # PAGE CONFIG
15
- # =============================
16
- st.set_page_config(page_title="Kenya Legal Assistant", layout="wide")
 
 
 
17
 
18
- # =============================
19
- # CACHE EMBEDDINGS + VECTOR DB
20
- # =============================
21
- @st.cache_resource
 
 
 
22
  def load_vectorstore():
23
 
 
 
24
  dataset = load_dataset(
25
  "Brian269/Kenyan_Judgements",
26
  split="train",
@@ -29,12 +37,16 @@ def load_vectorstore():
29
 
30
  documents = []
31
  for i, item in enumerate(dataset):
32
- if i > 200: # prevent HF timeout
33
  break
 
34
  documents.append(
35
  Document(
36
  page_content=item["text"],
37
- metadata={"source": item["file_name"], "page": 1},
 
 
 
38
  )
39
  )
40
 
@@ -46,7 +58,9 @@ def load_vectorstore():
46
  chunks = []
47
  for doc in documents:
48
  for chunk in splitter.split_text(doc.page_content):
49
- chunks.append(Document(page_content=chunk, metadata=doc.metadata))
 
 
50
 
51
  embeddings = HuggingFaceEmbeddings(
52
  model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -54,25 +68,30 @@ def load_vectorstore():
54
 
55
  INDEX_PATH = "faiss_index"
56
 
 
57
  if os.path.exists(INDEX_PATH):
 
58
  vectorstore = FAISS.load_local(
59
  INDEX_PATH,
60
  embeddings,
61
  allow_dangerous_deserialization=True
62
  )
63
  else:
 
64
  vectorstore = FAISS.from_documents(chunks, embeddings)
65
  vectorstore.save_local(INDEX_PATH)
66
 
67
  return vectorstore
68
 
69
 
70
- # =============================
71
- # CACHE MODEL (LOAD ONCE)
72
- # =============================
73
- @st.cache_resource
74
  def load_llm():
75
 
 
 
76
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
77
 
78
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -93,12 +112,13 @@ def load_llm():
93
  return pipe
94
 
95
 
 
96
  vectorstore = load_vectorstore()
97
  pipe = load_llm()
98
 
99
- # =============================
100
  # HELPERS
101
- # =============================
102
  def detect_language(text):
103
  try:
104
  return detect(text)
@@ -106,17 +126,17 @@ def detect_language(text):
106
  return "en"
107
 
108
 
109
- def translate(text, target):
110
- return GoogleTranslator(source="auto", target=target).translate(text)
111
 
112
 
113
  def build_prompt(question, context):
114
  return f"""
115
  You are a Kenyan legal assistant.
116
 
117
- Answer ONLY using provided context.
118
- Include case citations.
119
- Do not hallucinate.
120
 
121
  Context:
122
  {context}
@@ -130,34 +150,38 @@ Structured Answer:
130
 
131
  def ask_kenya_law(question):
132
 
133
- lang = detect_language(question)
134
- q_en = translate(question, "en") if lang == "sw" else question
 
 
 
135
 
136
- docs = vectorstore.similarity_search(q_en, k=4)
137
- context = "\n\n".join([d.page_content for d in docs])
138
 
139
- prompt = build_prompt(q_en, context)
 
 
140
 
141
  result = pipe(prompt)[0]["generated_text"]
142
 
143
- if lang == "sw":
144
  result = translate(result, "sw")
145
 
146
  sources = "\n".join(
147
- [f'{d.metadata["source"]} - Page {d.metadata["page"]}' for d in docs]
 
148
  )
149
 
150
  return result, sources
151
 
152
 
153
- # =============================
154
  # STREAMLIT CHAT UI
155
- # =============================
156
- st.title("πŸ‡°πŸ‡ͺ Kenya Legal Assistant")
157
-
158
  if "messages" not in st.session_state:
159
  st.session_state.messages = []
160
 
 
161
  for msg in st.session_state.messages:
162
  with st.chat_message(msg["role"]):
163
  st.markdown(msg["content"])
@@ -165,7 +189,10 @@ for msg in st.session_state.messages:
165
  prompt = st.chat_input("Ask a legal question...")
166
 
167
  if prompt:
168
- st.session_state.messages.append({"role": "user", "content": prompt})
 
 
 
169
 
170
  with st.chat_message("user"):
171
  st.markdown(prompt)
@@ -177,11 +204,14 @@ if prompt:
177
  response = f"""
178
  {answer}
179
 
 
 
180
  πŸ“š **Sources**
181
  {sources}
182
 
183
  ⚠️ DISCLAIMER:
184
- Educational legal information only β€” not legal advice.
 
185
  """
186
 
187
  st.markdown(response)
 
10
  from langchain.embeddings import HuggingFaceEmbeddings
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
12
 
13
+ # ===================================
14
  # PAGE CONFIG
15
+ # ===================================
16
+ st.set_page_config(
17
+ page_title="Kenya Legal Assistant",
18
+ layout="wide"
19
+ )
20
 
21
+ st.title("πŸ‡°πŸ‡ͺ Kenya Legal Assistant")
22
+ st.caption("Ask questions about Kenyan court judgments (English or Swahili)")
23
+
24
+ # ===================================
25
+ # LOAD VECTOR DATABASE (CACHED)
26
+ # ===================================
27
+ @st.cache_resource(show_spinner=True)
28
  def load_vectorstore():
29
 
30
+ st.write("πŸ”Ž Loading legal knowledge base...")
31
+
32
  dataset = load_dataset(
33
  "Brian269/Kenyan_Judgements",
34
  split="train",
 
37
 
38
  documents = []
39
  for i, item in enumerate(dataset):
40
+ if i > 200: # prevents HF startup timeout
41
  break
42
+
43
  documents.append(
44
  Document(
45
  page_content=item["text"],
46
+ metadata={
47
+ "source": item["file_name"],
48
+ "page": 1
49
+ },
50
  )
51
  )
52
 
 
58
  chunks = []
59
  for doc in documents:
60
  for chunk in splitter.split_text(doc.page_content):
61
+ chunks.append(
62
+ Document(page_content=chunk, metadata=doc.metadata)
63
+ )
64
 
65
  embeddings = HuggingFaceEmbeddings(
66
  model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
 
68
 
69
  INDEX_PATH = "faiss_index"
70
 
71
+ # βœ… Load prebuilt FAISS index if uploaded
72
  if os.path.exists(INDEX_PATH):
73
+ st.write("βœ… Loading FAISS index...")
74
  vectorstore = FAISS.load_local(
75
  INDEX_PATH,
76
  embeddings,
77
  allow_dangerous_deserialization=True
78
  )
79
  else:
80
+ st.warning("⚠️ FAISS index not found β€” building (first run only)...")
81
  vectorstore = FAISS.from_documents(chunks, embeddings)
82
  vectorstore.save_local(INDEX_PATH)
83
 
84
  return vectorstore
85
 
86
 
87
+ # ===================================
88
+ # LOAD LANGUAGE MODEL (CACHED)
89
+ # ===================================
90
+ @st.cache_resource(show_spinner=True)
91
  def load_llm():
92
 
93
+ st.write("🧠 Loading language model...")
94
+
95
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
96
 
97
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
112
  return pipe
113
 
114
 
115
+ # Load once
116
  vectorstore = load_vectorstore()
117
  pipe = load_llm()
118
 
119
+ # ===================================
120
  # HELPERS
121
+ # ===================================
122
  def detect_language(text):
123
  try:
124
  return detect(text)
 
126
  return "en"
127
 
128
 
129
+ def translate(text, target_lang):
130
+ return GoogleTranslator(source="auto", target=target_lang).translate(text)
131
 
132
 
133
  def build_prompt(question, context):
134
  return f"""
135
  You are a Kenyan legal assistant.
136
 
137
+ Answer ONLY using the provided context.
138
+ Include proper case citations.
139
+ Do not fabricate information.
140
 
141
  Context:
142
  {context}
 
150
 
151
  def ask_kenya_law(question):
152
 
153
+ language = detect_language(question)
154
+
155
+ question_en = (
156
+ translate(question, "en") if language == "sw" else question
157
+ )
158
 
159
+ retrieved_docs = vectorstore.similarity_search(question_en, k=4)
 
160
 
161
+ context = "\n\n".join([doc.page_content for doc in retrieved_docs])
162
+
163
+ prompt = build_prompt(question_en, context)
164
 
165
  result = pipe(prompt)[0]["generated_text"]
166
 
167
+ if language == "sw":
168
  result = translate(result, "sw")
169
 
170
  sources = "\n".join(
171
+ [f'{doc.metadata["source"]} - Page {doc.metadata["page"]}'
172
+ for doc in retrieved_docs]
173
  )
174
 
175
  return result, sources
176
 
177
 
178
+ # ===================================
179
  # STREAMLIT CHAT UI
180
+ # ===================================
 
 
181
  if "messages" not in st.session_state:
182
  st.session_state.messages = []
183
 
184
+ # Display history
185
  for msg in st.session_state.messages:
186
  with st.chat_message(msg["role"]):
187
  st.markdown(msg["content"])
 
189
  prompt = st.chat_input("Ask a legal question...")
190
 
191
  if prompt:
192
+
193
+ st.session_state.messages.append(
194
+ {"role": "user", "content": prompt}
195
+ )
196
 
197
  with st.chat_message("user"):
198
  st.markdown(prompt)
 
204
  response = f"""
205
  {answer}
206
 
207
+ ---
208
+
209
  πŸ“š **Sources**
210
  {sources}
211
 
212
  ⚠️ DISCLAIMER:
213
+ This AI provides legal information for educational purposes only.
214
+ It does NOT constitute legal advice.
215
  """
216
 
217
  st.markdown(response)