Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| import faiss | |
| import numpy as np | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| import spaces | |
| # Ensure an HF Token is present for gated models (like Llama 3) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| class MyRAGPipeline: | |
| ''' | |
| Wrapper class for RAG pipeline. | |
| ''' | |
| def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True): | |
| if tokenizer_name is None: | |
| tokenizer_name = model_name | |
| self.embedding_model_name = embedding_model_name | |
| self.max_new_tokens = MAX_NEW_TOKENS | |
| print(f"Loading Model: {model_name}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| dtype=torch.bfloat16, | |
| token=HF_TOKEN | |
| ) | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| self.tokenizer.padding_side = "left" | |
| print("Loading Embeddings...") | |
| self.embedding_model = HuggingFaceEmbeddings( | |
| model_name=self.embedding_model_name, | |
| multi_process=False, # Set to False for stability in Spaces | |
| model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| print(f"Loading Vector DB from {vector_db_path}...") | |
| # Check if index exists to prevent crash | |
| if not os.path.exists(vector_db_path): | |
| raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.") | |
| self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True) | |
| # FAISS GPU optimization (If available) | |
| if torch.cuda.is_available(): | |
| try: | |
| res = faiss.StandardGpuResources() | |
| co = faiss.GpuClonerOptions() | |
| co.useFloat16 = True | |
| self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co) | |
| except Exception as e: | |
| print(f"Could not load FAISS to GPU, running on CPU: {e}") | |
| # Initialize Pipeline | |
| self.pipe = pipeline( | |
| 'text-generation', | |
| model=self.model, | |
| torch_dtype=torch.bfloat16, | |
| device_map='auto', | |
| tokenizer=self.tokenizer, | |
| max_new_tokens=self.max_new_tokens, | |
| temperature=TEMPERATURE, | |
| do_sample=DO_SAMPLE, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| # return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt | |
| return_full_text=False | |
| ) | |
| def retrieve(self, query, num_docs=3): | |
| ''' | |
| Returns the k most similar documents to the query | |
| ''' | |
| retrieved_docs = self.vector_db.similarity_search(query, k=num_docs) | |
| return retrieved_docs | |
| def _format_prompt(self, query, retrieved_docs): | |
| context = "\nExtracted documents:\n" | |
| # Adjusted extraction slightly to handle missing metadata keys gracefully | |
| for doc in retrieved_docs: | |
| section = doc.metadata.get('Section', 'N/A') | |
| subtitle = doc.metadata.get('Subtitle', 'Context') | |
| context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n" | |
| prompt = f''' | |
| You are a helpful legal interpreter. | |
| You are given the following context: | |
| {context}\n\n | |
| Using the information contained in the context, | |
| give a comprehensive answer to the question. | |
| Respond only to the question asked. Your response should be concise and relevant to the question. | |
| Always provide the section number and title of the source document. | |
| Also please use plain English when responding, not legal jargon. | |
| Question: {query}" | |
| ''' | |
| return prompt | |
| def easy_generate(self, query, num_docs=3): | |
| retrieved_docs = self.retrieve(query, num_docs=num_docs) | |
| prompt = self._format_prompt(query, retrieved_docs) | |
| # Because we used return_full_text=False in the pipeline, | |
| # this returns only the answer. | |
| result = self.pipe(prompt)[0]['generated_text'] | |
| return result | |
| # --- INITIALIZATION --- | |
| # Using standard paths and models | |
| #MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct' | |
| MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | |
| EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B' | |
| VECDB_PATH = './index/' | |
| # Initialize the RAG system globally so it doesn't reload on every message | |
| try: | |
| rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH) | |
| except Exception as e: | |
| rag = None | |
| print(f"Error initializing RAG: {e}") | |
| # --- GRADIO INTERFACE --- | |
| def chat_function(message, history): | |
| if rag is None: | |
| return "System Error: The RAG pipeline failed to initialize. Check logs and ensure the 'index/' folder is uploaded." | |
| try: | |
| response = rag.easy_generate(message) | |
| return response | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" | |
| demo = gr.ChatInterface( | |
| fn=chat_function, | |
| type="messages", | |
| title="Legal RAG Assistant", | |
| description="Ask a question about the legal documents indexed in the database.", | |
| examples=["Can the mayor move outside of the city limits?", "What are the zoning laws?", "Is there a maximum building height?","How do I pay a parking ticket?", "How many chickens can I own?"] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |