Spaces:
Paused
Paused
| import os | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from model2vec import StaticModel | |
| from transformers import AutoConfig | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from src.utils.api_key_manager import APIKeyManager | |
| from src.helpers.helper import chunk_text | |
| class LateChunker: | |
| def __init__( | |
| self, | |
| model_name='minishlab/potion-base-8M', | |
| max_workers=os.cpu_count() * 2, | |
| verbose=False | |
| ): | |
| self.verbose = verbose | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.llm = APIKeyManager().get_llm() | |
| self.model_name = model_name | |
| # Initialize model using the fallback strategy | |
| self.model, self.context_length = self._initialize_model() | |
| # Initialize ThreadPoolExecutor | |
| self.executor = ThreadPoolExecutor(max_workers=max_workers) | |
| def _initialize_model(self): | |
| sentence_transformer_error = None | |
| model2vec_error = None | |
| # First attempt: Try SentenceTransformer | |
| try: | |
| # Get the model config to check max context length | |
| config = AutoConfig.from_pretrained(self.model_name) | |
| max_length = config.max_position_embeddings | |
| # Initialize SentenceTransformer model | |
| model = SentenceTransformer(self.model_name, trust_remote_code=True) | |
| model.max_seq_length = max_length # Set the correct max length | |
| model.to(self.device).half() | |
| context_length = model.max_seq_length | |
| return model, context_length | |
| except Exception as e: | |
| sentence_transformer_error = str(e) | |
| # Second attempt: Try Model2Vec | |
| try: | |
| # Initialize Model2Vec model | |
| model = StaticModel.from_pretrained( | |
| self.model_name | |
| ) | |
| # Get max sequence length from static model config | |
| context_length = model.config['seq_length'] | |
| return model, context_length | |
| except Exception as e: | |
| model2vec_error = str(e) | |
| error_msg = ( | |
| f"Failed to load model {self.model_name}.\n" | |
| f"SentenceTransformer error: {sentence_transformer_error}\n" | |
| f"Model2Vec error: {model2vec_error}" | |
| ) | |
| raise Exception(error_msg) from e | |
| async def late_chunking(self, text, span_annotations, current_chunk_idx=None, total_chunks=None): | |
| print(f"Processing chunk {current_chunk_idx+1}/{total_chunks}...") \ | |
| if self.verbose else None | |
| # Get the current running event loop | |
| loop = asyncio.get_running_loop() | |
| # Generate chunk embeddings | |
| chunk_embeddings = [] | |
| for start, end in span_annotations: | |
| chunk_text = text[start:end] | |
| print("Generating chunk embeddings...") if self.verbose else None | |
| chunk_embedding = await loop.run_in_executor( | |
| self.executor, | |
| lambda: torch.tensor( | |
| self.model.encode( | |
| chunk_text, | |
| convert_to_tensor=True | |
| ) | |
| ) | |
| ) | |
| if isinstance(chunk_embedding, torch.Tensor): | |
| chunk_embedding = chunk_embedding.clone().detach().to(self.device) | |
| print(f"Chunk embedding shape: {chunk_embedding.shape}") if self.verbose else None | |
| chunk_embeddings.append(chunk_embedding) | |
| print("Late Chunking applied successfully!") if self.verbose else None | |
| return chunk_embeddings if chunk_embeddings else None | |
| def get_text_embedding(self, text): | |
| embeddings = self.model.encode(text, convert_to_tensor=True) | |
| if isinstance(embeddings, torch.Tensor): | |
| return embeddings.clone().detach().to(self.device) | |
| return torch.tensor(embeddings).to(self.device) | |
| def calculate_embedding_similarities(self, text1_embedding, text2_embedding): | |
| text1_embedding = text1_embedding.cpu().numpy() | |
| text2_embedding = text2_embedding.cpu().numpy() | |
| if text1_embedding.ndim == 1: | |
| text1_embedding = text1_embedding.reshape(1, -1) | |
| if text2_embedding.ndim == 1: | |
| text2_embedding = text2_embedding.reshape(1, -1) | |
| if text1_embedding.shape[1] != text2_embedding.shape[1]: | |
| text1_embedding = text1_embedding.T | |
| if text2_embedding.shape[1] != text1_embedding.shape[1]: | |
| text2_embedding = text2_embedding.T | |
| return cosine_similarity(text1_embedding, text2_embedding)[0] | |
| def select_relevant_chunks(self, similarities, chunks, max_tokens): | |
| sorted_indices = np.argsort(similarities)[::-1] | |
| selected_chunks = [] | |
| total_tokens = 0 | |
| for i, idx in enumerate(sorted_indices): | |
| print(f"Selected chunk {i+1}/{len(sorted_indices)} with similarity {similarities[idx]:.2f}") \ | |
| if self.verbose else None | |
| chunk_tokens = self.llm.get_num_tokens(chunks[idx]) | |
| print(f"Chunk tokens: {chunk_tokens}") if self.verbose else None | |
| if total_tokens + chunk_tokens > max_tokens: | |
| print(f"Total tokens exceed max tokens allowed ({total_tokens} > {max_tokens}). \ | |
| Stopping chunk selection.") if self.verbose else None | |
| break | |
| selected_chunks.append((idx, chunks[idx])) | |
| total_tokens += chunk_tokens | |
| print("Sorting selected chunks...") if self.verbose else None | |
| selected_chunks.sort(key=lambda x: x[0]) | |
| print("Selected chunks sorted successfully!") if self.verbose else None | |
| return " ".join([chunk for _, chunk in selected_chunks]) | |
| async def chunker(self, text, query, max_chunk_length=1000, max_tokens=2048, overlap=200): | |
| # Tokenize the entire text to check its length | |
| total_tokens = self.llm.get_num_tokens(text) | |
| # If the text is less than max tokens, return the text as is | |
| if total_tokens <= max_tokens: | |
| print(f"Text is less than the max tokens allowed ({total_tokens} <= {max_tokens}). \ | |
| Returning original text.") if self.verbose else None | |
| return text | |
| # Chunk the text if it exceeds max tokens | |
| print(f"Text is greater than the max tokens allowed ({total_tokens} > {max_tokens}). \ | |
| Chunking text...") if self.verbose else None | |
| chunks, span_annotations = chunk_text( | |
| text, | |
| max_chunk_length=max_chunk_length, | |
| overlap=overlap, | |
| # Use the smaller of either context length or max tokens | |
| context_length=min(self.context_length, max_tokens) | |
| ) | |
| print(f"Text chunked into {len(chunks)} macro chunks.") if self.verbose else None | |
| # Process each macro chunk individually | |
| chunk_embeddings = [] | |
| tasks = [] | |
| for i, macro_chunk in enumerate(chunks): | |
| # Adjust span annotations relative to the current macro chunk | |
| start_offset = span_annotations[i][0] | |
| adjusted_spans = [ | |
| (start - start_offset, end - start_offset) | |
| for start, end in span_annotations | |
| if start >= start_offset and end <= start_offset + len(macro_chunk) | |
| ] | |
| # Apply late chunking for the current macro chunk | |
| tasks.append(self.late_chunking(macro_chunk, adjusted_spans, i, len(chunks))) | |
| # Aggregate embeddings asynchronously | |
| results = await asyncio.gather(*tasks) | |
| chunk_embeddings = torch.stack([result[0] for result in results]) | |
| # Generate query embedding | |
| print("Generating query embedding...") if self.verbose else None | |
| query_embedding = self.get_text_embedding(query) | |
| print(f"Query embedding shape: {query_embedding.shape}") if self.verbose else None | |
| # Calculate similarities between query embedding and chunk embeddings | |
| print("Calculating embedding similarities...") if self.verbose else None | |
| similarities = self.calculate_embedding_similarities(query_embedding, chunk_embeddings) | |
| print(f"Similarities shape: {similarities.shape}") if self.verbose else None | |
| # Select relevant chunks based on similarity | |
| print("Selecting relevant chunks...") if self.verbose else None | |
| return self.select_relevant_chunks(similarities, chunks, max_tokens) | |
| if __name__ == "__main__": | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from src.reasoning.reasoner import Reasoner | |
| from src.search.search_engine import SearchEngine | |
| from src.crawl.crawler import CustomCrawler | |
| import time | |
| search_engine = SearchEngine() | |
| crawler = CustomCrawler() | |
| reasoner = Reasoner() | |
| chunking = LateChunker(verbose=True) | |
| loop = asyncio.new_event_loop() | |
| search1 = loop.run_until_complete(search_engine.search( | |
| "What is the history of climate change and pollution since the pre-indutrial revolution?", | |
| num_results=20, | |
| exclude_filetypes=["pdf"] | |
| )) | |
| urls = [result["link"] for result in search1] | |
| search2 = loop.run_until_complete(search_engine.search( | |
| "What is the impact of climate change on the Indian economy?", | |
| num_results=20, | |
| exclude_filetypes=["pdf"] | |
| )) | |
| urls.extend([result["link"] for result in search2]) | |
| search3 = loop.run_until_complete(search_engine.search( | |
| "What are some of the latest, state of art techniques used to fight climate change?", | |
| num_results=20, | |
| exclude_filetypes=["pdf"] | |
| )) | |
| urls.extend([result["link"] for result in search3]) | |
| search4 = loop.run_until_complete(search_engine.search( | |
| "What does the projection for climate change look like in the next 50 years?", | |
| num_results=20, | |
| exclude_filetypes=["pdf"] | |
| )) | |
| urls.extend([result["link"] for result in search4]) | |
| search5 = loop.run_until_complete(search_engine.search( | |
| "What efforts are being made by governments all around the world to combat climate change?", | |
| num_results=20, | |
| exclude_filetypes=["pdf"] | |
| )) | |
| urls.extend([result["link"] for result in search5]) | |
| results = loop.run_until_complete(crawler.fetch_page_contents( | |
| urls=urls, | |
| max_attempts=1, | |
| delay=0 | |
| )) | |
| text = "\n".join([f"Document {i}:\n{result}\n" for i, result in enumerate(results)]) | |
| num_tokens_before_chunking = chunking.llm.get_num_tokens(text) | |
| start_time = time.perf_counter() | |
| response = loop.run_until_complete(chunking.chunker( | |
| text, | |
| query="What is this text about? Give me a detailed answer", | |
| max_tokens=128000 | |
| )) | |
| end_time = time.perf_counter() | |
| num_tokens_after_chunking = chunking.llm.get_num_tokens(response) | |
| print(f"\nResponse:\n{response}") | |
| print(f"\nNumber of URLs: {len(urls)}") | |
| print(f"\nNumber of tokens before late chunking: {num_tokens_before_chunking}") | |
| print(f"\nNumber of tokens after late chunking: {num_tokens_after_chunking}") | |
| print(f"\nTime taken: {end_time - start_time:.2f} seconds") | |
| # Calculate cosine similarity between original text and response | |
| def calculate_cosine_similarity(text1, text2): | |
| vectorizer = TfidfVectorizer().fit_transform([text1, text2]) | |
| vectors = vectorizer.toarray() | |
| return cosine_similarity(vectors)[0][1] | |
| similarity = calculate_cosine_similarity(text, response) | |
| print(f"\nCosine similarity between original text and late chunked text: {similarity * 100:.2f}%") | |