zm-f21 commited on
Commit
9cfac5d
·
verified ·
1 Parent(s): bc508c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -53
app.py CHANGED
@@ -1,70 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- messages.extend(history)
 
 
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- response = ""
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
1
+ # app.py (copy-paste ready)
2
+ import os
3
+ import zipfile
4
+ import shutil
5
+ import re
6
+ import math
7
+ import json
8
+ import logging
9
+
10
+ # silence transformers warnings
11
+ import transformers
12
+ transformers.logging.set_verbosity_error()
13
+ logging.getLogger("transformers.generation.utils").setLevel(logging.ERROR)
14
+
15
+ from transformers import pipeline
16
+ from sentence_transformers import SentenceTransformer
17
+ import pandas as pd
18
+ import numpy as np
19
  import gradio as gr
20
+
21
+ # ----------------------------- #
22
+ # Configuration - edit these if needed
23
+ # ----------------------------- #
24
+ MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2" # original model (kept as requested)
25
+ ZIP_PATH = "/app/yukon.zip" # where you uploaded the zip in the Space
26
+ EXTRACT_FOLDER = "/app/yukon_texts"
27
+ EMBEDDING_MODEL_ID = "nlpaueb/legal-bert-base-uncased"
28
+ TOP_K = 2 # default number of retrieved docs
29
+
30
+ # ----------------------------- #
31
+ # Load LLM pipeline (try device_map first, fallback to CPU)
32
+ # ----------------------------- #
33
+ def create_llm_pipeline():
34
+ try:
35
+ # Try to load with device_map="auto" (requires accelerate)
36
+ llm = pipeline(
37
+ "text-generation",
38
+ model=MODEL_ID,
39
+ torch_dtype="auto",
40
+ device_map="auto",
41
+ max_new_tokens=150
42
+ )
43
+ return llm
44
+ except Exception as e:
45
+ # Fallback to CPU (slower). Keep the error for debugging in logs.
46
+ print(f"[warning] device_map auto failed ({e}). Falling back to CPU pipeline (slower).")
47
+ llm = pipeline(
48
+ "text-generation",
49
+ model=MODEL_ID,
50
+ torch_dtype=None, # let transformers choose
51
+ device_map=None
52
+ )
53
+ return llm
54
+
55
+ llm = create_llm_pipeline()
56
+
57
+ # ----------------------------- #
58
+ # Load embedding model
59
+ # ----------------------------- #
60
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_ID)
61
+
62
+ # ----------------------------- #
63
+ # Helpers: Unzip dataset and normalize path
64
+ # ----------------------------- #
65
+ def safe_extract_zip(zip_path, extract_to):
66
+ # remove old extracted folder if exists
67
+ if os.path.exists(extract_to):
68
+ try:
69
+ shutil.rmtree(extract_to)
70
+ except Exception:
71
+ pass
72
+ os.makedirs(extract_to, exist_ok=True)
73
+
74
+ with zipfile.ZipFile(zip_path, "r") as zf:
75
+ # Some zips contain a top-level folder; extract all
76
+ zf.extractall(extract_to)
77
+
78
+ # If ZIP exists in the Space, extract it
79
+ if os.path.exists(ZIP_PATH):
80
+ safe_extract_zip(ZIP_PATH, EXTRACT_FOLDER)
81
+ else:
82
+ print(f"[warning] ZIP file not found at {ZIP_PATH}. Make sure you uploaded your dataset zip to this path.")
83
+
84
+ # ----------------------------- #
85
+ # Parse metadata/content from files (your existing format with "CONTENT:" separator)
86
+ # ----------------------------- #
87
+ def parse_metadata_and_content(raw_text):
88
  """
89
+ Splits header metadata and content using 'CONTENT:' marker.
90
+ Returns (metadata_dict, content_str).
91
  """
92
+ if "CONTENT:" not in raw_text:
93
+ # If the file doesn't follow the exact format, attempt a graceful fallback:
94
+ # try to extract simple "Key: Value" lines at the top and treat rest as content.
95
+ metadata = {}
96
+ lines = raw_text.split("\n")
97
+ content_lines = []
98
+ for line in lines:
99
+ if ":" in line and len(line.split(":",1)[0].strip()) <= 30 and len(metadata) < 12:
100
+ key, value = line.split(":", 1)
101
+ metadata[key.strip().upper()] = value.strip()
102
+ else:
103
+ content_lines.append(line)
104
+ content = "\n".join(content_lines).strip()
105
+ return metadata, content
106
 
107
+ header, content = raw_text.split("CONTENT:", 1)
108
+ metadata = {}
109
+ pdf_list = []
110
+ for line in header.strip().split("\n"):
111
+ if ":" in line and not line.strip().startswith("-"):
112
+ key, value = line.split(":", 1)
113
+ metadata[key.strip().upper()] = value.strip()
114
+ elif line.strip().startswith("-"):
115
+ pdf_list.append(line.strip())
116
+ if pdf_list:
117
+ metadata["PDF_LINKS"] = "\n".join(pdf_list)
118
+ return metadata, content.strip()
119
 
120
+ # ----------------------------- #
121
+ # Build documents list (paragraph-level)
122
+ # ----------------------------- #
123
+ documents = []
124
 
125
+ # Walk extracted folder for .txt files
126
+ for root, dirs, files in os.walk(EXTRACT_FOLDER):
127
+ for filename in files:
128
+ if filename.startswith("._"): # skip mac metadata
129
+ continue
130
+ if not filename.lower().endswith(".txt"):
131
+ continue
132
+ filepath = os.path.join(root, filename)
133
+ try:
134
+ with open(filepath, "r", encoding="latin-1") as f:
135
+ raw = f.read()
136
+ except Exception:
137
+ try:
138
+ with open(filepath, "r", encoding="utf-8") as f:
139
+ raw = f.read()
140
+ except Exception as e:
141
+ print(f"[warning] failed reading {filepath}: {e}")
142
+ continue
143
 
144
+ # parse metadata + content
145
+ metadata, content = parse_metadata_and_content(raw)
146
+ paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
147
+ for p in paragraphs:
148
+ documents.append({
149
+ "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
150
+ "province": metadata.get("PROVINCE", "Unknown"),
151
+ "last_updated": metadata.get("LAST_UPDATED", "Unknown"),
152
+ "url": metadata.get("URL", "N/A"),
153
+ "pdf_links": metadata.get("PDF_LINKS", ""),
154
+ "text": p
155
+ })
156
 
157
+ print(f"[info] Loaded {len(documents)} document paragraphs.")
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # ----------------------------- #
160
+ # Create embeddings and dataframe
161
+ # ----------------------------- #
162
+ texts = [d["text"] for d in documents]
163
+ if len(texts) == 0:
164
+ df = pd.DataFrame(columns=["source_title","province","last_updated","url","pdf_links","text","Embedding"])
165
+ else:
166
+ # create embeddings (this is potentially slow for many docs)
167
+ embeddings = embedding_model.encode(texts, show_progress_bar=True)
168
+ df = pd.DataFrame(documents)
169
+ df["Embedding"] = list(np.asarray(embeddings, dtype="float32"))
170
+ print("[info] Embeddings indexed. Total:", len(df))
171
 
172
+ # ----------------------------- #
173
+ # Retrieval function (with optional province filter)
174
+ # ----------------------------- #
175
+ def retrieve_with_pandas(query, province=None, top_k=TOP_K):
176
+ if df is None or len(df) == 0:
177
+ return pd.DataFrame() # empty
178
+
179
+ query_emb = embedding_model.encode([query])[0].astype("float32")
180
+
181
+ if province is not None:
182
+ filtered = df[df["province"].str.lower() == str(province).lower()].copy()
183
+ else:
184
+ filtered = df.copy()
185
+
186
+ if filtered.empty:
187
+ return pd.DataFrame()
188
+
189
+ # cosine similarity
190
+ def cos_sim(a, b):
191
+ return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))
192
+
193
+ filtered["Similarity"] = filtered["Embedding"].apply(lambda x: cos_sim(query_emb, np.asarray(x)))
194
+ results = filtered.sort_values("Similarity", ascending=False).head(top_k)
195
+ return results[["text", "last_updated", "Similarity", "province", "source_title", "url"]]
196
+
197
+ # ----------------------------- #
198
+ # Utilities: province detection, guardrails, intros
199
+ # ----------------------------- #
200
+ def detect_province(query):
201
+ provinces = {
202
+ "yukon": "Yukon",
203
+ "alberta": "Alberta",
204
+ "bc": "British Columbia",
205
+ "british columbia": "British Columbia",
206
+ "manitoba": "Manitoba",
207
+ "nl": "Newfoundland and Labrador",
208
+ "newfoundland": "Newfoundland and Labrador",
209
+ "sask": "Saskatchewan",
210
+ "saskatchewan": "Saskatchewan",
211
+ "ontario": "Ontario",
212
+ "pei": "Prince Edward Island",
213
+ "prince edward island": "Prince Edward Island",
214
+ "quebec": "Quebec",
215
+ "nb": "New Brunswick",
216
+ "new brunswick": "New Brunswick",
217
+ "nova scotia": "Nova Scotia",
218
+ "nunavut": "Nunavut",
219
+ "nwt": "Northwest Territories",
220
+ "northwest territories": "Northwest Territories"
221
+ }
222
+ q = query.lower()
223
+ for key, prov in provinces.items():
224
+ if key in q:
225
+ return prov
226
+ return None
227
+
228
+ def is_disallowed(query):
229
+ banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
230
+ q = query.lower()
231
+ return any(b in q for b in banned)
232
+
233
+ def is_off_topic(query):
234
+ tenancy_keywords = [
235
+ "tenant", "landlord", "rent", "evict", "lease",
236
+ "deposit", "tenancy", "rental", "apartment",
237
+ "unit", "heating", "notice", "repair", "pets"
238
+ ]
239
+ q = query.lower()
240
+ return not any(k in q for k in tenancy_keywords)
241
+
242
+ INTRO_TEXT = (
243
+ "Hi! I'm a Canadian rental housing assistant. I can help you find, summarize, "
244
+ "and explain information from the Residential Tenancies Acts across provinces and territories.\n\n"
245
+ "**Important:** I'm not a lawyer and this is NOT legal advice. I may be wrong and laws change — "
246
+ "please verify with official sources or a legal professional when in doubt.\n\n"
247
+ )
248
 
249
+ # ----------------------------- #
250
+ # The RAG generator function
251
+ # ----------------------------- #
252
+ def generate_with_rag(query, province=None, top_k=TOP_K):
253
+ # Guardrails
254
+ if is_disallowed(query):
255
+ return INTRO_TEXT + "Sorry — I can't help with harmful or dangerous topics. Try asking about tenancy/housing instead."
256
+
257
+ if is_off_topic(query):
258
+ return INTRO_TEXT + "Sorry — I can only answer questions about Canadian tenancy and housing law. Try rephrasing with tenancy keywords or mention a province."
259
+
260
+ if province is None:
261
+ province = detect_province(query)
262
+
263
+ top_docs_df = retrieve_with_pandas(query, province=province, top_k=top_k)
264
+ if top_docs_df is None or len(top_docs_df) == 0:
265
+ return INTRO_TEXT + "Sorry — I couldn't find matching info in the tenancy database. Try rephrasing or include a province."
266
+
267
+ context = " ".join(top_docs_df["text"].tolist())
268
+
269
+ # Few-shot style examples (style only)
270
+ qa_examples = """
271
+ Q: I asked my landlord three months ago to install handrails in my bathroom. Can the landlord take a long time to respond?
272
+ A: Landlords should respond promptly to reasonable accommodation requests. If they delay unreasonably, you may be able to file a complaint.
273
+
274
+ Q: My building manager keeps complaining about my children’s noise. Can I be evicted?
275
+ A: Reasonable noise from children is expected. Differential treatment based on family status may violate housing protections.
276
  """
277
+
278
+ prompt = f"""
279
+ Use the examples as a STYLE GUIDE ONLY.
280
+ DO NOT repeat the example questions.
281
+ DO NOT invent laws — only use the context provided.
282
+ If the context does not contain the answer, say you cannot confidently answer.
283
+
284
+ {qa_examples}
285
+
286
+ Context:
287
+ {context}
288
+
289
+ Question:
290
+ {query}
291
+
292
+ Answer conversationally:
293
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ # Call the model (the pipeline already set max tokens default, specify additional args as needed)
296
+ try:
297
+ raw_output = llm(prompt, max_new_tokens=200, do_sample=False)[0]["generated_text"]
298
+ except Exception as e:
299
+ # If the pipeline fails (OOM or other), return a helpful message
300
+ print(f"[error] LLM generation failed: {e}")
301
+ return INTRO_TEXT + "Sorry — the language model failed to produce an answer. Try again or contact the maintainer."
302
+
303
+ # Clean the model output: extract only the part after the "Answer conversationally:" instruction
304
+ if "Answer conversationally:" in raw_output:
305
+ answer = raw_output.split("Answer conversationally:", 1)[-1].strip()
306
+ else:
307
+ answer = raw_output.strip()
308
+
309
+ # Metadata formatting
310
+ metadata_block = ""
311
+ for _, row in top_docs_df.iterrows():
312
+ metadata_block += (
313
+ f"- Province: {row.get('province', 'Unknown')}\n"
314
+ f" Source: {row.get('source_title', 'Unknown')}\n"
315
+ f" Updated: {row.get('last_updated', 'Unknown')}\n"
316
+ f" URL: {row.get('url', 'N/A')}\n"
317
+ )
318
 
319
+ return INTRO_TEXT + f"{answer}\n\nSources Used:\n{metadata_block}"
320
+
321
+ # ----------------------------- #
322
+ # Gradio UI
323
+ # ----------------------------- #
324
+ def respond_gradio(message, chat_history):
325
+ answer = generate_with_rag(message)
326
+ chat_history = chat_history or []
327
+ chat_history.append((message, answer))
328
+ return chat_history, chat_history
329
+
330
+ with gr.Blocks() as demo:
331
+ gr.Markdown("## Yukon / Canada Tenancy RAG Chatbot")
332
+ chatbot = gr.Chatbot()
333
+ msg = gr.Textbox(label="Your question", placeholder="e.g. Can my landlord increase rent in Yukon?")
334
+ msg.submit(respond_gradio, [msg, chatbot], [chatbot, chatbot])
335
+ gr.Markdown("**Note:** This assistant is informational only, not legal advice.")
336
+ demo.queue(concurrency_count=2) # enable queueing to handle requests sequentially in Spaces
337
 
338
  if __name__ == "__main__":
339
+ demo.launch(share=True)
340
+