zm-f21 commited on
Commit
85551a0
·
verified ·
1 Parent(s): d5857d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -197
app.py CHANGED
@@ -1,204 +1,81 @@
1
- # ----------------------------- #
2
- # Imports
3
- # ----------------------------- #
4
- import os
5
- import re
6
- import zipfile
7
- from pathlib import Path
8
-
9
- import pandas as pd
10
- import numpy as np
11
-
12
- from sentence_transformers import SentenceTransformer
13
- from ctransformers import AutoModelForCausalLM
14
  import gradio as gr
 
 
 
 
 
15
 
16
- # ----------------------------- #
17
- # Load LLM (GGUF quantized Mistral)
18
- # ----------------------------- #
19
- # Make sure you have downloaded the model locally:
20
- # e.g., ./models/mistral-7B-v0.1.Q4_0.gguf
21
- llm = AutoModelForCausalLM.from_pretrained(
22
- "./models/mistral-7B-v0.1.Q4_0.gguf",
23
- model_type="mistral",
24
- )
25
-
26
- # ----------------------------- #
27
- # Load Embedding Model
28
- # ----------------------------- #
29
- embedding_model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
30
-
31
- # ----------------------------- #
32
- # Extract ZIP of provincial texts
33
- # ----------------------------- #
34
- zip_path = "provinces.zip"
35
- extract_folder = "provinces_texts"
36
-
37
- if not os.path.exists(extract_folder):
38
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
39
- zip_ref.extractall(extract_folder)
40
-
41
- # ----------------------------- #
42
- # Parse Files
43
- # ----------------------------- #
44
- def parse_metadata_and_content(raw_text):
45
- if "CONTENT:" not in raw_text:
46
- raise ValueError("File missing CONTENT: separator.")
47
-
48
- header, content = raw_text.split("CONTENT:", 1)
49
- metadata = {}
50
- lines = header.strip().split("\n")
51
- pdf_list = []
52
-
53
- for line in lines:
54
- if ":" in line and not line.strip().startswith("-"):
55
- key, value = line.split(":", 1)
56
- metadata[key.strip()] = value.strip()
57
- elif line.strip().startswith("-"):
58
- pdf_list.append(line.strip())
59
-
60
- if pdf_list:
61
- metadata["PDF_LINKS"] = "\n".join(pdf_list)
62
-
63
- return metadata, content.strip()
64
-
65
-
66
- documents = []
67
-
68
- for root, dirs, files in os.walk(extract_folder):
69
- for filename in files:
70
- if filename.startswith("._"):
71
- continue
72
- if filename.endswith(".txt"):
73
- filepath = os.path.join(root, filename)
74
- try:
75
- with open(filepath, "r", encoding="latin-1") as f:
76
- raw = f.read()
77
- metadata, content = parse_metadata_and_content(raw)
78
- paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
79
- for p in paragraphs:
80
- documents.append({
81
- "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
82
- "province": metadata.get("PROVINCE", "Unknown"),
83
- "last_updated": metadata.get("LAST_UPDATED", "Unknown"),
84
- "url": metadata.get("URL", "N/A"),
85
- "pdf_links": metadata.get("PDF_LINKS", ""),
86
- "text": p
87
- })
88
- except Exception:
89
- continue
90
-
91
- # Build DataFrame and compute embeddings
92
- df = pd.DataFrame(documents)
93
- df["Embedding"] = df["text"].apply(lambda x: embedding_model.encode(x))
94
-
95
- # ----------------------------- #
96
- # Province Detection
97
- # ----------------------------- #
98
- def detect_province(query):
99
- provinces = {
100
- "yukon": "Yukon",
101
- "alberta": "Alberta",
102
- "bc": "British Columbia",
103
- "british columbia": "British Columbia",
104
- "manitoba": "Manitoba",
105
- "nl": "Newfoundland and Labrador",
106
- "newfoundland": "Newfoundland and Labrador",
107
- "sask": "Saskatchewan",
108
- "saskatchewan": "Saskatchewan",
109
- "ontario": "Ontario",
110
- "pei": "Prince Edward Island",
111
- "prince edward island": "Prince Edward Island",
112
- "quebec": "Quebec",
113
- "nb": "New Brunswick",
114
- "new brunswick": "New Brunswick",
115
- "nova scotia": "Nova Scotia",
116
- "nunavut": "Nunavut",
117
- "nwt": "Northwest Territories",
118
- "northwest territories": "Northwest Territories"
119
- }
120
- q = query.lower()
121
- for key, prov in provinces.items():
122
- if key in q:
123
- return prov
124
- return None
125
-
126
- # ----------------------------- #
127
- # Guardrails
128
- # ----------------------------- #
129
- def is_disallowed(query):
130
- banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
131
- return any(b in query.lower() for b in banned)
132
-
133
- def is_off_topic(query):
134
- tenancy_keywords = [
135
- "tenant", "landlord", "rent", "evict", "lease",
136
- "deposit", "tenancy", "rental", "apartment",
137
- "unit", "heating", "notice", "repair", "pets"
138
- ]
139
- q = query.lower()
140
- return not any(k in q for k in tenancy_keywords)
141
-
142
- INTRO_TEXT = (
143
- "Hi! I'm a Canadian rental housing assistant. I can help you find, summarize, "
144
- "and explain information from the Residential Tenancies Acts across all provinces.\n\n"
145
- "This is not legal advice — laws may vary and change.\n\n"
146
  )
147
 
148
- # ----------------------------- #
149
- # Retrieval Function
150
- # ----------------------------- #
151
- def retrieve_with_pandas(query, province=None, top_k=2):
152
- query_embedding = embedding_model.encode([query])[0]
153
- filtered_df = df[df['province'] == province].copy() if province else df.copy()
154
- filtered_df["Similarity"] = filtered_df["Embedding"].apply(
155
- lambda x: np.dot(query_embedding, x) /
156
- (np.linalg.norm(query_embedding) * np.linalg.norm(x))
157
- )
158
- results = filtered_df.sort_values("Similarity", ascending=False).head(top_k)
159
- return results
160
-
161
- # ----------------------------- #
162
- # Main RAG Generator
163
- # ----------------------------- #
164
- def generate_with_rag(query):
165
- if is_disallowed(query):
166
- return INTRO_TEXT + "Sorry — I can’t help with harmful topics."
167
- if is_off_topic(query):
168
- return INTRO_TEXT + "Sorry — I can only answer questions about tenancy and housing law."
169
-
170
- province = detect_province(query)
171
- top_docs_df = retrieve_with_pandas(query, province=province, top_k=2)
172
- if len(top_docs_df) == 0:
173
- return INTRO_TEXT + "I couldn't find relevant information."
174
-
175
- context = " ".join(top_docs_df["text"].tolist())
176
-
177
- prompt = f"""
178
- Use the context below to answer the question.
179
- CONTEXT:
180
- {context}
181
- QUESTION:
182
- {query}
183
- ANSWER:
184
- """
185
-
186
- # Generate response with ctransformers
187
- response = llm(prompt, max_new_tokens=300, temperature=0.2)
188
- return response[0]["generated_text"].split("ANSWER:")[-1].strip()
189
-
190
- # ----------------------------- #
191
- # Gradio UI
192
- # ----------------------------- #
193
- def ui_fn(query):
194
- return generate_with_rag(query)
195
-
196
- demo = gr.Interface(
197
- fn=ui_fn,
198
- inputs=gr.Textbox(lines=3, label="Ask a question"),
199
- outputs=gr.Textbox(label="Answer"),
200
- title="Canadian Tenancy RAG Assistant"
201
- )
202
 
203
  if __name__ == "__main__":
204
  demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ import os
7
 
8
+ # -----------------------------
9
+ # Hugging Face token
10
+ # -----------------------------
11
+ os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN"
12
+ client = InferenceClient(token=os.environ["HF_TOKEN"], model="mistralai/Mistral-7B-Instruct-v0.2")
13
+
14
+ # -----------------------------
15
+ # Example RAG documents
16
+ # -----------------------------
17
+ documents = [
18
+ "Quantum computing uses quantum bits.",
19
+ "Transformers are a type of neural network architecture.",
20
+ "Python is a popular programming language.",
21
+ # Add more docs or load from your dataset
22
+ ]
23
+
24
+ # -----------------------------
25
+ # Embeddings + FAISS
26
+ # -----------------------------
27
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
28
+ embeddings = embed_model.encode(documents, convert_to_numpy=True)
29
+ dimension = embeddings.shape[1]
30
+ index = faiss.IndexFlatL2(dimension)
31
+ index.add(embeddings)
32
+
33
+ def retrieve(query, top_k=2):
34
+ query_emb = embed_model.encode([query], convert_to_numpy=True)
35
+ distances, indices = index.search(query_emb, top_k)
36
+ return [documents[i] for i in indices[0]]
37
+
38
+ # -----------------------------
39
+ # RAG answer function
40
+ # -----------------------------
41
+ def answer_with_rag(message, history, system_message, max_tokens, temperature, top_p):
42
+ context_docs = retrieve(message)
43
+ context = " ".join(context_docs)
44
+
45
+ prompt = f"Answer the question using the following context:\n{context}\n\nQuestion: {message}\nAnswer:"
46
+
47
+ response = ""
48
+ for msg in client.chat_completion(
49
+ prompt,
50
+ max_tokens=max_tokens,
51
+ stream=True,
52
+ temperature=temperature,
53
+ top_p=top_p
54
+ ):
55
+ choices = msg.choices
56
+ if len(choices) and choices[0].delta.content:
57
+ response += choices[0].delta.content
58
+ return response
59
+
60
+ # -----------------------------
61
+ # Gradio ChatInterface
62
+ # -----------------------------
63
+ chatbot = gr.ChatInterface(
64
+ answer_with_rag,
65
+ type="messages",
66
+ additional_inputs=[
67
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
68
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
69
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
70
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
71
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
+ with gr.Blocks() as demo:
75
+ with gr.Sidebar():
76
+ gr.LoginButton()
77
+ chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == "__main__":
80
  demo.launch(share=True)
81
+