zm-f21 commited on
Commit
0e38642
·
verified ·
1 Parent(s): 87fa24c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -23
app.py CHANGED
@@ -1,25 +1,45 @@
1
- import pandas as pd
2
- import numpy as np
 
3
  import re
4
- from sentence_transformers import SentenceTransformer
 
 
 
 
 
5
  import gradio as gr
6
- from ctransformers import AutoModelForCausalLM
 
 
 
7
 
8
  # ----------------------------- #
9
- # Load Hosted Mistral 7B Q4_0
10
  # ----------------------------- #
11
- llm = AutoModelForCausalLM.from_pretrained(
12
- "TheBloke/Mistral-7B-v0.1-Q4_0", # hosted HF model
13
- model_type="mistral", # model type
14
- gpu_layers=32 # adjust based on GPU/VRAM
15
- )
16
 
 
 
 
17
  embedding_model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
18
 
19
  # ----------------------------- #
20
- # Parse & Prepare Your Documents
 
 
 
 
 
 
 
 
 
 
21
  # ----------------------------- #
22
- # Example parsing function (from your previous code)
23
  date_pattern = re.compile(r"(\d{4}[-]\d{2}[-_]\d{2})")
24
 
25
  def parse_metadata_and_content(raw_text):
@@ -29,6 +49,7 @@ def parse_metadata_and_content(raw_text):
29
  header, content = raw_text.split("CONTENT:", 1)
30
  metadata = {}
31
  lines = header.strip().split("\n")
 
32
  pdf_list = []
33
 
34
  for line in lines:
@@ -43,12 +64,37 @@ def parse_metadata_and_content(raw_text):
43
 
44
  return metadata, content.strip()
45
 
46
- # Load your text documents into df as before
47
- # df = pd.DataFrame(documents)
48
- # df["Embedding"] = df["text"].apply(lambda x: embedding_model.encode(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # ----------------------------- #
51
- # Province Detection
52
  # ----------------------------- #
53
  def detect_province(query):
54
  provinces = {
@@ -72,6 +118,7 @@ def detect_province(query):
72
  "nwt": "Northwest Territories",
73
  "northwest territories": "Northwest Territories"
74
  }
 
75
  q = query.lower()
76
  for key, prov in provinces.items():
77
  if key in q:
@@ -79,7 +126,7 @@ def detect_province(query):
79
  return None
80
 
81
  # ----------------------------- #
82
- # Guardrails
83
  # ----------------------------- #
84
  def is_disallowed(query):
85
  banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
@@ -101,7 +148,7 @@ INTRO_TEXT = (
101
  )
102
 
103
  # ----------------------------- #
104
- # Retrieval
105
  # ----------------------------- #
106
  def retrieve_with_pandas(query, province=None, top_k=2):
107
  query_embedding = embedding_model.encode([query])[0]
@@ -120,7 +167,7 @@ def retrieve_with_pandas(query, province=None, top_k=2):
120
  return results
121
 
122
  # ----------------------------- #
123
- # RAG Generator
124
  # ----------------------------- #
125
  def generate_with_rag(query):
126
  if is_disallowed(query):
@@ -138,6 +185,7 @@ def generate_with_rag(query):
138
  return INTRO_TEXT + "I couldn't find relevant information."
139
 
140
  context = " ".join(top_docs_df["text"].tolist())
 
141
  prompt = f"""
142
  Use the context below to answer the question.
143
  CONTEXT:
@@ -147,15 +195,18 @@ QUESTION:
147
  ANSWER:
148
  """
149
 
150
- response = llm(
151
- prompt,
 
152
  max_new_tokens=300,
153
  temperature=0.2
154
  )
155
- return response[0]["generated_text"].split("ANSWER:")[-1].strip()
 
 
156
 
157
  # ----------------------------- #
158
- # Gradio UI
159
  # ----------------------------- #
160
  def ui_fn(query):
161
  return generate_with_rag(query)
@@ -170,3 +221,4 @@ demo = gr.Interface(
170
  if __name__ == "__main__":
171
  demo.launch(share=True)
172
 
 
 
1
+ # ----------------------------- #
2
+ # Imports
3
+ # ----------------------------- #
4
  import re
5
+ import os
6
+ import zipfile
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
  import gradio as gr
12
+ from sentence_transformers import SentenceTransformer
13
+
14
+ # Mistral Inference
15
+ from mistral_inference import MistralForCausalLM, MistralTokenizer
16
 
17
  # ----------------------------- #
18
+ # Load Local Mistral Model
19
  # ----------------------------- #
20
+ model_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
21
+
22
+ tokenizer = MistralTokenizer.from_pretrained(model_path)
23
+ llm = MistralForCausalLM.from_pretrained(model_path)
 
24
 
25
+ # ----------------------------- #
26
+ # Load Embedding Model
27
+ # ----------------------------- #
28
  embedding_model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
29
 
30
  # ----------------------------- #
31
+ # Extract ZIP
32
+ # ----------------------------- #
33
+ zip_path = "provinces.zip"
34
+ extract_folder = "provinces_texts"
35
+
36
+ if not os.path.exists(extract_folder):
37
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
38
+ zip_ref.extractall(extract_folder)
39
+
40
+ # ----------------------------- #
41
+ # Parse Files
42
  # ----------------------------- #
 
43
  date_pattern = re.compile(r"(\d{4}[-]\d{2}[-_]\d{2})")
44
 
45
  def parse_metadata_and_content(raw_text):
 
49
  header, content = raw_text.split("CONTENT:", 1)
50
  metadata = {}
51
  lines = header.strip().split("\n")
52
+
53
  pdf_list = []
54
 
55
  for line in lines:
 
64
 
65
  return metadata, content.strip()
66
 
67
+ documents = []
68
+
69
+ for root, dirs, files in os.walk(extract_folder):
70
+ for filename in files:
71
+ if filename.startswith("._"):
72
+ continue
73
+ if filename.endswith(".txt"):
74
+ filepath = os.path.join(root, filename)
75
+ try:
76
+ with open(filepath, "r", encoding="latin-1") as f:
77
+ raw = f.read()
78
+ metadata, content = parse_metadata_and_content(raw)
79
+ paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
80
+
81
+ for p in paragraphs:
82
+ documents.append({
83
+ "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
84
+ "province": metadata.get("PROVINCE", "Unknown"),
85
+ "last_updated": metadata.get("LAST_UPDATED", "Unknown"),
86
+ "url": metadata.get("URL", "N/A"),
87
+ "pdf_links": metadata.get("PDF_LINKS", ""),
88
+ "text": p
89
+ })
90
+ except Exception:
91
+ continue
92
+
93
+ df = pd.DataFrame(documents)
94
+ df["Embedding"] = df["text"].apply(lambda x: embedding_model.encode(x))
95
 
96
  # ----------------------------- #
97
+ # Province Detection
98
  # ----------------------------- #
99
  def detect_province(query):
100
  provinces = {
 
118
  "nwt": "Northwest Territories",
119
  "northwest territories": "Northwest Territories"
120
  }
121
+
122
  q = query.lower()
123
  for key, prov in provinces.items():
124
  if key in q:
 
126
  return None
127
 
128
  # ----------------------------- #
129
+ # Guardrails
130
  # ----------------------------- #
131
  def is_disallowed(query):
132
  banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
 
148
  )
149
 
150
  # ----------------------------- #
151
+ # Retrieval
152
  # ----------------------------- #
153
  def retrieve_with_pandas(query, province=None, top_k=2):
154
  query_embedding = embedding_model.encode([query])[0]
 
167
  return results
168
 
169
  # ----------------------------- #
170
+ # Main RAG Generator using MistralInference
171
  # ----------------------------- #
172
  def generate_with_rag(query):
173
  if is_disallowed(query):
 
185
  return INTRO_TEXT + "I couldn't find relevant information."
186
 
187
  context = " ".join(top_docs_df["text"].tolist())
188
+
189
  prompt = f"""
190
  Use the context below to answer the question.
191
  CONTEXT:
 
195
  ANSWER:
196
  """
197
 
198
+ # Generate response
199
+ response = llm.generate(
200
+ tokenizer.encode(prompt, return_tensors="pt"),
201
  max_new_tokens=300,
202
  temperature=0.2
203
  )
204
+
205
+ answer = tokenizer.decode(response[0], skip_special_tokens=True)
206
+ return answer.split("ANSWER:")[-1].strip()
207
 
208
  # ----------------------------- #
209
+ # Gradio UI
210
  # ----------------------------- #
211
  def ui_fn(query):
212
  return generate_with_rag(query)
 
221
  if __name__ == "__main__":
222
  demo.launch(share=True)
223
 
224
+