NLP-RAG-world-news / pipeline.py
MusaR's picture
Update pipeline.py
561d690 verified
import os
import gc
import time
import pickle
from pathlib import Path
import warnings
import threading
from concurrent.futures import ThreadPoolExecutor, TimeoutError
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import nltk
import faiss
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from ctransformers import AutoModelForCausalLM
# --- Basic Configuration ---
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
nltk.download('punkt', quiet=True)
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
DEVICE = "cpu"
class RAGPipeline:
def __init__(self):
self.chunks_df = None
self.bm25 = None
self.index_faiss = None
self.embedding_model = None
self.llm_model = None
# Create a sample of the data for faster search
self.sample_indices = None
self.sample_size = 50000 # Work with subset of data
self.load_artifacts()
self.load_models()
def load_artifacts(self):
print(f"--> Loading artifacts from root directory")
self.chunks_df = pd.read_parquet("chunks_df.parquet")
print(f"Loaded {len(self.chunks_df)} chunks.")
# Create a random sample for faster search
self.sample_indices = np.random.choice(
len(self.chunks_df),
size=min(self.sample_size, len(self.chunks_df)),
replace=False
)
print(f"Created sample of {len(self.sample_indices)} chunks for faster search")
with open("bm25_index.pkl", "rb") as f:
self.bm25 = pickle.load(f)
print("Loaded BM25 index.")
self.index_faiss = faiss.read_index("news_chunks.faiss_index")
print(f"Loaded FAISS index with {self.index_faiss.ntotal} vectors.")
def load_models(self):
print("--> Loading models...")
# Dense Retriever - use smaller model
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' # Faster than multi-qa variant
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE)
self.embedding_model.max_seq_length = 128 # Very short for speed
print(f"Embedding model '{EMBEDDING_MODEL_NAME}' loaded.")
# LLM - Try Phi-3 Mini with different settings
print("Loading Phi-3 Mini...")
# Multiple model options to try
model_options = [
# Option 1: Phi-3 Mini 4K
{
"repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
"model_file": "Phi-3-mini-4k-instruct-q4.gguf",
"model_type": "phi3"
},
# Option 2: Back to TinyLlama but with different settings
{
"repo_id": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
"model_file": "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
"model_type": "llama"
},
# Option 3: Even smaller model
{
"repo_id": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
"model_file": "tinyllama-1.1b-chat-v1.0.Q2_K.gguf", # Smaller quantization
"model_type": "llama"
}
]
for i, model_config in enumerate(model_options):
try:
print(f"Trying model option {i+1}: {model_config['repo_id']}")
self.llm_model = AutoModelForCausalLM.from_pretrained(
model_config["repo_id"],
model_file=model_config["model_file"],
model_type=model_config["model_type"],
temperature=0.1,
max_new_tokens=100,
context_length=512, # Very short context
gpu_layers=0,
threads=2, # Minimal threads
batch_size=1, # Smallest batch
stream=False,
local_files_only=False
)
print(f"Successfully loaded model: {model_config['repo_id']}")
break
except Exception as e:
print(f"Failed to load model option {i+1}: {e}")
if i == len(model_options) - 1:
raise Exception("Failed to load any LLM model")
def search_bm25_fast(self, query: str, k: int = 5):
"""Ultra-fast BM25 search on sample"""
tokenized_query = query.lower().split()
# Only search the sample
sample_scores = []
for idx in self.sample_indices[:10000]: # Even smaller subset
doc_tokens = self.bm25.doc_freqs[idx]
score = self.bm25._score(tokenized_query, idx)
sample_scores.append((idx, score))
# Get top k
sample_scores.sort(key=lambda x: x[1], reverse=True)
results = []
for idx, score in sample_scores[:k]:
chunk_info = self.chunks_df.iloc[idx]
results.append({
'chunk_id': chunk_info['chunk_id'],
'doc_id': chunk_info['doc_id'],
'score': float(score),
'text': chunk_info['chunk_text'][:500], # Truncate text
'title': chunk_info['original_title'],
'url': chunk_info['original_url']
})
return results
def search_faiss_fast(self, query: str, k: int = 5):
"""Fast FAISS search with timeout"""
try:
# Quick embedding
query_embedding = self.embedding_model.encode(
query[:100], # Truncate query if too long
convert_to_tensor=False,
show_progress_bar=False
)
query_embedding = query_embedding.reshape(1, -1).astype('float32')
faiss.normalize_L2(query_embedding)
# Search with reduced k
distances, indices = self.index_faiss.search(query_embedding, k)
results = []
for i in range(min(k, len(indices[0]))):
idx = indices[0][i]
if idx < 0 or idx >= len(self.chunks_df):
continue
score = float(distances[0][i])
chunk_info = self.chunks_df.iloc[idx]
results.append({
'chunk_id': chunk_info['chunk_id'],
'doc_id': chunk_info['doc_id'],
'score': score,
'text': chunk_info['chunk_text'][:500],
'title': chunk_info['original_title'],
'url': chunk_info['original_url']
})
return results
except Exception as e:
print(f"FAISS search error: {e}")
return []
def simple_search(self, query: str, k: int = 3):
"""Ultra-simple search - just use FAISS"""
print(" - Performing simple FAISS-only search...")
return self.search_faiss_fast(query, k=k)
def format_rag_prompt_phi3(self, query: str, context_chunks: list):
"""Format prompt for Phi-3"""
# Very short context
context = " ".join([chunk['text'][:200] for chunk in context_chunks[:2]])
# Phi-3 instruct format
prompt = f"""<|system|>
You are a helpful assistant. Answer based only on the context provided. Be very brief.
<|end|>
<|user|>
Context: {context}
Question: {query}
<|end|>
<|assistant|>"""
return prompt
def format_rag_prompt_tinyllama(self, query: str, context_chunks: list):
"""Format prompt for TinyLlama"""
context = " ".join([chunk['text'][:200] for chunk in context_chunks[:2]])
prompt = f"<|system|>\nAnswer briefly based on context.\n</s>\n<|user|>\nContext: {context}\n\nQ: {query}\n</s>\n<|assistant|>\n"
return prompt
def generate_llm_answer_with_timeout(self, query: str, context_chunks: list, timeout_seconds: int = 30):
"""Generate answer with timeout protection"""
if not context_chunks:
return "No relevant context found.", []
# Choose prompt format based on model
if hasattr(self.llm_model, 'model_type') and self.llm_model.model_type == 'phi3':
formatted_prompt = self.format_rag_prompt_phi3(query, context_chunks)
else:
formatted_prompt = self.format_rag_prompt_tinyllama(query, context_chunks)
print(f" - Generating answer (max {timeout_seconds}s)...")
result = {"answer": None, "error": None}
def generate():
try:
answer = self.llm_model(
formatted_prompt,
max_new_tokens=50, # Very short
stop=["<|end|>", "</s>", "\n\n"],
stream=False
)
result["answer"] = answer.strip()
except Exception as e:
result["error"] = str(e)
# Run generation in thread with timeout
thread = threading.Thread(target=generate)
thread.start()
thread.join(timeout=timeout_seconds)
if thread.is_alive():
print(" - Generation timed out!")
return "Generation timed out. The model is too slow for this environment.", context_chunks[:2]
if result["error"]:
print(f" - Generation error: {result['error']}")
return f"Error: {result['error']}", context_chunks[:2]
answer = result["answer"] or "Could not generate answer."
return answer, context_chunks[:2]
def answer_query(self, query: str):
"""Main query answering with aggressive timeouts"""
print(f"Received query: {query}")
total_start = time.time()
try:
# 1. Super fast retrieval
start_time = time.time()
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(self.simple_search, query, 3)
try:
retrieved_context = future.result(timeout=5) # 5 second timeout
except TimeoutError:
print(" Search timed out!")
return "Search timed out. Please try a simpler query.", [], ""
print(f" Retrieval completed in {time.time() - start_time:.2f}s")
if not retrieved_context:
return "No relevant documents found.", [], ""
# 2. Generate Answer with timeout
llm_answer, used_chunks = self.generate_llm_answer_with_timeout(
query,
retrieved_context,
timeout_seconds=20
)
# 3. Format sources
sources_text = "\n\n**Sources:**\n"
for chunk in used_chunks:
sources_text += f"- [{chunk['title']}]({chunk['url']})\n"
total_time = time.time() - total_start
print(f"Total processing time: {total_time:.2f}s")
return llm_answer, used_chunks, sources_text
except Exception as e:
print(f"Error in answer_query: {e}")
import traceback
traceback.print_exc()
return f"System error: {str(e)}", [], ""
# For testing without full pipeline
def test_simple_generation():
"""Test if LLM generation works at all"""
try:
from ctransformers import AutoModelForCausalLM
print("Testing simple generation...")
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
model_file="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
model_type="llama",
gpu_layers=0,
threads=2,
context_length=128,
max_new_tokens=20
)
result = model("Hello, how are", max_new_tokens=10, stream=False)
print(f"Test result: {result}")
return True
except Exception as e:
print(f"Test failed: {e}")
return False