Spaces:
Sleeping
Sleeping
File size: 5,903 Bytes
d7dee9d 669470c d7dee9d 1f3ab5b 669470c d7dee9d 669470c d7dee9d 669470c d7dee9d 669470c d7dee9d 669470c d7dee9d 669470c d7dee9d 669470c d7dee9d a3799ac d7dee9d 669470c d7dee9d 669470c d7dee9d 5542b80 d7dee9d 9a0c7ad 669470c d7dee9d 669470c d7dee9d 1f3ab5b d7dee9d 669470c d7dee9d 9a0c7ad d7dee9d 669470c d7dee9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import torch
import gradio as gr
import faiss
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
import spaces
# Ensure an HF Token is present for gated models (like Llama 3)
HF_TOKEN = os.getenv("HF_TOKEN")
class MyRAGPipeline:
'''
Wrapper class for RAG pipeline.
'''
def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True):
if tokenizer_name is None:
tokenizer_name = model_name
self.embedding_model_name = embedding_model_name
self.max_new_tokens = MAX_NEW_TOKENS
print(f"Loading Model: {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
dtype=torch.bfloat16,
token=HF_TOKEN
)
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.tokenizer.padding_side = "left"
print("Loading Embeddings...")
self.embedding_model = HuggingFaceEmbeddings(
model_name=self.embedding_model_name,
multi_process=False, # Set to False for stability in Spaces
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
print(f"Loading Vector DB from {vector_db_path}...")
# Check if index exists to prevent crash
if not os.path.exists(vector_db_path):
raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)
# FAISS GPU optimization (If available)
if torch.cuda.is_available():
try:
res = faiss.StandardGpuResources()
co = faiss.GpuClonerOptions()
co.useFloat16 = True
self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co)
except Exception as e:
print(f"Could not load FAISS to GPU, running on CPU: {e}")
# Initialize Pipeline
self.pipe = pipeline(
'text-generation',
model=self.model,
torch_dtype=torch.bfloat16,
device_map='auto',
tokenizer=self.tokenizer,
max_new_tokens=self.max_new_tokens,
temperature=TEMPERATURE,
do_sample=DO_SAMPLE,
pad_token_id=self.tokenizer.eos_token_id,
# return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt
return_full_text=False
)
def retrieve(self, query, num_docs=3):
'''
Returns the k most similar documents to the query
'''
retrieved_docs = self.vector_db.similarity_search(query, k=num_docs)
return retrieved_docs
def _format_prompt(self, query, retrieved_docs):
context = "\nExtracted documents:\n"
# Adjusted extraction slightly to handle missing metadata keys gracefully
for doc in retrieved_docs:
section = doc.metadata.get('Section', 'N/A')
subtitle = doc.metadata.get('Subtitle', 'Context')
context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"
prompt = f'''
You are a helpful legal interpreter.
You are given the following context:
{context}\n\n
Using the information contained in the context,
give a comprehensive answer to the question.
Respond only to the question asked. Your response should be concise and relevant to the question.
Always provide the section number and title of the source document.
Also please use plain English when responding, not legal jargon.
Question: {query}"
'''
return prompt
def easy_generate(self, query, num_docs=3):
retrieved_docs = self.retrieve(query, num_docs=num_docs)
prompt = self._format_prompt(query, retrieved_docs)
# Because we used return_full_text=False in the pipeline,
# this returns only the answer.
result = self.pipe(prompt)[0]['generated_text']
return result
# --- INITIALIZATION ---
# Using standard paths and models
#MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct'
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
VECDB_PATH = './index/'
# Initialize the RAG system globally so it doesn't reload on every message
try:
rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH)
except Exception as e:
rag = None
print(f"Error initializing RAG: {e}")
# --- GRADIO INTERFACE ---
@spaces.GPU(duration=10)
def chat_function(message, history):
if rag is None:
return "System Error: The RAG pipeline failed to initialize. Check logs and ensure the 'index/' folder is uploaded."
try:
response = rag.easy_generate(message)
return response
except Exception as e:
return f"An error occurred: {str(e)}"
demo = gr.ChatInterface(
fn=chat_function,
type="messages",
title="Legal RAG Assistant",
description="Ask a question about the legal documents indexed in the database.",
examples=["Can the mayor move outside of the city limits?", "What are the zoning laws?", "Is there a maximum building height?","How do I pay a parking ticket?", "How many chickens can I own?"]
)
if __name__ == "__main__":
demo.launch() |