Spaces:
Sleeping
Sleeping
Commit
·
d7dee9d
1
Parent(s):
a4fc9ce
updated app and readme
Browse files
README.md
CHANGED
|
@@ -15,3 +15,227 @@ short_description: RAG Enabled ChatBot for Charlottesville Municipal Code
|
|
| 15 |
---
|
| 16 |
|
| 17 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
---
|
| 16 |
|
| 17 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
| 18 |
+
|
| 19 |
+
# Charlottesville Local Ordinance Assistant
|
| 20 |
+
|
| 21 |
+
## 1. Introduction
|
| 22 |
+
Local laws are often written in dense legal terminology that the average person struggles to interpret, turning simple questions about parking or zoning into a maze of irrelevant sections and complex jargon. While current Large Language Models (LLMs) like ChatGPT have seen municipal codes, they are trained on codes from across the country, leading to generalized answers that may blend details from different jurisdictions and hallucinate non-existent regulations. To solve this, I developed the **Charlottesville Local Ordinance Assistant**, a system designed specifically to answer questions about Charlottesville, VA municipal code in plain English. This project utilizes a Retrieval-Augmented Generation (RAG) pipeline to ensure legal accuracy by retrieving up-to-date ordinances, coupled with a specific system prompt designed to translate that "legalese" into clear, accessible language without the need for computationally expensive fine-tuning. The results demonstrate that constraining the model to local data and utilizing strong prompt engineering significantly reduces hallucinations compared to off-the-shelf generalist models.
|
| 23 |
+
|
| 24 |
+
## 2. Data
|
| 25 |
+
|
| 26 |
+
For the RAG pipeline, the knowledge base consists of the unedited Charlottesville Municipal Code text, scraped and pre-processed from [Municode](https://library.municode.com/va/charlottesville/codes/code_of_ordinances). These chunks were not rephrased, ensuring that the retrieval mechanism pulls the exact letter of the law. To evaluate the RAG pipeline, I utilized a set of questions and answers generated from the original sections of the municipal code to validate retrieval accuracy (checking if the retrieved node matched the ground truth node for a given query).
|
| 27 |
+
|
| 28 |
+
## 3. Methodology
|
| 29 |
+
|
| 30 |
+
For the RAG methodology, I implemented a dense retrieval system. I selected **Qwen3-Embedding-0.6B** as the embedding model due to the relatively small size of the RAG corpus (municipal code). This model allows for high-precision retrieval without the latency of larger embedding models. The retrieved context is passed to the **Llama-3.2-1B** generator to synthesize the final answer.
|
| 31 |
+
|
| 32 |
+
## 4. Evaluation
|
| 33 |
+
|
| 34 |
+
### Benchmark Results
|
| 35 |
+
|
| 36 |
+
To strictly evaluate the legal reasoning and retrieval capabilities of the model, I utilized three established benchmarks: [LegalBench-RAG](https://github.com/hazyresearch/legalbench), [RAGBench](https://arxiv.org/abs/2306.16092), and [RAGTruth](https://arxiv.org/abs/2401.00396). I chose these because they specifically target the weaknesses of legal LLMs: the ability to reason over specific documents and the frequency of hallucinations.
|
| 37 |
+
|
| 38 |
+
The LegalBench-RAG, RAGBench, and my custom test split were all evaluated using **meta-llama/Llama-3.1-8B** as judge for eight different metrics:
|
| 39 |
+
|
| 40 |
+
* **Context Relevance**: Measures the proportion of retrieved information that is actually pertinent to the user's query.
|
| 41 |
+
* **Context Recall**: Assesses if the retrieved context contains all the necessary ground-truth information required to answer.
|
| 42 |
+
* **Chunk Relevance**: Evaluates the precision of individual retrieved document segments relative to the input query.
|
| 43 |
+
* **Faithfulness**: Checks if the generated answer is factually derived solely from the retrieved context (hallucination detection).
|
| 44 |
+
* **Answer Relevance**: Determines how well the generated response directly addresses the user's original prompt.
|
| 45 |
+
* **Answer Correctness**: Scores the accuracy of the generated answer against a known gold-standard reference.
|
| 46 |
+
* **Answer Completeness**: Checks if the response addresses all parts of the query without omitting key details.
|
| 47 |
+
* **Safety**: Measures the model's ability to refuse generating harmful, toxic, or inappropriate content.
|
| 48 |
+
|
| 49 |
+
Finally, because RAGTruth retrievals are frozen, only the faithfulness was evaluated. Benchmark results are shown below.
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
| Metric | LegalBenchRag | RAGBench | RAGTruth-QA | Custom Test Split |
|
| 53 |
+
|--------------------:|:-------------:|:--------:|:-----------:|:-----------------:|
|
| 54 |
+
| Context Relevance | 87.33 | 22.17 | - | |
|
| 55 |
+
| Context Recall | 47.63 | 20.56 | - | 78.87 |
|
| 56 |
+
| Chunk Relevance | 85.76 | 23.88 | - | |
|
| 57 |
+
| Faithfulness | 60.72 | 69.68 | 74.11 | 72.64 |
|
| 58 |
+
| Answer Relevance | 76.71 | 65.82 | - | 61.87 |
|
| 59 |
+
| Answer Correctness | 41.20 | 10.33 | - | 38.17 |
|
| 60 |
+
| Answer Completeness | 75.59 | 66.49 | - | 76.46 |
|
| 61 |
+
| Safety | 97.12 | 97.49 | - | 98.70 |
|
| 62 |
+
|
| 63 |
+
I qualitatively compared my primary model (Llama-3.2-1B) against **Qwen3-0.6B** (chosen as a smaller, efficient baseline) and **Qwen3-4B-Instruct-2507** (chosen as a larger, more capable baseline). Relative to these comparisons, the Llama-3.2-1B based Ordinance Assistant showed good performance whereas the larger Qwen 4B model didn't seem to give much better answers
|
| 64 |
+
|
| 65 |
+
## 5. Usage and Intended Uses
|
| 66 |
+
|
| 67 |
+
The intended use case for this model is to assist residents of Charlottesville, VA, in understanding local ordinances regarding zoning, parking, and noise complaints without needing a legal background. It is **not** a replacement for a lawyer but rather a tool for accessibility.
|
| 68 |
+
|
| 69 |
+
Below is an example of how the RAG pipeline class is constructed and used to generate responses with retrieval.
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
import torch
|
| 73 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModel
|
| 74 |
+
import faiss as fai
|
| 75 |
+
from langchain_community.vectorstores import FAISS
|
| 76 |
+
import os
|
| 77 |
+
import numpy as np
|
| 78 |
+
import pandas as pd
|
| 79 |
+
import random
|
| 80 |
+
|
| 81 |
+
class MyRAGPipeline:
|
| 82 |
+
'''
|
| 83 |
+
Wrapper class for RAG pipeline.
|
| 84 |
+
'''
|
| 85 |
+
def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name = None, MAX_NEW_TOKENS = 500, TEMPERATURE = 0.9, DO_SAMPLE = True):
|
| 86 |
+
if tokenizer_name is None:
|
| 87 |
+
tokenizer_name = model_name # default behavior is use the same tokenizer as the model
|
| 88 |
+
|
| 89 |
+
self.embedding_model_name = embedding_model_name
|
| 90 |
+
self.max_new_tokens = MAX_NEW_TOKENS
|
| 91 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 92 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", dtype = torch.bfloat16)
|
| 93 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 94 |
+
self.tokenizer.padding_side = "left"
|
| 95 |
+
|
| 96 |
+
self.embedding_model = HuggingFaceEmbeddings(
|
| 97 |
+
model_name=self.embedding_model_name,
|
| 98 |
+
multi_process=True,
|
| 99 |
+
model_kwargs={"device": "cuda"},
|
| 100 |
+
encode_kwargs={"normalize_embeddings": True}, # Set `True` for cosine similarity
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model,allow_dangerous_deserialization=True)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if torch.cuda.is_available():
|
| 107 |
+
res = fai.StandardGpuResources()
|
| 108 |
+
co = fai.GpuClonerOptions()
|
| 109 |
+
co.useFloat16 = True
|
| 110 |
+
self.vector_db.index = fai.index_cpu_to_gpu(res, 0, self.vector_db.index,co)
|
| 111 |
+
|
| 112 |
+
self.pipe = pipeline(
|
| 113 |
+
'text-generation',
|
| 114 |
+
model=self.model,
|
| 115 |
+
dtype = torch.bfloat16,
|
| 116 |
+
device_map = 'auto',
|
| 117 |
+
tokenizer = self.tokenizer,
|
| 118 |
+
max_new_tokens = self.max_new_tokens,
|
| 119 |
+
temperature = TEMPERATURE,
|
| 120 |
+
do_sample = DO_SAMPLE,
|
| 121 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 122 |
+
batch_size = 8
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def retrieve(self, query, num_docs=5):
|
| 128 |
+
'''
|
| 129 |
+
Returns the k most similar documents to the query
|
| 130 |
+
'''
|
| 131 |
+
retrieved_docs = self.vector_db.similarity_search(query, k=num_docs)
|
| 132 |
+
return retrieved_docs
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _format_prompt(self, query, retrieved_docs):
|
| 137 |
+
context = "\nExtracted documents:\n"
|
| 138 |
+
context += "".join([f"{doc.metadata['Section']} - {doc.metadata['Subtitle']}:::\n" + doc.page_content + "\n\n" for doc in retrieved_docs])
|
| 139 |
+
|
| 140 |
+
prompt = f'''
|
| 141 |
+
You are a helpful legal interpreter.
|
| 142 |
+
You are given the following context:
|
| 143 |
+
{context}\n\n
|
| 144 |
+
Using the information contained in the context,
|
| 145 |
+
give a comprehensive answer to the question.
|
| 146 |
+
Respond only to the question asked. Your response should be concise and relevant to the question.
|
| 147 |
+
Always provide the section number and title of the source document.
|
| 148 |
+
Now please answer the follwing question in plain English using less than {self.max_new_tokens} words.
|
| 149 |
+
Question: {query}"
|
| 150 |
+
'''
|
| 151 |
+
return prompt
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _simple_format(self, query, retrieved_docs):
|
| 156 |
+
context = "\nExtracted documents:\n"
|
| 157 |
+
context += "".join([f"{doc.page_content}" + "\n\n" for doc in retrieved_docs])
|
| 158 |
+
prompt = f'''
|
| 159 |
+
You are given the following context:
|
| 160 |
+
{context}\n\n
|
| 161 |
+
Using the information contained in the context,
|
| 162 |
+
give a comprehensive answer to the question.
|
| 163 |
+
Respond only to the question asked. Your response should be concise and relevant to the question.
|
| 164 |
+
Now please answer the follwing question in plain English using less than {self.max_new_tokens} words.
|
| 165 |
+
Question: {query}"
|
| 166 |
+
'''
|
| 167 |
+
return prompt
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def easy_generate(self, query, num_docs = 5):
|
| 172 |
+
retrieved_docs = self.retrieve(query, num_docs=num_docs)
|
| 173 |
+
prompt = self._format_prompt(query, retrieved_docs)
|
| 174 |
+
return self.pipe(prompt)[0]['generated_text']
|
| 175 |
+
|
| 176 |
+
def generate(self, query, retrieved_docs):
|
| 177 |
+
prompt = self._simple_format(query, retrieved_docs)
|
| 178 |
+
return self.pipe(prompt)[0]['generated_text']
|
| 179 |
+
|
| 180 |
+
def batch_generate(self, prompt_list, batch_size = 8):
|
| 181 |
+
return self.pipe(prompt_list, return_full_text=False, batch_size = batch_size)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def batch_retrieve(self, queries, num_docs=5, batch_size=256):
|
| 186 |
+
"""
|
| 187 |
+
Retrieves documents using GPU acceleration with a progress bar.
|
| 188 |
+
Processes queries in chunks to allow monitoring without sacrificing speed.
|
| 189 |
+
"""
|
| 190 |
+
all_retrieved_docs = []
|
| 191 |
+
docstore = self.vector_db.docstore
|
| 192 |
+
index_to_id = self.vector_db.index_to_docstore_id
|
| 193 |
+
|
| 194 |
+
for i in tqdm(range(0, len(queries), batch_size), desc="Batch Search"):
|
| 195 |
+
batch_queries = queries[i : i + batch_size]
|
| 196 |
+
query_vectors = self.embedding_model.embed_documents(batch_queries)
|
| 197 |
+
query_matrix = np.array(query_vectors, dtype=np.float32)
|
| 198 |
+
D, I = self.vector_db.index.search(query_matrix, num_docs)
|
| 199 |
+
|
| 200 |
+
for row_indices in I:
|
| 201 |
+
docs_for_query = []
|
| 202 |
+
for idx in row_indices:
|
| 203 |
+
if idx == -1: continue
|
| 204 |
+
_id = index_to_id[idx]
|
| 205 |
+
doc = docstore.search(_id)
|
| 206 |
+
docs_for_query.append(doc)
|
| 207 |
+
all_retrieved_docs.append(docs_for_query)
|
| 208 |
+
return all_retrieved_docs
|
| 209 |
+
|
| 210 |
+
model_name = 'meta-llama/Llama-3.2-1B-Instruct'
|
| 211 |
+
embedding_name = 'Qwen/Qwen3-Embedding-0.6B'
|
| 212 |
+
vecdb_path = 'index/'
|
| 213 |
+
|
| 214 |
+
rag = MyRAGPipeline(model_name, embedding_name, vecdb_path)
|
| 215 |
+
|
| 216 |
+
prompt = "Can the mayor move outside of the city limits?"
|
| 217 |
+
|
| 218 |
+
print(rag.easy_generate(prompt))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
## Prompt Format
|
| 224 |
+
|
| 225 |
+
The model relies on a strict system prompt to ensure the output is simplified but factually accurate. The prompt injects the retrieved RAG context directly into the system message.
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
## Expected Output Format
|
| 231 |
+
|
| 232 |
+
The model is expected to output a plain-English translation of the input text, simplifying sentence structure while retaining critical entities (dates, fines, locations).
|
| 233 |
+
|
| 234 |
+
```
|
| 235 |
+
The Clerk of the Council is responsible for keeping the city's official seal.
|
| 236 |
+
They must stamp this seal on any papers or documents when the Council's laws
|
| 237 |
+
or decisions require it.
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
## Limitations
|
| 241 |
+
The primary limitation of this model is that while it reduces hallucinations, it does not eliminate them; users should verify important legal details with the official [Municode](https://library.municode.com/va/charlottesville/codes/code_of_ordinances) source. Additionally, the model is strictly limited to the Charlottesville context; applying it to Albemarle County or other jurisdictions will result in incorrect information. Finally, because the model was not fine-tuned, it may occasionally slip back into dense terminology if the retrieved ordinance is exceptionally complex.
|
app.py
CHANGED
|
@@ -1,70 +1,146 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
additional_inputs=[
|
| 50 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
| 51 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 52 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 53 |
-
gr.Slider(
|
| 54 |
-
minimum=0.1,
|
| 55 |
-
maximum=1.0,
|
| 56 |
-
value=0.95,
|
| 57 |
-
step=0.05,
|
| 58 |
-
label="Top-p (nucleus sampling)",
|
| 59 |
-
),
|
| 60 |
-
],
|
| 61 |
-
)
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
if __name__ == "__main__":
|
| 70 |
-
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
import gradio as gr
|
| 4 |
+
import faiss
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 8 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
+
from langchain_community.vectorstores import FAISS
|
| 10 |
|
| 11 |
+
# Ensure an HF Token is present for gated models (like Llama 3)
|
| 12 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 13 |
|
| 14 |
+
class MyRAGPipeline:
|
| 15 |
+
'''
|
| 16 |
+
Wrapper class for RAG pipeline.
|
| 17 |
+
'''
|
| 18 |
+
def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True):
|
| 19 |
+
if tokenizer_name is None:
|
| 20 |
+
tokenizer_name = model_name
|
| 21 |
+
|
| 22 |
+
self.embedding_model_name = embedding_model_name
|
| 23 |
+
self.max_new_tokens = MAX_NEW_TOKENS
|
| 24 |
+
|
| 25 |
+
print(f"Loading Model: {model_name}...")
|
| 26 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN)
|
| 27 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 28 |
+
model_name,
|
| 29 |
+
device_map="auto",
|
| 30 |
+
dtype=torch.bfloat16,
|
| 31 |
+
token=HF_TOKEN
|
| 32 |
+
)
|
| 33 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 34 |
+
self.tokenizer.padding_side = "left"
|
| 35 |
+
|
| 36 |
+
print("Loading Embeddings...")
|
| 37 |
+
self.embedding_model = HuggingFaceEmbeddings(
|
| 38 |
+
model_name=self.embedding_model_name,
|
| 39 |
+
multi_process=False, # Set to False for stability in Spaces
|
| 40 |
+
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
|
| 41 |
+
encode_kwargs={"normalize_embeddings": True},
|
| 42 |
+
)
|
| 43 |
|
| 44 |
+
print(f"Loading Vector DB from {vector_db_path}...")
|
| 45 |
+
# Check if index exists to prevent crash
|
| 46 |
+
if not os.path.exists(vector_db_path):
|
| 47 |
+
raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
|
| 48 |
+
|
| 49 |
+
self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)
|
| 50 |
|
| 51 |
+
# FAISS GPU optimization (If available)
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
try:
|
| 54 |
+
res = faiss.StandardGpuResources()
|
| 55 |
+
co = faiss.GpuClonerOptions()
|
| 56 |
+
co.useFloat16 = True
|
| 57 |
+
self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Could not load FAISS to GPU, running on CPU: {e}")
|
| 60 |
+
|
| 61 |
+
# Initialize Pipeline
|
| 62 |
+
self.pipe = pipeline(
|
| 63 |
+
'text-generation',
|
| 64 |
+
model=self.model,
|
| 65 |
+
torch_dtype=torch.bfloat16,
|
| 66 |
+
device_map='auto',
|
| 67 |
+
tokenizer=self.tokenizer,
|
| 68 |
+
max_new_tokens=self.max_new_tokens,
|
| 69 |
+
temperature=TEMPERATURE,
|
| 70 |
+
do_sample=DO_SAMPLE,
|
| 71 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 72 |
+
# return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt
|
| 73 |
+
return_full_text=False
|
| 74 |
+
)
|
| 75 |
|
| 76 |
+
def retrieve(self, query, num_docs=3):
|
| 77 |
+
'''
|
| 78 |
+
Returns the k most similar documents to the query
|
| 79 |
+
'''
|
| 80 |
+
retrieved_docs = self.vector_db.similarity_search(query, k=num_docs)
|
| 81 |
+
return retrieved_docs
|
| 82 |
|
| 83 |
+
def _format_prompt(self, query, retrieved_docs):
|
| 84 |
+
context = "\nExtracted documents:\n"
|
| 85 |
+
# Adjusted extraction slightly to handle missing metadata keys gracefully
|
| 86 |
+
for doc in retrieved_docs:
|
| 87 |
+
section = doc.metadata.get('Section', 'N/A')
|
| 88 |
+
subtitle = doc.metadata.get('Subtitle', 'Context')
|
| 89 |
+
context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"
|
| 90 |
|
| 91 |
+
prompt = f'''
|
| 92 |
+
You are a helpful legal interpreter.
|
| 93 |
+
You are given the following context:
|
| 94 |
+
{context}\n\n
|
| 95 |
+
Using the information contained in the context,
|
| 96 |
+
give a comprehensive answer to the question.
|
| 97 |
+
Respond only to the question asked. Your response should be concise and relevant to the question.
|
| 98 |
+
Always provide the section number and title of the source document.
|
| 99 |
+
|
| 100 |
+
Question: {query}"
|
| 101 |
+
'''
|
| 102 |
+
return prompt
|
| 103 |
|
| 104 |
+
def easy_generate(self, query, num_docs=3):
|
| 105 |
+
retrieved_docs = self.retrieve(query, num_docs=num_docs)
|
| 106 |
+
prompt = self._format_prompt(query, retrieved_docs)
|
| 107 |
+
|
| 108 |
+
# Because we used return_full_text=False in the pipeline,
|
| 109 |
+
# this returns only the answer.
|
| 110 |
+
result = self.pipe(prompt)[0]['generated_text']
|
| 111 |
+
return result
|
| 112 |
|
| 113 |
+
# --- INITIALIZATION ---
|
| 114 |
+
# Using standard paths and models
|
| 115 |
+
MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct'
|
| 116 |
+
EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
|
| 117 |
+
VECDB_PATH = 'index/' # Make sure you upload this folder to your Space!
|
| 118 |
|
| 119 |
+
# Initialize the RAG system globally so it doesn't reload on every message
|
| 120 |
+
try:
|
| 121 |
+
rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
rag = None
|
| 124 |
+
print(f"Error initializing RAG: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
# --- GRADIO INTERFACE ---
|
| 127 |
+
def chat_function(message, history):
|
| 128 |
+
if rag is None:
|
| 129 |
+
return "System Error: The RAG pipeline failed to initialize. Check logs and ensure the 'index/' folder is uploaded."
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
response = rag.easy_generate(message)
|
| 133 |
+
return response
|
| 134 |
+
except Exception as e:
|
| 135 |
+
return f"An error occurred: {str(e)}"
|
| 136 |
|
| 137 |
+
demo = gr.ChatInterface(
|
| 138 |
+
fn=chat_function,
|
| 139 |
+
type="messages",
|
| 140 |
+
title="Legal RAG Assistant",
|
| 141 |
+
description="Ask a question about the legal documents indexed in the database.",
|
| 142 |
+
examples=["Can the mayor move outside of the city limits?", "What are the zoning laws?"],
|
| 143 |
+
)
|
| 144 |
|
| 145 |
if __name__ == "__main__":
|
| 146 |
+
demo.launch()
|