Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| from typing import List | |
| import re, json, time | |
| from dataclasses import dataclass | |
| from nltk.tokenize import sent_tokenize | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from huggingface_hub import InferenceClient | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever # Main LangChain package | |
| from langchain.schema import Document as LangchainDocument | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import CrossEncoder | |
| from huggingface_hub import hf_hub_download, HfApi | |
| import tiktoken | |
| from json_repair import repair_json | |
| from functools import lru_cache | |
| import utils | |
| from constants import ( | |
| RELEVANCE_SCORE, | |
| UTILIZATION_SCORE, | |
| COMPLETENESS_SCORE, | |
| ADHERENCE_SCORE | |
| ) | |
| from utils import ( | |
| post_with_retry, | |
| ) | |
| from constants import ( | |
| HF_DATASET_REPO_NAME, | |
| HF_REPO_TYPE, | |
| ) | |
| # Define document structure | |
| class Document: | |
| doc_id: str | |
| text: str | |
| source: str # Refers to the subset | |
| metadata: dict | |
| class Chunk: | |
| chunk_id: str | |
| text: str | |
| doc_id: str | |
| source: str | |
| chunk_num: int | |
| total_chunks: int | |
| metadata: dict | |
| class RAGSystem: | |
| def __init__( | |
| self, | |
| subset: str, | |
| dataset_type: str, | |
| strategy: str, | |
| chunks: List[Chunk], | |
| chunk_size: int = 512, | |
| chunk_overlap: int = 50, | |
| generator_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", | |
| retriever_model_name: str = "BAAI/bge-large-en-v1.5", | |
| reranker_model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2", | |
| hf_api_token: str = None | |
| ): | |
| self.subset = subset | |
| self.dataset_type = dataset_type | |
| self.strategy = strategy | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| self.generator_model_name = generator_model_name | |
| self.retriever_model_name = retriever_model_name | |
| self.reranker_model_name = reranker_model_name | |
| self.chunks = chunks | |
| self.hf_api_token = hf_api_token or os.getenv("HF_API_TOKEN") | |
| # Initialize components | |
| self.vector_store = None | |
| self.embedder = None | |
| self.hybrid_retriever = None | |
| self.generator_client = None | |
| # Set up API-based generator | |
| self._init_generator_api() | |
| def _init_generator_api(self): | |
| self.generator_client = InferenceClient( | |
| model=self.generator_model_name, | |
| token=self.hf_api_token, | |
| timeout=120, | |
| headers={"x-use-cache": "0"} | |
| ) | |
| self.generation_params = { | |
| "max_new_tokens": 512, | |
| "temperature": 0.7, | |
| "top_p": 0.95, | |
| "repetition_penalty": 1.1 | |
| } | |
| def _load_embeddings(self): | |
| if not self.embedder: | |
| self.embedder = HuggingFaceEmbeddings( | |
| model_name=self.retriever_model_name, | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| def _check_huggingface_repo(self): | |
| try: | |
| # 1. First verify the repo exists | |
| print("Checking Hugging Face repository...") | |
| api = HfApi() | |
| bResult = api.repo_exists( | |
| repo_id=HF_DATASET_REPO_NAME, | |
| repo_type=HF_REPO_TYPE, # or "model" if you used model repo | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| if not bResult: | |
| print(f"Repository {HF_DATASET_REPO_NAME} does not exist.") | |
| return False | |
| print(f"Repository {HF_DATASET_REPO_NAME} exists.") | |
| # repo_info = api.repo_info( | |
| # repo_id=HF_DATASET_REPO_NAME, | |
| # repo_type=HF_REPO_TYPE, # or "model" if you used model repo | |
| # token=os.getenv("HF_TOKEN") | |
| # ) | |
| # print(f"Repo exists: {repo_info}") | |
| # 2. List files to verify filename | |
| repo_files = api.list_repo_files( | |
| repo_id=HF_DATASET_REPO_NAME, | |
| repo_type=HF_REPO_TYPE, | |
| ) | |
| print(f"Repository {HF_DATASET_REPO_NAME} is accessible. No of Files: {len(repo_files)}") | |
| # for index, item in enumerate(repo_files): | |
| # print(f"Index, {index}, File: {item}") | |
| except Exception as e: | |
| print(f"Error accessing Hugging Face repo: {e}") | |
| return False | |
| return True | |
| def _download_file(self, filename: str, folder_path: str) -> str: | |
| """Download a file from Hugging Face hub to the specified folder.""" | |
| try: | |
| file_path = hf_hub_download( | |
| repo_id=HF_DATASET_REPO_NAME, | |
| filename=filename, | |
| repo_type=HF_REPO_TYPE, | |
| local_dir=folder_path, | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| print(f"Downloaded {filename} to {file_path}") | |
| return file_path | |
| except Exception as e: | |
| print(f"Error downloading {filename}: {e}") | |
| return None | |
| def _upload_file(self, filename: str, folder_path: str) -> str: | |
| """Upload a file to Hugging Face hub from the specified folder.""" | |
| try: | |
| file_path = os.path.join(folder_path, filename) | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"File {file_path} does not exist.") | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=f"{folder_path}/{filename}", | |
| repo_id=HF_DATASET_REPO_NAME, | |
| repo_type=HF_REPO_TYPE, | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| print(f"Uploaded {file_path} to {HF_DATASET_REPO_NAME}") | |
| return file_path | |
| except Exception as e: | |
| print(f"Error uploading {filename}: {e}") | |
| return None | |
| def _store_faiss_files(self, folder_path: str): | |
| """Store FAISS index files to Hugging Face hub.""" | |
| try: | |
| # Ensure the folder exists | |
| if not os.path.exists(folder_path): | |
| os.makedirs(folder_path) | |
| # Save the FAISS index locally | |
| self.vector_store.save_local(folder_path) | |
| # Upload required files to Hugging Face hub | |
| for filename in ["index.faiss", "index.pkl"]: | |
| file_path = os.path.join(folder_path, filename) | |
| if os.path.exists(file_path): | |
| self._upload_file(filename, folder_path) | |
| else: | |
| print(f"File {file_path} does not exist, skipping upload.") | |
| except Exception as e: | |
| print(f"Error storing FAISS files: {e}") | |
| def _download_FAISS_files(self, folder_path: str): | |
| """Download all required FAISS files from Hugging Face hub.""" | |
| # Define all required files | |
| REQUIRED_FILES = [ | |
| "index.faiss", | |
| "index.pkl" | |
| ] | |
| try: | |
| # Download the embeddings files to cache | |
| downloaded_files = [] | |
| for filename in REQUIRED_FILES: | |
| file_path = self._download_file( | |
| filename=f"{folder_path}/{filename}", | |
| folder_path=folder_path | |
| ) | |
| if file_path: | |
| downloaded_files.append(file_path) | |
| print(f"Downloaded: {filename} → {file_path}") | |
| else: | |
| return False # If any file fails to download, return False | |
| # Get the common directory | |
| index_dir = os.path.dirname(downloaded_files[0]) | |
| print(f"Final index directory: {index_dir}") | |
| print(f"Files in directory: {os.listdir(index_dir)}") | |
| # Load FAISS | |
| self.vector_store = FAISS.load_local( | |
| folder_path=index_dir, | |
| embeddings=self.embedder, | |
| allow_dangerous_deserialization=True | |
| ) | |
| except Exception as e: | |
| print(f"Error loading index: {e}") | |
| return False | |
| return True | |
| def load_embeddings_database(self, retriever_type = "Vector"): | |
| print("Testing new changes in ragbench.py") | |
| if self._check_huggingface_repo() is False: | |
| print(f"Repository {HF_DATASET_REPO_NAME} does not exist or is inaccessible.") | |
| return | |
| """Load pre-built FAISS index and retrievers""" | |
| self._load_embeddings() | |
| chunkFilePath = f"{self.subset}/chunks/chunks_{self.strategy}.pkl" | |
| print(f"Chunk File Path: {chunkFilePath} for strategy {self.strategy}") | |
| chunkFile = self._download_file(chunkFilePath, "") | |
| bChunkFileAvailable = False | |
| if os.path.exists(chunkFile): | |
| with open(chunkFile, "rb") as f: | |
| langchain_docs = pickle.load(f) | |
| bChunkFileAvailable = True | |
| print(f"Successfully loaded chunks from {chunkFile}, length: {len(langchain_docs)}") | |
| # Check for FAISS index files (index.faiss, index.pkl ) | |
| faissFolderPath = f"{self.subset}/embeddings/{self.retriever_model_name.replace('/', ':')}/{self.strategy}" | |
| print(f"FAISS Folder path: {faissFolderPath}") | |
| if self._download_FAISS_files(faissFolderPath): | |
| print(f"FAISS index loaded successfully from {faissFolderPath}") | |
| else: | |
| print(f"Failed to load FAISS index from {faissFolderPath}, so load from documents") | |
| if bChunkFileAvailable: | |
| print(f"Building FAISS index from downloaded chunks") | |
| start = time.time() | |
| self.vector_store = FAISS.from_documents(langchain_docs, self.embedder) | |
| duration = time.time() - start | |
| # Convert to minutes and seconds | |
| minutes = int(duration // 60) | |
| seconds = int(duration % 60) | |
| print(f"FAISS index built successfully from chunks in {minutes} minutes and {seconds} seconds, saving to {faissFolderPath}") | |
| self._store_faiss_files(faissFolderPath) | |
| if bChunkFileAvailable and retriever_type == "BM25": | |
| bm25 = BM25Retriever.from_documents(langchain_docs) | |
| bm25.k = 20 | |
| self.hybrid_retriever = EnsembleRetriever( | |
| retrievers=[ | |
| self.vector_store.as_retriever(search_kwargs={"k": 20}), | |
| bm25 | |
| ], | |
| weights=[0.7, 0.3] | |
| ) | |
| else: | |
| print(f".pkl not found at {chunkFilePath}, using only FAISS retriever.") | |
| self.hybrid_retriever = self.vector_store.as_retriever(search_kwargs={"k": 20}) | |
| def store_embeddings_database(self, save_faiss: bool = True): | |
| """Build and store FAISS index from chunks""" | |
| if not self.embedder: | |
| self.embedder = HuggingFaceEmbeddings(model_name=self.retriever_model_name) | |
| index_path = f"./faiss_index_{self.subset}_{self.dataset_type}_{self.strategy}" | |
| if os.path.exists(f"{index_path}/index.faiss"): | |
| print(f"📂 Reusing existing FAISS index") | |
| self.vector_store = FAISS.load_local(index_path, self.embedder) | |
| else: | |
| print(f"⚙️ Building new FAISS index") | |
| langchain_docs = [ | |
| LangchainDocument( | |
| page_content=chunk.text, | |
| metadata={ | |
| **chunk.metadata, | |
| "chunk_id": chunk.chunk_id, | |
| "doc_id": chunk.doc_id, | |
| "source": chunk.source, | |
| "chunk_num": chunk.chunk_num, | |
| "total_chunks": chunk.total_chunks | |
| } | |
| ) for chunk in self.chunks | |
| ] | |
| self.vector_store = FAISS.from_documents(langchain_docs, self.embedder) | |
| if save_faiss: | |
| os.makedirs(index_path, exist_ok=True) | |
| self.vector_store.save_local(index_path) | |
| with open(f"{index_path}/langchain_docs.pkl", "wb") as f: | |
| pickle.dump(langchain_docs, f) | |
| # Initialize hybrid retriever | |
| bm25 = BM25Retriever.from_documents([ | |
| LangchainDocument(page_content=chunk.text, metadata=chunk.metadata) | |
| for chunk in self.chunks | |
| ]) | |
| bm25.k = 20 | |
| self.hybrid_retriever = EnsembleRetriever( | |
| retrievers=[self.vector_store.as_retriever(search_kwargs={"k": 20}), bm25], | |
| weights=[0.7, 0.3] | |
| ) | |
| def generate_hypothetical_answer(self, question: str) -> str: | |
| """Generate HYDE hypothetical answer using API""" | |
| prompt = f"Generate a detailed hypothetical answer for: {question}" | |
| return self._generate_with_api(prompt, max_new_tokens=100) | |
| def _generate_with_api(self, prompt: str, **kwargs) -> str: | |
| """Generate text using HF Inference API with correct parameters""" | |
| # Default parameters compatible with the API | |
| params = { | |
| "max_tokens": kwargs.get("max_tokens", 512), # Note: 'max_tokens' not 'max_new_tokens' | |
| "temperature": kwargs.get("temperature", 0.7), | |
| "top_p": kwargs.get("top_p", 0.95), | |
| # Note: 'repetition_penalty' is not available in chat_completion() | |
| # Use 'top_k' instead if needed | |
| # "top_k": kwargs.get("top_k", 50) | |
| } | |
| try: | |
| response = self.generator_client.chat_completion( | |
| messages=[{"role": "user", "content": prompt}], | |
| **params | |
| ) | |
| time.sleep(3) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"Generation failed: {str(e)}") | |
| return "I couldn't generate an answer." | |
| def get_reranker(self, model_name: str, device: str): | |
| return CrossEncoder(model_name, device=device) | |
| def _use_reranker(self, docs: List[LangchainDocument], query: str, top_k: int) -> List[LangchainDocument]: | |
| """Use the reranker model to re-rank retrieved documents""" | |
| if not self.reranker_model_name: | |
| return docs | |
| sentence_chunks = [] | |
| for doc in docs: | |
| for sentence in doc.page_content.strip().split("."): | |
| sentence = sentence.strip() | |
| if len(sentence) > 15: | |
| sentence_chunks.append((sentence, doc.metadata)) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # self.reranker = CrossEncoder(self.reranker_model_name, device=device) | |
| self.reranker = self.get_reranker(self.reranker_model_name, device) | |
| pairs = [[query, sent] for sent, _ in sentence_chunks] | |
| scores = self.reranker.predict(pairs) | |
| top_pairs = sorted(zip(sentence_chunks, scores), key=lambda x: x[1], reverse=True)[:top_k] | |
| top_chunks = [] | |
| for (sentence, meta), score in top_pairs: | |
| top_chunks.append(Chunk( | |
| chunk_id=meta.get("chunk_id", ""), | |
| text=sentence, | |
| doc_id=meta.get("doc_id", ""), | |
| source=meta.get("source", ""), | |
| chunk_num=meta.get("chunk_num", -1), | |
| total_chunks=meta.get("total_chunks", -1), | |
| metadata={**meta, "reranker_score": score} | |
| )) | |
| print(f"Reranked {len(top_chunks)} chunks from {len(docs)} documents") | |
| return top_chunks | |
| def retrieve(self, query: str, top_k: int = 10) -> List[Chunk]: | |
| """Retrieve relevant chunks using HYDE""" | |
| pseudo_answer = self.generate_hypothetical_answer(query) | |
| docs = self.hybrid_retriever.invoke(pseudo_answer) | |
| if self.reranker_model_name is not None: | |
| return self._use_reranker(docs, query, top_k) | |
| else: | |
| return [ | |
| Chunk( | |
| chunk_id=doc.metadata.get("chunk_id", ""), | |
| text=doc.page_content, | |
| doc_id=doc.metadata.get("doc_id", ""), | |
| source=doc.metadata.get("source", ""), | |
| chunk_num=doc.metadata.get("chunk_num", -1), | |
| total_chunks=doc.metadata.get("total_chunks", -1), | |
| metadata=doc.metadata | |
| ) for doc in docs[:top_k] | |
| ] | |
| def generate(self, question: str, context: List[str] = None) -> str: | |
| """Generate final answer with RAG context""" | |
| if context is None: | |
| retrieved_chunks = self.retrieve(question) | |
| context = [chunk.text for chunk in retrieved_chunks] | |
| formatted_context = "\n\n".join(context) | |
| prompt = f"""[INST] You are a helpful assistant. Use *only* the context to answer. | |
| If unsure, say "I don't know." | |
| Context: | |
| {formatted_context} | |
| Question: {question} | |
| Answer: [/INST]""" | |
| return self._generate_with_api(prompt) | |
| class RAGEvaluator: | |
| CONTEXT_WINDOW = 8192 # Groq llama3-70b-8192 context window | |
| SAFETY_MARGIN = 1024 # Leave some room for response | |
| MAX_INPUT_TOKENS = CONTEXT_WINDOW - SAFETY_MARGIN | |
| def __init__(self, | |
| local_model_name="meta-llama/Llama-2-7b-chat-hf", | |
| use_groq=True, | |
| groq_api_key=None, | |
| groq_model="llama3-70b-8192"): | |
| self.use_groq = use_groq | |
| self.groq_model = groq_model | |
| self.groq_api_key = groq_api_key | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| if not use_groq: | |
| self.tokenizer = AutoTokenizer.from_pretrained(local_model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| local_model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" | |
| ).to(self.device) | |
| # Init tokenizer for Groq token estimation | |
| self.groq_tokenizer = tiktoken.encoding_for_model("gpt-4o") # Approximation works well for llama3 | |
| def build_trace_prompt(self, documents, question, response): | |
| return utils.get_evaluator_trace_prompt(documents, question, response) | |
| def _evaluate_with_groq(self, prompt): | |
| headers = { | |
| "Authorization": f"Bearer {self.groq_api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| safe_prompt = self.truncate_prompt(prompt) | |
| payload = { | |
| "model": self.groq_model, | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant that returns structured JSON in the format specified."}, | |
| {"role": "user", "content": safe_prompt} | |
| ], | |
| "temperature": 0.6, | |
| "top_p": 0.95, | |
| "max_tokens": self.SAFETY_MARGIN, # or change to "max_completion_tokens" if Groq uses that | |
| "stream": False, # set to True only if you handle streaming | |
| "stop": None | |
| } | |
| response = post_with_retry("https://api.groq.com/openai/v1/chat/completions", headers, payload) | |
| if response.status_code != 200: | |
| raise RuntimeError(f"Groq API Error: {response.status_code}: {response.text}") | |
| try: | |
| # print(f"*** Response: {response.text}") | |
| content_str = response.json()["choices"][0]["message"]["content"] | |
| return self._extract_and_clean_json(content_str) | |
| except Exception as e: | |
| print(f"Exception in load the content_str, {e}") | |
| def estimate_tokens(self, text): | |
| return len(self.groq_tokenizer.encode(text)) | |
| def truncate_prompt(self, prompt): | |
| tokens = self.estimate_tokens(prompt) | |
| if tokens <= self.MAX_INPUT_TOKENS: | |
| return prompt | |
| else: | |
| # Simple char-based truncate (fast approximation) | |
| approx_char_limit = int(self.MAX_INPUT_TOKENS * 4) | |
| truncated_prompt = prompt[:approx_char_limit] | |
| print(f"[WARNING] Prompt truncated from {tokens} to {self.estimate_tokens(truncated_prompt)} tokens") | |
| return truncated_prompt | |
| def _extract_and_clean_json(self, text): | |
| json_str = self._extract_first_json_block(text) | |
| repaired = repair_json(json_str) | |
| return json.loads(repaired) | |
| def _extract_json(self, text): | |
| if isinstance(text, dict): | |
| return text | |
| json_start = text.find("{") | |
| if json_start == -1: | |
| raise ValueError("No JSON object found in text") | |
| json_text = text[json_start:] | |
| return json.loads(json_text) | |
| def evaluate(self, documents, question, response, max_new_tokens=1024): | |
| prompt = self.build_trace_prompt(documents, question, response) | |
| return self._evaluate_with_groq(prompt) | |
| def extract_trace_metrics_from_json(self, trace_json: dict, totalDocuments) -> dict: | |
| if not trace_json: | |
| raise ValueError("Input is empty") | |
| if isinstance(trace_json, list): | |
| trace_json = trace_json[0] if len(trace_json) > 0 else {} | |
| relevant_keys = set(trace_json.get("all_relevant_sentence_keys", [])) | |
| utilized_keys = set(trace_json.get("all_utilized_sentence_keys", [])) | |
| adherence = trace_json.get("overall_supported", False) | |
| len_R = len(relevant_keys) | |
| len_U = len(utilized_keys) | |
| len_R_intersect_U = len(relevant_keys.intersection(utilized_keys)) | |
| relevance = None if totalDocuments is None else len_R / totalDocuments | |
| utilization = None if totalDocuments is None else len_U / totalDocuments | |
| completeness = None if len_R == 0 else len_R_intersect_U / len_R | |
| return { | |
| RELEVANCE_SCORE: round(relevance, 3) if relevance is not None else None, | |
| UTILIZATION_SCORE: round(utilization, 3) if utilization is not None else None, | |
| COMPLETENESS_SCORE: round(completeness, 3) if completeness is not None else None, | |
| ADHERENCE_SCORE: adherence | |
| } | |
| def _extract_first_json_block(self, text): | |
| json_start = text.find('{') | |
| json_end = text.rfind('}') | |
| if json_start == -1 or json_end == -1 or json_start >= json_end: | |
| raise ValueError("No valid JSON block found.") | |
| return text[json_start:json_end+1] | |
| def _clean_json_text(self, text): | |
| text = text.strip().replace("'", '"') | |
| text = re.sub(r',\s*}', '}', text) | |
| text = re.sub(r',\s*]', ']', text) | |
| return text |