jme-datasci's picture
added ZeroGPU support
1f3ab5b
raw
history blame
5.9 kB
import os
import torch
import gradio as gr
import faiss
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
import spaces
# Ensure an HF Token is present for gated models (like Llama 3)
HF_TOKEN = os.getenv("HF_TOKEN")
class MyRAGPipeline:
'''
Wrapper class for RAG pipeline.
'''
def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True):
if tokenizer_name is None:
tokenizer_name = model_name
self.embedding_model_name = embedding_model_name
self.max_new_tokens = MAX_NEW_TOKENS
print(f"Loading Model: {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
dtype=torch.bfloat16,
token=HF_TOKEN
)
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.tokenizer.padding_side = "left"
print("Loading Embeddings...")
self.embedding_model = HuggingFaceEmbeddings(
model_name=self.embedding_model_name,
multi_process=False, # Set to False for stability in Spaces
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
print(f"Loading Vector DB from {vector_db_path}...")
# Check if index exists to prevent crash
if not os.path.exists(vector_db_path):
raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)
# FAISS GPU optimization (If available)
if torch.cuda.is_available():
try:
res = faiss.StandardGpuResources()
co = faiss.GpuClonerOptions()
co.useFloat16 = True
self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co)
except Exception as e:
print(f"Could not load FAISS to GPU, running on CPU: {e}")
# Initialize Pipeline
self.pipe = pipeline(
'text-generation',
model=self.model,
torch_dtype=torch.bfloat16,
device_map='auto',
tokenizer=self.tokenizer,
max_new_tokens=self.max_new_tokens,
temperature=TEMPERATURE,
do_sample=DO_SAMPLE,
pad_token_id=self.tokenizer.eos_token_id,
# return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt
return_full_text=False
)
def retrieve(self, query, num_docs=3):
'''
Returns the k most similar documents to the query
'''
retrieved_docs = self.vector_db.similarity_search(query, k=num_docs)
return retrieved_docs
def _format_prompt(self, query, retrieved_docs):
context = "\nExtracted documents:\n"
# Adjusted extraction slightly to handle missing metadata keys gracefully
for doc in retrieved_docs:
section = doc.metadata.get('Section', 'N/A')
subtitle = doc.metadata.get('Subtitle', 'Context')
context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"
prompt = f'''
You are a helpful legal interpreter.
You are given the following context:
{context}\n\n
Using the information contained in the context,
give a comprehensive answer to the question.
Respond only to the question asked. Your response should be concise and relevant to the question.
Always provide the section number and title of the source document.
Also please use plain English when responding, not legal jargon.
Question: {query}"
'''
return prompt
def easy_generate(self, query, num_docs=3):
retrieved_docs = self.retrieve(query, num_docs=num_docs)
prompt = self._format_prompt(query, retrieved_docs)
# Because we used return_full_text=False in the pipeline,
# this returns only the answer.
result = self.pipe(prompt)[0]['generated_text']
return result
# --- INITIALIZATION ---
# Using standard paths and models
#MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct'
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
VECDB_PATH = './index/'
# Initialize the RAG system globally so it doesn't reload on every message
try:
rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH)
except Exception as e:
rag = None
print(f"Error initializing RAG: {e}")
# --- GRADIO INTERFACE ---
@spaces.GPU(duration=10)
def chat_function(message, history):
if rag is None:
return "System Error: The RAG pipeline failed to initialize. Check logs and ensure the 'index/' folder is uploaded."
try:
response = rag.easy_generate(message)
return response
except Exception as e:
return f"An error occurred: {str(e)}"
demo = gr.ChatInterface(
fn=chat_function,
type="messages",
title="Legal RAG Assistant",
description="Ask a question about the legal documents indexed in the database.",
examples=["Can the mayor move outside of the city limits?", "What are the zoning laws?", "Is there a maximum building height?","How do I pay a parking ticket?", "How many chickens can I own?"]
)
if __name__ == "__main__":
demo.launch()