zm-f21 commited on
Commit
8babbb9
·
verified ·
1 Parent(s): 85551a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -53
app.py CHANGED
@@ -1,81 +1,246 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
  from sentence_transformers import SentenceTransformer
4
- import faiss
5
  import numpy as np
 
6
  import os
 
 
7
 
8
  # -----------------------------
9
- # Hugging Face token
10
  # -----------------------------
11
- os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN"
12
- client = InferenceClient(token=os.environ["HF_TOKEN"], model="mistralai/Mistral-7B-Instruct-v0.2")
 
 
 
 
13
 
14
  # -----------------------------
15
- # Example RAG documents
 
 
 
16
  # -----------------------------
17
- documents = [
18
- "Quantum computing uses quantum bits.",
19
- "Transformers are a type of neural network architecture.",
20
- "Python is a popular programming language.",
21
- # Add more docs or load from your dataset
22
- ]
 
 
 
 
 
 
 
 
 
23
 
24
  # -----------------------------
25
- # Embeddings + FAISS
26
  # -----------------------------
27
- embed_model = SentenceTransformer("all-MiniLM-L6-v2")
28
- embeddings = embed_model.encode(documents, convert_to_numpy=True)
29
- dimension = embeddings.shape[1]
30
- index = faiss.IndexFlatL2(dimension)
31
- index.add(embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def retrieve(query, top_k=2):
34
- query_emb = embed_model.encode([query], convert_to_numpy=True)
35
- distances, indices = index.search(query_emb, top_k)
36
- return [documents[i] for i in indices[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # -----------------------------
39
- # RAG answer function
40
  # -----------------------------
41
- def answer_with_rag(message, history, system_message, max_tokens, temperature, top_p):
42
- context_docs = retrieve(message)
43
- context = " ".join(context_docs)
 
 
 
 
44
 
45
- prompt = f"Answer the question using the following context:\n{context}\n\nQuestion: {message}\nAnswer:"
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- response = ""
48
- for msg in client.chat_completion(
49
- prompt,
50
- max_tokens=max_tokens,
51
- stream=True,
52
- temperature=temperature,
53
- top_p=top_p
54
- ):
55
- choices = msg.choices
56
- if len(choices) and choices[0].delta.content:
57
- response += choices[0].delta.content
58
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # -----------------------------
61
- # Gradio ChatInterface
62
  # -----------------------------
63
- chatbot = gr.ChatInterface(
64
- answer_with_rag,
65
- type="messages",
66
- additional_inputs=[
67
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
68
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
69
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
70
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
71
- ],
 
 
 
 
 
 
 
 
72
  )
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  with gr.Blocks() as demo:
75
- with gr.Sidebar():
76
- gr.LoginButton()
77
- chatbot.render()
 
 
 
 
78
 
79
  if __name__ == "__main__":
80
  demo.launch(share=True)
81
-
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  from sentence_transformers import SentenceTransformer
4
+ import pandas as pd
5
  import numpy as np
6
+ import zipfile
7
  import os
8
+ import re
9
+ import torch
10
 
11
  # -----------------------------
12
+ # Load Mistral pipeline
13
  # -----------------------------
14
+ llm = pipeline(
15
+ "text-generation",
16
+ model="mistralai/Mistral-7B-Instruct-v0.2",
17
+ torch_dtype=torch.float16,
18
+ device_map="auto"
19
+ )
20
 
21
  # -----------------------------
22
+ # Load SentenceTransformer embeddings
23
+ # -----------------------------
24
+ embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
25
+
26
  # -----------------------------
27
+ # Extract Provinces ZIP
28
+ # -----------------------------
29
+ zip_path = "/app/provinces.zip" # Make sure you upload this to your HF Space
30
+ extract_folder = "/app/provinces_texts"
31
+
32
+ # Remove old folder if exists
33
+ if os.path.exists(extract_folder):
34
+ import shutil
35
+ shutil.rmtree(extract_folder)
36
+
37
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
38
+ zip_ref.extractall(extract_folder)
39
+
40
+ # Regex to capture YYYY_MM_DD or YYYY-MM-DD anywhere in filename
41
+ date_pattern = re.compile(r"(\d{4}[-]\d{2}[_-]\d{2})")
42
 
43
  # -----------------------------
44
+ # Parse TXT files and create documents
45
  # -----------------------------
46
+ def parse_metadata_and_content(raw_text):
47
+ if "CONTENT:" not in raw_text:
48
+ raise ValueError("File missing CONTENT: separator.")
49
+
50
+ header, content = raw_text.split("CONTENT:", 1)
51
+ metadata = {}
52
+ lines = header.strip().split("\n")
53
+ pdf_list = []
54
+
55
+ for line in lines:
56
+ if ":" in line and not line.strip().startswith("-"):
57
+ key, value = line.split(":", 1)
58
+ metadata[key.strip().upper()] = value.strip()
59
+ elif line.strip().startswith("-"):
60
+ pdf_list.append(line.strip())
61
+ if pdf_list:
62
+ metadata["PDF_LINKS"] = "\n".join(pdf_list)
63
+ return metadata, content.strip()
64
+
65
+ documents = []
66
 
67
+ for root, dirs, files in os.walk(extract_folder):
68
+ for filename in files:
69
+ if filename.startswith("._") or not filename.endswith(".txt"):
70
+ continue
71
+ filepath = os.path.join(root, filename)
72
+ try:
73
+ with open(filepath, "r", encoding="latin-1") as f:
74
+ raw = f.read()
75
+ metadata, content = parse_metadata_and_content(raw)
76
+ paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
77
+ for p in paragraphs:
78
+ documents.append({
79
+ "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
80
+ "province": metadata.get("PROVINCE", "Unknown"),
81
+ "last_updated": metadata.get("LAST_UPDATED", "Unknown"),
82
+ "url": metadata.get("URL", "N/A"),
83
+ "pdf_links": metadata.get("PDF_LINKS", ""),
84
+ "text": p
85
+ })
86
+ except ValueError as e:
87
+ print(f"Skipping {filepath}: {e}")
88
+ continue
89
+
90
+ print(f"Loaded {len(documents)} paragraphs from all provinces.")
91
 
92
  # -----------------------------
93
+ # Create embeddings and dataframe
94
  # -----------------------------
95
+ texts = [d["text"] for d in documents]
96
+ embeddings = embedding_model.encode(texts).astype("float16")
97
+
98
+ df = pd.DataFrame(documents)
99
+ df["Embedding"] = list(embeddings)
100
+
101
+ print("Indexing complete. Total:", len(df))
102
 
103
+ # -----------------------------
104
+ # Retrieve with Pandas
105
+ # -----------------------------
106
+ def retrieve_with_pandas(query, province=None, top_k=2):
107
+ query_emb = embedding_model.encode([query])[0]
108
+ if province is not None:
109
+ filtered_df = df[df['province'] == province].copy()
110
+ else:
111
+ filtered_df = df.copy()
112
+ filtered_df['Similarity'] = filtered_df['Embedding'].apply(
113
+ lambda x: np.dot(query_emb, x) / (np.linalg.norm(query_emb) * np.linalg.norm(x))
114
+ )
115
+ return filtered_df.sort_values("Similarity", ascending=False).head(top_k)
116
 
117
+ # -----------------------------
118
+ # Province detection
119
+ # -----------------------------
120
+ def detect_province(query):
121
+ provinces = {
122
+ "yukon": "Yukon",
123
+ "alberta": "Alberta",
124
+ "bc": "British Columbia",
125
+ "british columbia": "British Columbia",
126
+ "manitoba": "Manitoba",
127
+ "nl": "Newfoundland and Labrador",
128
+ "newfoundland": "Newfoundland and Labrador",
129
+ "sask": "Saskatchewan",
130
+ "saskatchewan": "Saskatchewan",
131
+ "ontario": "Ontario",
132
+ "pei": "Prince Edward Island",
133
+ "prince edward island": "Prince Edward Island",
134
+ "quebec": "Quebec",
135
+ "nb": "New Brunswick",
136
+ "new brunswick": "New Brunswick",
137
+ "nova scotia": "Nova Scotia",
138
+ "nunavut": "Nunavut",
139
+ "nwt": "Northwest Territories",
140
+ "northwest territories": "Northwest Territories"
141
+ }
142
+ q = query.lower()
143
+ for key, prov in provinces.items():
144
+ if key in q:
145
+ return prov
146
+ return None
147
 
148
  # -----------------------------
149
+ # Guardrails
150
  # -----------------------------
151
+ def is_disallowed(query):
152
+ banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
153
+ return any(b in query.lower() for b in banned)
154
+
155
+ def is_off_topic(query):
156
+ tenancy_keywords = [
157
+ "tenant", "landlord", "rent", "evict", "lease",
158
+ "deposit", "tenancy", "rental", "apartment",
159
+ "unit", "heating", "notice", "repair", "pets"
160
+ ]
161
+ q = query.lower()
162
+ return not any(k in q for k in tenancy_keywords)
163
+
164
+ INTRO_TEXT = (
165
+ "Hi! I'm a Canadian rental housing assistant. I can help you find, summarize, "
166
+ "and explain information from the Residential Tenancies Acts across all provinces and territories.\n\n"
167
+ "**Important:** I'm not a lawyer and this is **not legal advice**. Use your own judgment.\n\n"
168
  )
169
 
170
+ # -----------------------------
171
+ # RAG generation function
172
+ # -----------------------------
173
+ def generate_with_rag(query, province=None, top_k=2):
174
+ if is_disallowed(query):
175
+ return INTRO_TEXT + "Sorry — I can’t help with harmful or dangerous topics."
176
+ if is_off_topic(query):
177
+ return INTRO_TEXT + "Sorry — I can only answer questions about Canadian tenancy and housing law."
178
+
179
+ if province is None:
180
+ province = detect_province(query)
181
+
182
+ top_docs = retrieve_with_pandas(query, province=province, top_k=top_k)
183
+ if top_docs is None or len(top_docs) == 0:
184
+ return INTRO_TEXT + "Sorry — I couldn't find any matching information in the tenancy database."
185
+
186
+ context = " ".join(top_docs["text"].tolist())
187
+
188
+ # Few-shot style examples (style guide)
189
+ qa_examples = """
190
+ Q: I asked my landlord three months ago to install handrails in my bathroom. Can the landlord take a long time to respond?
191
+ A: Landlords should respond promptly to reasonable accommodation requests. If they delay unreasonably, you can file a discrimination complaint.
192
+
193
+ Q: My building manager keeps complaining about my children’s noise. Can I be evicted?
194
+ A: Reasonable noise from children is expected. If you're treated differently because you have children, you may file a complaint based on family status.
195
+ """
196
+
197
+ prompt = f"""
198
+ Use the examples as a STYLE GUIDE ONLY.
199
+ DO NOT repeat the example questions.
200
+ DO NOT invent laws — only use the context provided.
201
+ If the context does not contain the answer, say you cannot confidently answer.
202
+
203
+ {qa_examples}
204
+
205
+ Context:
206
+ {context}
207
+
208
+ Question:
209
+ {query}
210
+
211
+ Answer conversationally:
212
+ """
213
+
214
+ raw_output = llm(prompt, max_new_tokens=150)[0]["generated_text"]
215
+ answer = raw_output.split("Answer conversationally:", 1)[-1].strip() if "Answer conversationally:" in raw_output else raw_output.strip()
216
+
217
+ metadata_block = ""
218
+ for _, row in top_docs.iterrows():
219
+ metadata_block += (
220
+ f"- Province: {row['province']}\n"
221
+ f" Source: {row['source_title']}\n"
222
+ f" Updated: {row['last_updated']}\n"
223
+ f" URL: {row['url']}\n"
224
+ )
225
+
226
+ return INTRO_TEXT + f"{answer}\n\nSources Used:\n{metadata_block}"
227
+
228
+ # -----------------------------
229
+ # Gradio Chat
230
+ # -----------------------------
231
+ def respond(message, history):
232
+ answer = generate_with_rag(message)
233
+ history.append((message, answer))
234
+ return history, history
235
+
236
  with gr.Blocks() as demo:
237
+ chatbot = gr.Chatbot()
238
+ msg = gr.Textbox(label="Your question")
239
+ msg.submit(respond, [msg, chatbot], [chatbot, chatbot])
240
+ gr.Markdown(
241
+ "Ask questions about Canadian tenancy and housing law.\n\n"
242
+ "**Note:** I am not a lawyer. Responses are generated from official documents."
243
+ )
244
 
245
  if __name__ == "__main__":
246
  demo.launch(share=True)