import os import torch import gradio as gr import faiss import numpy as np from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS import spaces # Ensure an HF Token is present for gated models (like Llama 3) HF_TOKEN = os.getenv("HF_TOKEN") class MyRAGPipeline: ''' Wrapper class for RAG pipeline. ''' def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True): if tokenizer_name is None: tokenizer_name = model_name self.embedding_model_name = embedding_model_name self.max_new_tokens = MAX_NEW_TOKENS print(f"Loading Model: {model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN) self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", dtype=torch.bfloat16, token=HF_TOKEN ) self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.padding_side = "left" print("Loading Embeddings...") self.embedding_model = HuggingFaceEmbeddings( model_name=self.embedding_model_name, multi_process=False, # Set to False for stability in Spaces model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}, encode_kwargs={"normalize_embeddings": True}, ) print(f"Loading Vector DB from {vector_db_path}...") # Check if index exists to prevent crash if not os.path.exists(vector_db_path): raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.") self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True) # FAISS GPU optimization (If available) if torch.cuda.is_available(): try: res = faiss.StandardGpuResources() co = faiss.GpuClonerOptions() co.useFloat16 = True self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co) except Exception as e: print(f"Could not load FAISS to GPU, running on CPU: {e}") # Initialize Pipeline self.pipe = pipeline( 'text-generation', model=self.model, torch_dtype=torch.bfloat16, device_map='auto', tokenizer=self.tokenizer, max_new_tokens=self.max_new_tokens, temperature=TEMPERATURE, do_sample=DO_SAMPLE, pad_token_id=self.tokenizer.eos_token_id, # return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt return_full_text=False ) def retrieve(self, query, num_docs=3): ''' Returns the k most similar documents to the query ''' retrieved_docs = self.vector_db.similarity_search(query, k=num_docs) return retrieved_docs def _format_prompt(self, query, retrieved_docs): context = "\nExtracted documents:\n" # Adjusted extraction slightly to handle missing metadata keys gracefully for doc in retrieved_docs: section = doc.metadata.get('Section', 'N/A') subtitle = doc.metadata.get('Subtitle', 'Context') context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n" prompt = f''' You are a helpful legal interpreter. You are given the following context: {context}\n\n Using the information contained in the context, give a comprehensive answer to the question. Respond only to the question asked. Your response should be concise and relevant to the question. Always provide the section number and title of the source document. Also please use plain English when responding, not legal jargon. Question: {query}" ''' return prompt def easy_generate(self, query, num_docs=3): retrieved_docs = self.retrieve(query, num_docs=num_docs) prompt = self._format_prompt(query, retrieved_docs) # Because we used return_full_text=False in the pipeline, # this returns only the answer. result = self.pipe(prompt)[0]['generated_text'] return result # --- INITIALIZATION --- # Using standard paths and models #MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct' MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B' VECDB_PATH = './index/' # Initialize the RAG system globally so it doesn't reload on every message try: rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH) except Exception as e: rag = None print(f"Error initializing RAG: {e}") # --- GRADIO INTERFACE --- @spaces.GPU(duration=10) def chat_function(message, history): if rag is None: return "System Error: The RAG pipeline failed to initialize. Check logs and ensure the 'index/' folder is uploaded." try: response = rag.easy_generate(message) return response except Exception as e: return f"An error occurred: {str(e)}" demo = gr.ChatInterface( fn=chat_function, type="messages", title="Legal RAG Assistant", description="Ask a question about the legal documents indexed in the database.", examples=["Can the mayor move outside of the city limits?", "What are the zoning laws?", "Is there a maximum building height?","How do I pay a parking ticket?", "How many chickens can I own?"] ) if __name__ == "__main__": demo.launch()