File size: 5,903 Bytes
d7dee9d
 
669470c
d7dee9d
 
 
 
 
 
1f3ab5b
669470c
d7dee9d
 
669470c
d7dee9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669470c
d7dee9d
 
 
 
 
 
669470c
d7dee9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669470c
d7dee9d
 
 
 
 
 
669470c
d7dee9d
 
 
 
 
 
 
669470c
d7dee9d
 
 
 
 
 
 
 
a3799ac
d7dee9d
 
 
 
669470c
d7dee9d
 
 
 
 
 
 
 
669470c
d7dee9d
 
5542b80
 
d7dee9d
9a0c7ad
669470c
d7dee9d
 
 
 
 
 
669470c
d7dee9d
1f3ab5b
d7dee9d
 
 
 
 
 
 
 
 
669470c
d7dee9d
 
 
 
 
9a0c7ad
d7dee9d
669470c
 
d7dee9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
import gradio as gr
import faiss
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
import spaces

# Ensure an HF Token is present for gated models (like Llama 3)
HF_TOKEN = os.getenv("HF_TOKEN")

class MyRAGPipeline:
    '''
    Wrapper class for RAG pipeline.
    '''
    def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True):
        if tokenizer_name is None:
            tokenizer_name = model_name 
            
        self.embedding_model_name = embedding_model_name
        self.max_new_tokens = MAX_NEW_TOKENS
        
        print(f"Loading Model: {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, 
            device_map="auto", 
            dtype=torch.bfloat16,
            token=HF_TOKEN
        )
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.tokenizer.padding_side = "left"
        
        print("Loading Embeddings...")
        self.embedding_model = HuggingFaceEmbeddings(
            model_name=self.embedding_model_name,
            multi_process=False, # Set to False for stability in Spaces
            model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
            encode_kwargs={"normalize_embeddings": True},
        )     

        print(f"Loading Vector DB from {vector_db_path}...")
        # Check if index exists to prevent crash
        if not os.path.exists(vector_db_path):
             raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
             
        self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)

        # FAISS GPU optimization (If available)
        if torch.cuda.is_available():
            try:
                res = faiss.StandardGpuResources()
                co = faiss.GpuClonerOptions()
                co.useFloat16 = True
                self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co)
            except Exception as e:
                print(f"Could not load FAISS to GPU, running on CPU: {e}")
        
        # Initialize Pipeline
        self.pipe = pipeline(
            'text-generation',
            model=self.model,
            torch_dtype=torch.bfloat16,
            device_map='auto',
            tokenizer=self.tokenizer,
            max_new_tokens=self.max_new_tokens,
            temperature=TEMPERATURE,
            do_sample=DO_SAMPLE,
            pad_token_id=self.tokenizer.eos_token_id,
            # return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt
            return_full_text=False 
        )

    def retrieve(self, query, num_docs=3):
        '''
        Returns the k most similar documents to the query
        '''
        retrieved_docs = self.vector_db.similarity_search(query, k=num_docs)
        return retrieved_docs

    def _format_prompt(self, query, retrieved_docs):
        context = "\nExtracted documents:\n"
        # Adjusted extraction slightly to handle missing metadata keys gracefully
        for doc in retrieved_docs:
            section = doc.metadata.get('Section', 'N/A')
            subtitle = doc.metadata.get('Subtitle', 'Context')
            context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"

        prompt = f'''
        You are a helpful legal interpreter.
        You are given the following context:
        {context}\n\n
        Using the information contained in the context,
        give a comprehensive answer to the question.
        Respond only to the question asked. Your response should be concise and relevant to the question.
        Always provide the section number and title of the source document.
        Also please use plain English when responding, not legal jargon.
        
        Question: {query}"
        '''
        return prompt

    def easy_generate(self, query, num_docs=3):
        retrieved_docs = self.retrieve(query, num_docs=num_docs)
        prompt = self._format_prompt(query, retrieved_docs)
        
        # Because we used return_full_text=False in the pipeline, 
        # this returns only the answer.
        result = self.pipe(prompt)[0]['generated_text']
        return result

# --- INITIALIZATION ---
# Using standard paths and models
#MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct'
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
VECDB_PATH = './index/'

# Initialize the RAG system globally so it doesn't reload on every message
try:
    rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH)
except Exception as e:
    rag = None
    print(f"Error initializing RAG: {e}")

# --- GRADIO INTERFACE ---
@spaces.GPU(duration=10)
def chat_function(message, history):
    if rag is None:
        return "System Error: The RAG pipeline failed to initialize. Check logs and ensure the 'index/' folder is uploaded."
    
    try:
        response = rag.easy_generate(message)
        return response
    except Exception as e:
        return f"An error occurred: {str(e)}"

demo = gr.ChatInterface(
    fn=chat_function,
    type="messages",
    title="Legal RAG Assistant",
    description="Ask a question about the legal documents indexed in the database.",
    examples=["Can the mayor move outside of the city limits?", "What are the zoning laws?", "Is there a maximum building height?","How do I pay a parking ticket?", "How many chickens can I own?"]
)

if __name__ == "__main__":
    demo.launch()