zm-f21 commited on
Commit
d5857d2
·
verified ·
1 Parent(s): 7b7b8cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -42
app.py CHANGED
@@ -1,26 +1,27 @@
1
  # ----------------------------- #
2
  # Imports
3
  # ----------------------------- #
4
- import re
5
  import os
 
6
  import zipfile
7
  from pathlib import Path
8
 
9
- import numpy as np
10
  import pandas as pd
11
- import gradio as gr
12
- from sentence_transformers import SentenceTransformer
13
 
14
- # Mistral Inference
15
- from mistral_inference import MistralForCausalLM, MistralTokenizer
 
16
 
17
  # ----------------------------- #
18
- # Load Local Mistral Model
19
  # ----------------------------- #
20
- model_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
21
-
22
- tokenizer = MistralTokenizer.from_pretrained(model_path)
23
- llm = MistralForCausalLM.from_pretrained(model_path)
 
 
24
 
25
  # ----------------------------- #
26
  # Load Embedding Model
@@ -28,7 +29,7 @@ llm = MistralForCausalLM.from_pretrained(model_path)
28
  embedding_model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
29
 
30
  # ----------------------------- #
31
- # Extract ZIP
32
  # ----------------------------- #
33
  zip_path = "provinces.zip"
34
  extract_folder = "provinces_texts"
@@ -40,8 +41,6 @@ if not os.path.exists(extract_folder):
40
  # ----------------------------- #
41
  # Parse Files
42
  # ----------------------------- #
43
- date_pattern = re.compile(r"(\d{4}[-]\d{2}[-_]\d{2})")
44
-
45
  def parse_metadata_and_content(raw_text):
46
  if "CONTENT:" not in raw_text:
47
  raise ValueError("File missing CONTENT: separator.")
@@ -49,7 +48,6 @@ def parse_metadata_and_content(raw_text):
49
  header, content = raw_text.split("CONTENT:", 1)
50
  metadata = {}
51
  lines = header.strip().split("\n")
52
-
53
  pdf_list = []
54
 
55
  for line in lines:
@@ -64,6 +62,7 @@ def parse_metadata_and_content(raw_text):
64
 
65
  return metadata, content.strip()
66
 
 
67
  documents = []
68
 
69
  for root, dirs, files in os.walk(extract_folder):
@@ -77,7 +76,6 @@ for root, dirs, files in os.walk(extract_folder):
77
  raw = f.read()
78
  metadata, content = parse_metadata_and_content(raw)
79
  paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
80
-
81
  for p in paragraphs:
82
  documents.append({
83
  "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
@@ -90,6 +88,7 @@ for root, dirs, files in os.walk(extract_folder):
90
  except Exception:
91
  continue
92
 
 
93
  df = pd.DataFrame(documents)
94
  df["Embedding"] = df["text"].apply(lambda x: embedding_model.encode(x))
95
 
@@ -118,7 +117,6 @@ def detect_province(query):
118
  "nwt": "Northwest Territories",
119
  "northwest territories": "Northwest Territories"
120
  }
121
-
122
  q = query.lower()
123
  for key, prov in provinces.items():
124
  if key in q:
@@ -148,39 +146,29 @@ INTRO_TEXT = (
148
  )
149
 
150
  # ----------------------------- #
151
- # Retrieval
152
  # ----------------------------- #
153
  def retrieve_with_pandas(query, province=None, top_k=2):
154
  query_embedding = embedding_model.encode([query])[0]
155
-
156
- if province:
157
- filtered_df = df[df['province'] == province].copy()
158
- else:
159
- filtered_df = df.copy()
160
-
161
  filtered_df["Similarity"] = filtered_df["Embedding"].apply(
162
  lambda x: np.dot(query_embedding, x) /
163
  (np.linalg.norm(query_embedding) * np.linalg.norm(x))
164
  )
165
-
166
  results = filtered_df.sort_values("Similarity", ascending=False).head(top_k)
167
  return results
168
 
169
  # ----------------------------- #
170
- # Main RAG Generator using MistralInference
171
  # ----------------------------- #
172
  def generate_with_rag(query):
173
  if is_disallowed(query):
174
  return INTRO_TEXT + "Sorry — I can’t help with harmful topics."
175
-
176
  if is_off_topic(query):
177
- return INTRO_TEXT + (
178
- "Sorry — I can only answer questions about tenancy and housing law."
179
- )
180
 
181
  province = detect_province(query)
182
  top_docs_df = retrieve_with_pandas(query, province=province, top_k=2)
183
-
184
  if len(top_docs_df) == 0:
185
  return INTRO_TEXT + "I couldn't find relevant information."
186
 
@@ -195,15 +183,9 @@ QUESTION:
195
  ANSWER:
196
  """
197
 
198
- # Generate response
199
- response = llm.generate(
200
- tokenizer.encode(prompt, return_tensors="pt"),
201
- max_new_tokens=300,
202
- temperature=0.2
203
- )
204
-
205
- answer = tokenizer.decode(response[0], skip_special_tokens=True)
206
- return answer.split("ANSWER:")[-1].strip()
207
 
208
  # ----------------------------- #
209
  # Gradio UI
@@ -220,5 +202,3 @@ demo = gr.Interface(
220
 
221
  if __name__ == "__main__":
222
  demo.launch(share=True)
223
-
224
-
 
1
  # ----------------------------- #
2
  # Imports
3
  # ----------------------------- #
 
4
  import os
5
+ import re
6
  import zipfile
7
  from pathlib import Path
8
 
 
9
  import pandas as pd
10
+ import numpy as np
 
11
 
12
+ from sentence_transformers import SentenceTransformer
13
+ from ctransformers import AutoModelForCausalLM
14
+ import gradio as gr
15
 
16
  # ----------------------------- #
17
+ # Load LLM (GGUF quantized Mistral)
18
  # ----------------------------- #
19
+ # Make sure you have downloaded the model locally:
20
+ # e.g., ./models/mistral-7B-v0.1.Q4_0.gguf
21
+ llm = AutoModelForCausalLM.from_pretrained(
22
+ "./models/mistral-7B-v0.1.Q4_0.gguf",
23
+ model_type="mistral",
24
+ )
25
 
26
  # ----------------------------- #
27
  # Load Embedding Model
 
29
  embedding_model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
30
 
31
  # ----------------------------- #
32
+ # Extract ZIP of provincial texts
33
  # ----------------------------- #
34
  zip_path = "provinces.zip"
35
  extract_folder = "provinces_texts"
 
41
  # ----------------------------- #
42
  # Parse Files
43
  # ----------------------------- #
 
 
44
  def parse_metadata_and_content(raw_text):
45
  if "CONTENT:" not in raw_text:
46
  raise ValueError("File missing CONTENT: separator.")
 
48
  header, content = raw_text.split("CONTENT:", 1)
49
  metadata = {}
50
  lines = header.strip().split("\n")
 
51
  pdf_list = []
52
 
53
  for line in lines:
 
62
 
63
  return metadata, content.strip()
64
 
65
+
66
  documents = []
67
 
68
  for root, dirs, files in os.walk(extract_folder):
 
76
  raw = f.read()
77
  metadata, content = parse_metadata_and_content(raw)
78
  paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
 
79
  for p in paragraphs:
80
  documents.append({
81
  "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
 
88
  except Exception:
89
  continue
90
 
91
+ # Build DataFrame and compute embeddings
92
  df = pd.DataFrame(documents)
93
  df["Embedding"] = df["text"].apply(lambda x: embedding_model.encode(x))
94
 
 
117
  "nwt": "Northwest Territories",
118
  "northwest territories": "Northwest Territories"
119
  }
 
120
  q = query.lower()
121
  for key, prov in provinces.items():
122
  if key in q:
 
146
  )
147
 
148
  # ----------------------------- #
149
+ # Retrieval Function
150
  # ----------------------------- #
151
  def retrieve_with_pandas(query, province=None, top_k=2):
152
  query_embedding = embedding_model.encode([query])[0]
153
+ filtered_df = df[df['province'] == province].copy() if province else df.copy()
 
 
 
 
 
154
  filtered_df["Similarity"] = filtered_df["Embedding"].apply(
155
  lambda x: np.dot(query_embedding, x) /
156
  (np.linalg.norm(query_embedding) * np.linalg.norm(x))
157
  )
 
158
  results = filtered_df.sort_values("Similarity", ascending=False).head(top_k)
159
  return results
160
 
161
  # ----------------------------- #
162
+ # Main RAG Generator
163
  # ----------------------------- #
164
  def generate_with_rag(query):
165
  if is_disallowed(query):
166
  return INTRO_TEXT + "Sorry — I can’t help with harmful topics."
 
167
  if is_off_topic(query):
168
+ return INTRO_TEXT + "Sorry — I can only answer questions about tenancy and housing law."
 
 
169
 
170
  province = detect_province(query)
171
  top_docs_df = retrieve_with_pandas(query, province=province, top_k=2)
 
172
  if len(top_docs_df) == 0:
173
  return INTRO_TEXT + "I couldn't find relevant information."
174
 
 
183
  ANSWER:
184
  """
185
 
186
+ # Generate response with ctransformers
187
+ response = llm(prompt, max_new_tokens=300, temperature=0.2)
188
+ return response[0]["generated_text"].split("ANSWER:")[-1].strip()
 
 
 
 
 
 
189
 
190
  # ----------------------------- #
191
  # Gradio UI
 
202
 
203
  if __name__ == "__main__":
204
  demo.launch(share=True)