jme-datasci commited on
Commit
30d6e30
·
1 Parent(s): effe941

added examples, updated readme

Browse files
Files changed (2) hide show
  1. README.md +63 -68
  2. app.py +11 -2
README.md CHANGED
@@ -14,8 +14,6 @@ license: apache-2.0
14
  short_description: RAG Enabled ChatBot for Charlottesville Municipal Code
15
  ---
16
 
17
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
18
-
19
  # Charlottesville Local Ordinance Assistant
20
 
21
  ## 1. Introduction
@@ -77,117 +75,114 @@ import pandas as pd
77
  import random
78
 
79
  class MyRAGPipeline:
80
- '''
81
- Wrapper class for RAG pipeline.
82
- '''
83
- def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True):
84
- if tokenizer_name is None:
85
- tokenizer_name = model_name
86
-
87
  self.embedding_model_name = embedding_model_name
88
- self.max_new_tokens = MAX_NEW_TOKENS
89
 
90
  print(f"Loading Model: {model_name}...")
91
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN)
 
 
 
 
92
  self.model = AutoModelForCausalLM.from_pretrained(
93
  model_name,
94
- device_map="auto",
95
- dtype=torch.bfloat16,
96
  token=HF_TOKEN
97
  )
 
98
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
99
  self.tokenizer.padding_side = "left"
100
 
101
  print("Loading Embeddings...")
102
  self.embedding_model = HuggingFaceEmbeddings(
103
  model_name=self.embedding_model_name,
104
- multi_process=False, # Set to False for stability in Spaces
105
- model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
106
  encode_kwargs={"normalize_embeddings": True},
107
  )
108
 
109
  print(f"Loading Vector DB from {vector_db_path}...")
110
- # Check if index exists to prevent crash
111
  if not os.path.exists(vector_db_path):
112
  raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
113
 
114
  self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)
115
-
116
- # FAISS GPU optimization (If available)
117
- if torch.cuda.is_available():
118
- try:
119
- res = faiss.StandardGpuResources()
120
- co = faiss.GpuClonerOptions()
121
- co.useFloat16 = True
122
- self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co)
123
- except Exception as e:
124
- print(f"Could not load FAISS to GPU, running on CPU: {e}")
125
-
126
- # Initialize Pipeline
127
- self.pipe = pipeline(
128
- 'text-generation',
129
- model=self.model,
130
- torch_dtype=torch.bfloat16,
131
- device_map='auto',
132
- tokenizer=self.tokenizer,
133
- max_new_tokens=self.max_new_tokens,
134
- temperature=TEMPERATURE,
135
- do_sample=DO_SAMPLE,
136
- pad_token_id=self.tokenizer.eos_token_id,
137
- # return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt
138
- return_full_text=False
139
- )
140
 
141
  def retrieve(self, query, num_docs=3):
142
- '''
143
- Returns the k most similar documents to the query
144
- '''
145
- retrieved_docs = self.vector_db.similarity_search(query, k=num_docs)
146
- return retrieved_docs
147
 
148
  def _format_prompt(self, query, retrieved_docs):
149
- context = "\nExtracted documents:\n"
150
- # Adjusted extraction slightly to handle missing metadata keys gracefully
151
  for doc in retrieved_docs:
152
  section = doc.metadata.get('Section', 'N/A')
153
  subtitle = doc.metadata.get('Subtitle', 'Context')
154
  context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"
155
 
156
- prompt = f'''
157
- You are a helpful legal interpreter.
158
- You are given the following context:
159
- {context}\n\n
160
- Using the information contained in the context,
161
- give a comprehensive answer to the question.
162
- Respond only to the question asked. Your response should be concise and relevant to the question.
163
- Always provide the section number and title of the source document.
164
- Also please use plain English when responding, not legal jargon.
 
 
 
 
 
 
165
 
166
- Question: {query}"
167
- '''
 
 
 
 
168
  return prompt
169
 
170
- def easy_generate(self, query, num_docs=3):
171
- retrieved_docs = self.retrieve(query, num_docs=num_docs)
172
- prompt = self._format_prompt(query, retrieved_docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- # Because we used return_full_text=False in the pipeline,
175
- # this returns only the answer.
176
- result = self.pipe(prompt)[0]['generated_text']
177
- return result
178
 
179
  # --- INITIALIZATION ---
180
  # Using standard paths and models
181
  MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
182
  EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
183
- VECDB_PATH = './index/'
184
 
185
  rag = MyRAGPipeline(model_name, embedding_name, vecdb_path)
186
 
187
  prompt = "My neighbor is playing loud music on their porch. What time does the 'quiet period' start, and what is the maximum decibel level allowed in a residential zone?"
188
 
189
 
190
- print(rag.easy_generate(prompt))
191
 
192
 
193
  ```
 
14
  short_description: RAG Enabled ChatBot for Charlottesville Municipal Code
15
  ---
16
 
 
 
17
  # Charlottesville Local Ordinance Assistant
18
 
19
  ## 1. Introduction
 
75
  import random
76
 
77
  class MyRAGPipeline:
78
+ def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str):
 
 
 
 
 
 
79
  self.embedding_model_name = embedding_model_name
80
+ self.max_new_tokens = 500
81
 
82
  print(f"Loading Model: {model_name}...")
83
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
84
+
85
+ # --- CRITICAL: Load to CPU first ---
86
+ # ZeroGPU does not have a GPU available during global startup.
87
+ # We load the weights into System RAM now, and move them to GPU later.
88
  self.model = AutoModelForCausalLM.from_pretrained(
89
  model_name,
90
+ device_map="cpu", # Force CPU loading
91
+ torch_dtype=torch.bfloat16,
92
  token=HF_TOKEN
93
  )
94
+
95
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
96
  self.tokenizer.padding_side = "left"
97
 
98
  print("Loading Embeddings...")
99
  self.embedding_model = HuggingFaceEmbeddings(
100
  model_name=self.embedding_model_name,
101
+ model_kwargs={"device": "cpu"}, # Keep embeddings on CPU
 
102
  encode_kwargs={"normalize_embeddings": True},
103
  )
104
 
105
  print(f"Loading Vector DB from {vector_db_path}...")
 
106
  if not os.path.exists(vector_db_path):
107
  raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
108
 
109
  self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)
110
+ print("RAG Pipeline Initialized (CPU Mode)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  def retrieve(self, query, num_docs=3):
113
+ return self.vector_db.similarity_search(query, k=num_docs)
 
 
 
 
114
 
115
  def _format_prompt(self, query, retrieved_docs):
116
+ # 1. Build Context
117
+ context = "Extracted documents:\n"
118
  for doc in retrieved_docs:
119
  section = doc.metadata.get('Section', 'N/A')
120
  subtitle = doc.metadata.get('Subtitle', 'Context')
121
  context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"
122
 
123
+ # 2. Universal Chat Template (Works for Qwen, Llama, Mistral, etc.)
124
+ messages = [
125
+ {
126
+ "role": "system",
127
+ "content": f"You are a helpful legal interpreter. Use the following context to answer the user's question.\nContext:\n{context}"
128
+ },
129
+ {
130
+ "role": "system",
131
+ "content": "Using the information contained in the context, give a comprehensive answer to the question. Respond only to the question asked. Your response should be concise and relevant to the question. Always provide the section number and title of the source document. Also please use plain English when responding, not legal jargon. \n Now answer the following question."
132
+ },
133
+ {
134
+ "role": "user",
135
+ "content": query
136
+ }
137
+ ]
138
 
139
+ # This applies the correct format for WHATEVER model you are using
140
+ prompt = self.tokenizer.apply_chat_template(
141
+ messages,
142
+ tokenize=False,
143
+ add_generation_prompt=True
144
+ )
145
  return prompt
146
 
147
+ def generate(self, query, num_docs=3):
148
+ # 1. Retrieve
149
+ retrieved_docs = self.retrieve(query, num_docs)
150
+
151
+ # 2. Format Prompt
152
+ prompt_str = self._format_prompt(query, retrieved_docs)
153
+
154
+ # 3. Tokenize
155
+ inputs = self.tokenizer(prompt_str, return_tensors="pt").to(self.model.device)
156
+
157
+ # 4. Generate (Streaming is simpler for direct model usage, but here we do blocking)
158
+ with torch.no_grad():
159
+ outputs = self.model.generate(
160
+ **inputs,
161
+ max_new_tokens=self.max_new_tokens,
162
+ temperature=0.7,
163
+ do_sample=True,
164
+ pad_token_id=self.tokenizer.eos_token_id
165
+ )
166
+
167
+ # 5. Decode
168
+ # Slicing [input_len:] ensures we only return the new text, not the prompt
169
+ input_len = inputs.input_ids.shape[1]
170
+ generated_text = self.tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
171
 
172
+ return generated_text
 
 
 
173
 
174
  # --- INITIALIZATION ---
175
  # Using standard paths and models
176
  MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
177
  EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
178
+ VECDB_PATH = 'index/'
179
 
180
  rag = MyRAGPipeline(model_name, embedding_name, vecdb_path)
181
 
182
  prompt = "My neighbor is playing loud music on their porch. What time does the 'quiet period' start, and what is the maximum decibel level allowed in a residential zone?"
183
 
184
 
185
+ print(rag.generate(prompt))
186
 
187
 
188
  ```
app.py CHANGED
@@ -141,8 +141,17 @@ def chat_function(message, history):
141
  demo = gr.ChatInterface(
142
  fn=chat_function,
143
  type="messages",
144
- title="Legal RAG Assistant (Qwen 2.5)",
145
- description="Ask a question about the legal documents.",
 
 
 
 
 
 
 
 
 
146
  )
147
 
148
  if __name__ == "__main__":
 
141
  demo = gr.ChatInterface(
142
  fn=chat_function,
143
  type="messages",
144
+ title="Charlottesville Municipal RAG Assistant",
145
+ description="Ask a question about the City of Charlottesville municipal code. This app is intended to increase accessibility to the municipal code and is not a replacement for a legal professional. AI makes mistakes, check important information.",
146
+ examples=[
147
+ "My neighbor is playing loud music on their porch. What time does the 'quiet period' start, and what is the maximum decibel level allowed in a residential zone?",
148
+ "There is a massive oak tree on my property I want to cut down. Do I need permission from the city to remove it?",
149
+ "I got a parking ticket near the Downtown Mall. What is the deadline to pay the fine, and how do I contest it if I think it was issued in error?",
150
+ "I want to build a privacy fence in my backyard. How tall can it be before I need a permit, and are there different rules for the front yard versus the back yard?",
151
+ "I found a deer in my backyard. Can I keep it as a pet if I put a leash on it?",
152
+ "I'm having trouble catching fish in the Rivanna River. Is it legal to use explosives to help catch them?",
153
+ "Can I legally attach a flamethrower to my car to melt the snow on my driveway?",
154
+ "Is it legal for me to practice my bagpipes on the sidewalk at 2:00 AM if I'm technically walking and not 'loitering'?"]
155
  )
156
 
157
  if __name__ == "__main__":