Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import time | |
| import pickle | |
| from pathlib import Path | |
| import warnings | |
| import threading | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError | |
| import pandas as pd | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| import nltk | |
| import faiss | |
| import torch | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| from ctransformers import AutoModelForCausalLM | |
| # --- Basic Configuration --- | |
| warnings.filterwarnings("ignore") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["OMP_NUM_THREADS"] = "2" | |
| os.environ["MKL_NUM_THREADS"] = "2" | |
| nltk.download('punkt', quiet=True) | |
| RANDOM_SEED = 42 | |
| np.random.seed(RANDOM_SEED) | |
| torch.manual_seed(RANDOM_SEED) | |
| DEVICE = "cpu" | |
| class RAGPipeline: | |
| def __init__(self): | |
| self.chunks_df = None | |
| self.bm25 = None | |
| self.index_faiss = None | |
| self.embedding_model = None | |
| self.llm_model = None | |
| # Create a sample of the data for faster search | |
| self.sample_indices = None | |
| self.sample_size = 50000 # Work with subset of data | |
| self.load_artifacts() | |
| self.load_models() | |
| def load_artifacts(self): | |
| print(f"--> Loading artifacts from root directory") | |
| self.chunks_df = pd.read_parquet("chunks_df.parquet") | |
| print(f"Loaded {len(self.chunks_df)} chunks.") | |
| # Create a random sample for faster search | |
| self.sample_indices = np.random.choice( | |
| len(self.chunks_df), | |
| size=min(self.sample_size, len(self.chunks_df)), | |
| replace=False | |
| ) | |
| print(f"Created sample of {len(self.sample_indices)} chunks for faster search") | |
| with open("bm25_index.pkl", "rb") as f: | |
| self.bm25 = pickle.load(f) | |
| print("Loaded BM25 index.") | |
| self.index_faiss = faiss.read_index("news_chunks.faiss_index") | |
| print(f"Loaded FAISS index with {self.index_faiss.ntotal} vectors.") | |
| def load_models(self): | |
| print("--> Loading models...") | |
| # Dense Retriever - use smaller model | |
| EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' # Faster than multi-qa variant | |
| self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE) | |
| self.embedding_model.max_seq_length = 128 # Very short for speed | |
| print(f"Embedding model '{EMBEDDING_MODEL_NAME}' loaded.") | |
| # LLM - Try Phi-3 Mini with different settings | |
| print("Loading Phi-3 Mini...") | |
| # Multiple model options to try | |
| model_options = [ | |
| # Option 1: Phi-3 Mini 4K | |
| { | |
| "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf", | |
| "model_file": "Phi-3-mini-4k-instruct-q4.gguf", | |
| "model_type": "phi3" | |
| }, | |
| # Option 2: Back to TinyLlama but with different settings | |
| { | |
| "repo_id": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", | |
| "model_file": "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", | |
| "model_type": "llama" | |
| }, | |
| # Option 3: Even smaller model | |
| { | |
| "repo_id": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", | |
| "model_file": "tinyllama-1.1b-chat-v1.0.Q2_K.gguf", # Smaller quantization | |
| "model_type": "llama" | |
| } | |
| ] | |
| for i, model_config in enumerate(model_options): | |
| try: | |
| print(f"Trying model option {i+1}: {model_config['repo_id']}") | |
| self.llm_model = AutoModelForCausalLM.from_pretrained( | |
| model_config["repo_id"], | |
| model_file=model_config["model_file"], | |
| model_type=model_config["model_type"], | |
| temperature=0.1, | |
| max_new_tokens=100, | |
| context_length=512, # Very short context | |
| gpu_layers=0, | |
| threads=2, # Minimal threads | |
| batch_size=1, # Smallest batch | |
| stream=False, | |
| local_files_only=False | |
| ) | |
| print(f"Successfully loaded model: {model_config['repo_id']}") | |
| break | |
| except Exception as e: | |
| print(f"Failed to load model option {i+1}: {e}") | |
| if i == len(model_options) - 1: | |
| raise Exception("Failed to load any LLM model") | |
| def search_bm25_fast(self, query: str, k: int = 5): | |
| """Ultra-fast BM25 search on sample""" | |
| tokenized_query = query.lower().split() | |
| # Only search the sample | |
| sample_scores = [] | |
| for idx in self.sample_indices[:10000]: # Even smaller subset | |
| doc_tokens = self.bm25.doc_freqs[idx] | |
| score = self.bm25._score(tokenized_query, idx) | |
| sample_scores.append((idx, score)) | |
| # Get top k | |
| sample_scores.sort(key=lambda x: x[1], reverse=True) | |
| results = [] | |
| for idx, score in sample_scores[:k]: | |
| chunk_info = self.chunks_df.iloc[idx] | |
| results.append({ | |
| 'chunk_id': chunk_info['chunk_id'], | |
| 'doc_id': chunk_info['doc_id'], | |
| 'score': float(score), | |
| 'text': chunk_info['chunk_text'][:500], # Truncate text | |
| 'title': chunk_info['original_title'], | |
| 'url': chunk_info['original_url'] | |
| }) | |
| return results | |
| def search_faiss_fast(self, query: str, k: int = 5): | |
| """Fast FAISS search with timeout""" | |
| try: | |
| # Quick embedding | |
| query_embedding = self.embedding_model.encode( | |
| query[:100], # Truncate query if too long | |
| convert_to_tensor=False, | |
| show_progress_bar=False | |
| ) | |
| query_embedding = query_embedding.reshape(1, -1).astype('float32') | |
| faiss.normalize_L2(query_embedding) | |
| # Search with reduced k | |
| distances, indices = self.index_faiss.search(query_embedding, k) | |
| results = [] | |
| for i in range(min(k, len(indices[0]))): | |
| idx = indices[0][i] | |
| if idx < 0 or idx >= len(self.chunks_df): | |
| continue | |
| score = float(distances[0][i]) | |
| chunk_info = self.chunks_df.iloc[idx] | |
| results.append({ | |
| 'chunk_id': chunk_info['chunk_id'], | |
| 'doc_id': chunk_info['doc_id'], | |
| 'score': score, | |
| 'text': chunk_info['chunk_text'][:500], | |
| 'title': chunk_info['original_title'], | |
| 'url': chunk_info['original_url'] | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"FAISS search error: {e}") | |
| return [] | |
| def simple_search(self, query: str, k: int = 3): | |
| """Ultra-simple search - just use FAISS""" | |
| print(" - Performing simple FAISS-only search...") | |
| return self.search_faiss_fast(query, k=k) | |
| def format_rag_prompt_phi3(self, query: str, context_chunks: list): | |
| """Format prompt for Phi-3""" | |
| # Very short context | |
| context = " ".join([chunk['text'][:200] for chunk in context_chunks[:2]]) | |
| # Phi-3 instruct format | |
| prompt = f"""<|system|> | |
| You are a helpful assistant. Answer based only on the context provided. Be very brief. | |
| <|end|> | |
| <|user|> | |
| Context: {context} | |
| Question: {query} | |
| <|end|> | |
| <|assistant|>""" | |
| return prompt | |
| def format_rag_prompt_tinyllama(self, query: str, context_chunks: list): | |
| """Format prompt for TinyLlama""" | |
| context = " ".join([chunk['text'][:200] for chunk in context_chunks[:2]]) | |
| prompt = f"<|system|>\nAnswer briefly based on context.\n</s>\n<|user|>\nContext: {context}\n\nQ: {query}\n</s>\n<|assistant|>\n" | |
| return prompt | |
| def generate_llm_answer_with_timeout(self, query: str, context_chunks: list, timeout_seconds: int = 30): | |
| """Generate answer with timeout protection""" | |
| if not context_chunks: | |
| return "No relevant context found.", [] | |
| # Choose prompt format based on model | |
| if hasattr(self.llm_model, 'model_type') and self.llm_model.model_type == 'phi3': | |
| formatted_prompt = self.format_rag_prompt_phi3(query, context_chunks) | |
| else: | |
| formatted_prompt = self.format_rag_prompt_tinyllama(query, context_chunks) | |
| print(f" - Generating answer (max {timeout_seconds}s)...") | |
| result = {"answer": None, "error": None} | |
| def generate(): | |
| try: | |
| answer = self.llm_model( | |
| formatted_prompt, | |
| max_new_tokens=50, # Very short | |
| stop=["<|end|>", "</s>", "\n\n"], | |
| stream=False | |
| ) | |
| result["answer"] = answer.strip() | |
| except Exception as e: | |
| result["error"] = str(e) | |
| # Run generation in thread with timeout | |
| thread = threading.Thread(target=generate) | |
| thread.start() | |
| thread.join(timeout=timeout_seconds) | |
| if thread.is_alive(): | |
| print(" - Generation timed out!") | |
| return "Generation timed out. The model is too slow for this environment.", context_chunks[:2] | |
| if result["error"]: | |
| print(f" - Generation error: {result['error']}") | |
| return f"Error: {result['error']}", context_chunks[:2] | |
| answer = result["answer"] or "Could not generate answer." | |
| return answer, context_chunks[:2] | |
| def answer_query(self, query: str): | |
| """Main query answering with aggressive timeouts""" | |
| print(f"Received query: {query}") | |
| total_start = time.time() | |
| try: | |
| # 1. Super fast retrieval | |
| start_time = time.time() | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(self.simple_search, query, 3) | |
| try: | |
| retrieved_context = future.result(timeout=5) # 5 second timeout | |
| except TimeoutError: | |
| print(" Search timed out!") | |
| return "Search timed out. Please try a simpler query.", [], "" | |
| print(f" Retrieval completed in {time.time() - start_time:.2f}s") | |
| if not retrieved_context: | |
| return "No relevant documents found.", [], "" | |
| # 2. Generate Answer with timeout | |
| llm_answer, used_chunks = self.generate_llm_answer_with_timeout( | |
| query, | |
| retrieved_context, | |
| timeout_seconds=20 | |
| ) | |
| # 3. Format sources | |
| sources_text = "\n\n**Sources:**\n" | |
| for chunk in used_chunks: | |
| sources_text += f"- [{chunk['title']}]({chunk['url']})\n" | |
| total_time = time.time() - total_start | |
| print(f"Total processing time: {total_time:.2f}s") | |
| return llm_answer, used_chunks, sources_text | |
| except Exception as e: | |
| print(f"Error in answer_query: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return f"System error: {str(e)}", [], "" | |
| # For testing without full pipeline | |
| def test_simple_generation(): | |
| """Test if LLM generation works at all""" | |
| try: | |
| from ctransformers import AutoModelForCausalLM | |
| print("Testing simple generation...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", | |
| model_file="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", | |
| model_type="llama", | |
| gpu_layers=0, | |
| threads=2, | |
| context_length=128, | |
| max_new_tokens=20 | |
| ) | |
| result = model("Hello, how are", max_new_tokens=10, stream=False) | |
| print(f"Test result: {result}") | |
| return True | |
| except Exception as e: | |
| print(f"Test failed: {e}") | |
| return False |