zm-f21 commited on
Commit
87fa24c
·
verified ·
1 Parent(s): 373a1dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -79
app.py CHANGED
@@ -1,52 +1,37 @@
1
- # ----------------------------- #
2
- # Imports
3
- # ----------------------------- #
4
- import os
5
- import zipfile
6
- import re
7
  import pandas as pd
8
  import numpy as np
 
9
  from sentence_transformers import SentenceTransformer
10
- from ctransformers import AutoModelForCausalLM
11
  import gradio as gr
 
12
 
13
  # ----------------------------- #
14
- # Load GGUF Mistral Model
15
  # ----------------------------- #
16
- # Make sure your GGUF file is in ./models/mistral.gguf
17
  llm = AutoModelForCausalLM.from_pretrained(
18
- "./models/mistral.gguf",
19
- model_type="mistral", # important for GGUF
20
- n_threads=8 # adjust based on your environment
21
  )
22
 
23
- # ----------------------------- #
24
- # Load Embedding Model
25
- # ----------------------------- #
26
  embedding_model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
27
 
28
  # ----------------------------- #
29
- # Extract ZIP
30
  # ----------------------------- #
31
- zip_path = "provinces.zip"
32
- extract_folder = "provinces_texts"
33
 
34
- if not os.path.exists(extract_folder):
35
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
36
- zip_ref.extractall(extract_folder)
37
-
38
- # ----------------------------- #
39
- # Parse Files
40
- # ----------------------------- #
41
  def parse_metadata_and_content(raw_text):
42
  if "CONTENT:" not in raw_text:
43
  raise ValueError("File missing CONTENT: separator.")
44
 
45
  header, content = raw_text.split("CONTENT:", 1)
46
  metadata = {}
 
47
  pdf_list = []
48
 
49
- for line in header.strip().split("\n"):
50
  if ":" in line and not line.strip().startswith("-"):
51
  key, value = line.split(":", 1)
52
  metadata[key.strip()] = value.strip()
@@ -58,37 +43,12 @@ def parse_metadata_and_content(raw_text):
58
 
59
  return metadata, content.strip()
60
 
61
- documents = []
62
-
63
- for root, dirs, files in os.walk(extract_folder):
64
- for filename in files:
65
- if filename.startswith("._"):
66
- continue
67
- if filename.endswith(".txt"):
68
- filepath = os.path.join(root, filename)
69
- try:
70
- with open(filepath, "r", encoding="latin-1") as f:
71
- raw = f.read()
72
- metadata, content = parse_metadata_and_content(raw)
73
- paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
74
- for p in paragraphs:
75
- documents.append({
76
- "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
77
- "province": metadata.get("PROVINCE", "Unknown"),
78
- "last_updated": metadata.get("LAST_UPDATED", "Unknown"),
79
- "url": metadata.get("URL", "N/A"),
80
- "pdf_links": metadata.get("PDF_LINKS", ""),
81
- "text": p
82
- })
83
- except Exception:
84
- continue
85
-
86
- # Build DataFrame and embeddings
87
- df = pd.DataFrame(documents)
88
- df["Embedding"] = df["text"].apply(lambda x: embedding_model.encode(x))
89
 
90
  # ----------------------------- #
91
- # Province Detection
92
  # ----------------------------- #
93
  def detect_province(query):
94
  provinces = {
@@ -112,7 +72,6 @@ def detect_province(query):
112
  "nwt": "Northwest Territories",
113
  "northwest territories": "Northwest Territories"
114
  }
115
-
116
  q = query.lower()
117
  for key, prov in provinces.items():
118
  if key in q:
@@ -120,7 +79,7 @@ def detect_province(query):
120
  return None
121
 
122
  # ----------------------------- #
123
- # Guardrails
124
  # ----------------------------- #
125
  def is_disallowed(query):
126
  banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
@@ -142,30 +101,35 @@ INTRO_TEXT = (
142
  )
143
 
144
  # ----------------------------- #
145
- # Retrieval
146
  # ----------------------------- #
147
  def retrieve_with_pandas(query, province=None, top_k=2):
148
  query_embedding = embedding_model.encode([query])[0]
149
 
150
- filtered_df = df.copy()
151
  if province:
152
- filtered_df = filtered_df[filtered_df['province'] == province]
 
 
153
 
154
  filtered_df["Similarity"] = filtered_df["Embedding"].apply(
155
  lambda x: np.dot(query_embedding, x) /
156
  (np.linalg.norm(query_embedding) * np.linalg.norm(x))
157
  )
158
 
159
- return filtered_df.sort_values("Similarity", ascending=False).head(top_k)
 
160
 
161
  # ----------------------------- #
162
- # Main RAG Generator
163
  # ----------------------------- #
164
  def generate_with_rag(query):
165
  if is_disallowed(query):
166
  return INTRO_TEXT + "Sorry — I can’t help with harmful topics."
 
167
  if is_off_topic(query):
168
- return INTRO_TEXT + "Sorry — I can only answer questions about tenancy and housing law."
 
 
169
 
170
  province = detect_province(query)
171
  top_docs_df = retrieve_with_pandas(query, province=province, top_k=2)
@@ -174,7 +138,6 @@ def generate_with_rag(query):
174
  return INTRO_TEXT + "I couldn't find relevant information."
175
 
176
  context = " ".join(top_docs_df["text"].tolist())
177
-
178
  prompt = f"""
179
  Use the context below to answer the question.
180
  CONTEXT:
@@ -184,23 +147,15 @@ QUESTION:
184
  ANSWER:
185
  """
186
 
187
- response = llm(prompt, max_new_tokens=300, temperature=0.2)
188
- answer = response[0]["generated_text"].split("ANSWER:")[-1].strip()
189
-
190
- # Add metadata
191
- metadata_block = ""
192
- for _, row in top_docs_df.iterrows():
193
- metadata_block += (
194
- f"- Province: {row['province']}\n"
195
- f" Source: {row['source_title']}\n"
196
- f" Updated: {row['last_updated']}\n"
197
- f" URL: {row['url']}\n"
198
- )
199
-
200
- return INTRO_TEXT + f"{answer}\n\nSources Used:\n{metadata_block}"
201
 
202
  # ----------------------------- #
203
- # Gradio UI
204
  # ----------------------------- #
205
  def ui_fn(query):
206
  return generate_with_rag(query)
@@ -214,3 +169,4 @@ demo = gr.Interface(
214
 
215
  if __name__ == "__main__":
216
  demo.launch(share=True)
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import numpy as np
3
+ import re
4
  from sentence_transformers import SentenceTransformer
 
5
  import gradio as gr
6
+ from ctransformers import AutoModelForCausalLM
7
 
8
  # ----------------------------- #
9
+ # Load Hosted Mistral 7B Q4_0
10
  # ----------------------------- #
 
11
  llm = AutoModelForCausalLM.from_pretrained(
12
+ "TheBloke/Mistral-7B-v0.1-Q4_0", # hosted HF model
13
+ model_type="mistral", # model type
14
+ gpu_layers=32 # adjust based on GPU/VRAM
15
  )
16
 
 
 
 
17
  embedding_model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
18
 
19
  # ----------------------------- #
20
+ # Parse & Prepare Your Documents
21
  # ----------------------------- #
22
+ # Example parsing function (from your previous code)
23
+ date_pattern = re.compile(r"(\d{4}[-]\d{2}[-_]\d{2})")
24
 
 
 
 
 
 
 
 
25
  def parse_metadata_and_content(raw_text):
26
  if "CONTENT:" not in raw_text:
27
  raise ValueError("File missing CONTENT: separator.")
28
 
29
  header, content = raw_text.split("CONTENT:", 1)
30
  metadata = {}
31
+ lines = header.strip().split("\n")
32
  pdf_list = []
33
 
34
+ for line in lines:
35
  if ":" in line and not line.strip().startswith("-"):
36
  key, value = line.split(":", 1)
37
  metadata[key.strip()] = value.strip()
 
43
 
44
  return metadata, content.strip()
45
 
46
+ # Load your text documents into df as before
47
+ # df = pd.DataFrame(documents)
48
+ # df["Embedding"] = df["text"].apply(lambda x: embedding_model.encode(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # ----------------------------- #
51
+ # Province Detection
52
  # ----------------------------- #
53
  def detect_province(query):
54
  provinces = {
 
72
  "nwt": "Northwest Territories",
73
  "northwest territories": "Northwest Territories"
74
  }
 
75
  q = query.lower()
76
  for key, prov in provinces.items():
77
  if key in q:
 
79
  return None
80
 
81
  # ----------------------------- #
82
+ # Guardrails
83
  # ----------------------------- #
84
  def is_disallowed(query):
85
  banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
 
101
  )
102
 
103
  # ----------------------------- #
104
+ # Retrieval
105
  # ----------------------------- #
106
  def retrieve_with_pandas(query, province=None, top_k=2):
107
  query_embedding = embedding_model.encode([query])[0]
108
 
 
109
  if province:
110
+ filtered_df = df[df['province'] == province].copy()
111
+ else:
112
+ filtered_df = df.copy()
113
 
114
  filtered_df["Similarity"] = filtered_df["Embedding"].apply(
115
  lambda x: np.dot(query_embedding, x) /
116
  (np.linalg.norm(query_embedding) * np.linalg.norm(x))
117
  )
118
 
119
+ results = filtered_df.sort_values("Similarity", ascending=False).head(top_k)
120
+ return results
121
 
122
  # ----------------------------- #
123
+ # RAG Generator
124
  # ----------------------------- #
125
  def generate_with_rag(query):
126
  if is_disallowed(query):
127
  return INTRO_TEXT + "Sorry — I can’t help with harmful topics."
128
+
129
  if is_off_topic(query):
130
+ return INTRO_TEXT + (
131
+ "Sorry — I can only answer questions about tenancy and housing law."
132
+ )
133
 
134
  province = detect_province(query)
135
  top_docs_df = retrieve_with_pandas(query, province=province, top_k=2)
 
138
  return INTRO_TEXT + "I couldn't find relevant information."
139
 
140
  context = " ".join(top_docs_df["text"].tolist())
 
141
  prompt = f"""
142
  Use the context below to answer the question.
143
  CONTEXT:
 
147
  ANSWER:
148
  """
149
 
150
+ response = llm(
151
+ prompt,
152
+ max_new_tokens=300,
153
+ temperature=0.2
154
+ )
155
+ return response[0]["generated_text"].split("ANSWER:")[-1].strip()
 
 
 
 
 
 
 
 
156
 
157
  # ----------------------------- #
158
+ # Gradio UI
159
  # ----------------------------- #
160
  def ui_fn(query):
161
  return generate_with_rag(query)
 
169
 
170
  if __name__ == "__main__":
171
  demo.launch(share=True)
172
+