MusaR commited on
Commit
561d690
·
verified ·
1 Parent(s): 0515601

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +226 -160
pipeline.py CHANGED
@@ -4,6 +4,8 @@ import time
4
  import pickle
5
  from pathlib import Path
6
  import warnings
 
 
7
 
8
  import pandas as pd
9
  import numpy as np
@@ -13,18 +15,18 @@ import faiss
13
  import torch
14
 
15
  from rank_bm25 import BM25Okapi
16
- from sentence_transformers import SentenceTransformer, CrossEncoder
17
  from ctransformers import AutoModelForCausalLM
18
 
19
  # --- Basic Configuration ---
20
  warnings.filterwarnings("ignore")
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
22
  nltk.download('punkt', quiet=True)
23
  RANDOM_SEED = 42
24
  np.random.seed(RANDOM_SEED)
25
  torch.manual_seed(RANDOM_SEED)
26
- if torch.cuda.is_available():
27
- torch.cuda.manual_seed_all(RANDOM_SEED)
28
 
29
  DEVICE = "cpu"
30
 
@@ -34,9 +36,12 @@ class RAGPipeline:
34
  self.bm25 = None
35
  self.index_faiss = None
36
  self.embedding_model = None
37
- self.reranker_model = None
38
  self.llm_model = None
39
- self.llm_tokenizer = None
 
 
 
 
40
  self.load_artifacts()
41
  self.load_models()
42
 
@@ -44,6 +49,14 @@ class RAGPipeline:
44
  print(f"--> Loading artifacts from root directory")
45
  self.chunks_df = pd.read_parquet("chunks_df.parquet")
46
  print(f"Loaded {len(self.chunks_df)} chunks.")
 
 
 
 
 
 
 
 
47
 
48
  with open("bm25_index.pkl", "rb") as f:
49
  self.bm25 = pickle.load(f)
@@ -54,205 +67,258 @@ class RAGPipeline:
54
 
55
  def load_models(self):
56
  print("--> Loading models...")
57
- # Dense Retriever
58
- EMBEDDING_MODEL_NAME = 'multi-qa-MiniLM-L6-cos-v1'
59
  self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE)
60
- # Optimize for CPU
61
- self.embedding_model.max_seq_length = 256 # Reduce from default 512
62
  print(f"Embedding model '{EMBEDDING_MODEL_NAME}' loaded.")
63
 
64
- # Reranker - SKIP for CPU optimization
65
- # We'll use simple score combination instead
66
- self.reranker_model = None
67
- print("Skipping reranker model for CPU optimization.")
68
-
69
- # LLM - Using TinyLlama with optimized settings
70
- LLM_REPO_ID = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
71
- LLM_MODEL_FILE = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
72
 
73
- print(f"Loading LLM: {LLM_REPO_ID}/{LLM_MODEL_FILE}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- try:
76
- self.llm_model = AutoModelForCausalLM.from_pretrained(
77
- LLM_REPO_ID,
78
- model_file=LLM_MODEL_FILE,
79
- model_type="llama",
80
- temperature=0.1,
81
- top_p=0.9,
82
- repetition_penalty=1.05,
83
- context_length=1024, # Reduced from 2048
84
- gpu_layers=0,
85
- threads=4, # Fixed threads for stability
86
- batch_size=8, # Small batch size
87
- stream=False,
88
- local_files_only=False
89
- )
90
- except Exception as e:
91
- print(f"Error loading LLM: {e}")
92
- raise
93
- print("LLM loaded successfully.")
 
 
 
94
 
95
- def search_bm25(self, query: str, k: int = 5):
 
96
  tokenized_query = query.lower().split()
97
- # Get scores more efficiently
98
- scores = self.bm25.get_scores(tokenized_query)
99
- # Use numpy for faster operations
100
- topk_indices = np.argpartition(scores, -k)[-k:]
101
- topk_indices = topk_indices[np.argsort(scores[topk_indices])[::-1]]
102
 
103
- results = []
104
- for idx in topk_indices:
105
- chunk_info = self.chunks_df.iloc[idx]
106
- results.append({
107
- 'chunk_id': chunk_info['chunk_id'],
108
- 'doc_id': chunk_info['doc_id'],
109
- 'score': float(scores[idx]),
110
- 'text': chunk_info['chunk_text'],
111
- 'title': chunk_info['original_title'],
112
- 'url': chunk_info['original_url']
113
- })
114
- return results
115
-
116
- def search_faiss(self, query: str, k: int = 5):
117
- # Encode with reduced batch size
118
- query_embedding = self.embedding_model.encode(
119
- query,
120
- convert_to_tensor=False, # Stay in numpy
121
- show_progress_bar=False
122
- )
123
- query_embedding = query_embedding.reshape(1, -1)
124
- faiss.normalize_L2(query_embedding)
125
 
126
- distances, indices = self.index_faiss.search(query_embedding, k)
 
127
 
128
  results = []
129
- for i in range(len(indices[0])):
130
- idx = indices[0][i]
131
- if idx < 0: # Skip invalid indices
132
- continue
133
- score = float(distances[0][i])
134
  chunk_info = self.chunks_df.iloc[idx]
135
  results.append({
136
  'chunk_id': chunk_info['chunk_id'],
137
  'doc_id': chunk_info['doc_id'],
138
- 'score': score,
139
- 'text': chunk_info['chunk_text'],
140
  'title': chunk_info['original_title'],
141
  'url': chunk_info['original_url']
142
  })
143
  return results
144
 
145
- def hybrid_search_simple(self, query: str, k: int = 5):
146
- """Simplified hybrid search without reranking for CPU efficiency"""
147
- print(" - Performing hybrid search...")
148
- # Reduce k for faster search
149
- bm25_res = self.search_bm25(query, k=k)
150
- faiss_res = self.search_faiss(query, k=k)
151
-
152
- # Simple score combination
153
- combined_scores = {}
154
-
155
- # Normalize BM25 scores
156
- if bm25_res:
157
- max_bm25 = max(r['score'] for r in bm25_res)
158
- for res in bm25_res:
159
- chunk_id = res['chunk_id']
160
- norm_score = res['score'] / max_bm25 if max_bm25 > 0 else 0
161
- combined_scores[chunk_id] = {
162
- 'data': res,
163
- 'score': norm_score * 0.5 # Weight for BM25
164
- }
165
-
166
- # Normalize FAISS scores (cosine similarity already in [0,1])
167
- for res in faiss_res:
168
- chunk_id = res['chunk_id']
169
- if chunk_id in combined_scores:
170
- combined_scores[chunk_id]['score'] += (1 - res['score']) * 0.5
171
- else:
172
- combined_scores[chunk_id] = {
173
- 'data': res,
174
- 'score': (1 - res['score']) * 0.5
175
- }
176
-
177
- # Sort by combined score
178
- sorted_results = sorted(
179
- combined_scores.values(),
180
- key=lambda x: x['score'],
181
- reverse=True
182
- )
183
-
184
- # Return top k results
185
- return [item['data'] for item in sorted_results[:k]]
186
 
187
- def format_rag_prompt(self, query: str, context_chunks: list):
188
- # Limit context to save tokens
189
- context_texts = [chunk['text'][:300] for chunk in context_chunks[:3]] # Max 3 chunks, 300 chars each
190
- context_str = "\n---\n".join(context_texts)
191
-
192
- # Shorter system message
193
- system_message = "Answer based on the context. Be concise."
194
- user_message_content = f"Context:\n{context_str}\n\nQuestion: {query}\n\nAnswer:"
 
195
 
196
- # TinyLlama chat template
197
- prompt = f"<|system|>\n{system_message}</s>\n<|user|>\n{user_message_content}</s>\n<|assistant|>\n"
 
 
 
 
 
 
 
 
198
  return prompt
199
 
200
- def generate_llm_answer(self, query: str, context_chunks: list):
 
 
 
 
 
 
 
201
  if not context_chunks:
202
- return "No relevant context found to answer the question.", []
203
 
204
- formatted_prompt = self.format_rag_prompt(query, context_chunks)
 
 
 
 
205
 
206
- print(" - Generating answer...")
207
- try:
208
- # Generate with timeout protection
209
- answer = self.llm_model(
210
- formatted_prompt,
211
- max_new_tokens=150, # Reduced from 250
212
- stop=["</s>", "\n\n"], # Stop tokens
213
- stream=False
214
- )
215
-
216
- # Clean up the answer
217
- answer = answer.strip()
218
- if not answer:
219
- answer = "I couldn't generate a proper response. Please try again."
220
-
221
- except Exception as e:
222
- print(f"LLM generation error: {e}")
223
- answer = "An error occurred during answer generation."
224
 
225
- return answer, context_chunks[:3] # Return only top 3 chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  def answer_query(self, query: str):
 
228
  print(f"Received query: {query}")
229
 
 
 
230
  try:
231
- # 1. Simplified retrieval (no reranking)
232
  start_time = time.time()
233
- retrieved_context = self.hybrid_search_simple(query, k=5)
 
 
 
 
 
 
 
234
  print(f" Retrieval completed in {time.time() - start_time:.2f}s")
235
 
236
  if not retrieved_context:
237
- return "Could not find any relevant documents to answer your question.", [], "No context found."
238
 
239
- # 2. Generate Answer
240
- start_time = time.time()
241
- llm_answer, used_context_chunks = self.generate_llm_answer(query, retrieved_context)
242
- print(f" Generation completed in {time.time() - start_time:.2f}s")
 
 
243
 
244
  # 3. Format sources
245
  sources_text = "\n\n**Sources:**\n"
246
- seen_urls = set()
247
- for chunk in used_context_chunks:
248
- if chunk['url'] not in seen_urls:
249
- sources_text += f"- [{chunk['title']}]({chunk['url']})\n"
250
- seen_urls.add(chunk['url'])
251
 
252
- return llm_answer, used_context_chunks, sources_text
 
 
 
253
 
254
  except Exception as e:
255
  print(f"Error in answer_query: {e}")
256
  import traceback
257
  traceback.print_exc()
258
- return f"An error occurred: {str(e)}", [], "Error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pickle
5
  from pathlib import Path
6
  import warnings
7
+ import threading
8
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
9
 
10
  import pandas as pd
11
  import numpy as np
 
15
  import torch
16
 
17
  from rank_bm25 import BM25Okapi
18
+ from sentence_transformers import SentenceTransformer
19
  from ctransformers import AutoModelForCausalLM
20
 
21
  # --- Basic Configuration ---
22
  warnings.filterwarnings("ignore")
23
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
+ os.environ["OMP_NUM_THREADS"] = "2"
25
+ os.environ["MKL_NUM_THREADS"] = "2"
26
  nltk.download('punkt', quiet=True)
27
  RANDOM_SEED = 42
28
  np.random.seed(RANDOM_SEED)
29
  torch.manual_seed(RANDOM_SEED)
 
 
30
 
31
  DEVICE = "cpu"
32
 
 
36
  self.bm25 = None
37
  self.index_faiss = None
38
  self.embedding_model = None
 
39
  self.llm_model = None
40
+
41
+ # Create a sample of the data for faster search
42
+ self.sample_indices = None
43
+ self.sample_size = 50000 # Work with subset of data
44
+
45
  self.load_artifacts()
46
  self.load_models()
47
 
 
49
  print(f"--> Loading artifacts from root directory")
50
  self.chunks_df = pd.read_parquet("chunks_df.parquet")
51
  print(f"Loaded {len(self.chunks_df)} chunks.")
52
+
53
+ # Create a random sample for faster search
54
+ self.sample_indices = np.random.choice(
55
+ len(self.chunks_df),
56
+ size=min(self.sample_size, len(self.chunks_df)),
57
+ replace=False
58
+ )
59
+ print(f"Created sample of {len(self.sample_indices)} chunks for faster search")
60
 
61
  with open("bm25_index.pkl", "rb") as f:
62
  self.bm25 = pickle.load(f)
 
67
 
68
  def load_models(self):
69
  print("--> Loading models...")
70
+ # Dense Retriever - use smaller model
71
+ EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' # Faster than multi-qa variant
72
  self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE)
73
+ self.embedding_model.max_seq_length = 128 # Very short for speed
 
74
  print(f"Embedding model '{EMBEDDING_MODEL_NAME}' loaded.")
75
 
76
+ # LLM - Try Phi-3 Mini with different settings
77
+ print("Loading Phi-3 Mini...")
 
 
 
 
 
 
78
 
79
+ # Multiple model options to try
80
+ model_options = [
81
+ # Option 1: Phi-3 Mini 4K
82
+ {
83
+ "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
84
+ "model_file": "Phi-3-mini-4k-instruct-q4.gguf",
85
+ "model_type": "phi3"
86
+ },
87
+ # Option 2: Back to TinyLlama but with different settings
88
+ {
89
+ "repo_id": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
90
+ "model_file": "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
91
+ "model_type": "llama"
92
+ },
93
+ # Option 3: Even smaller model
94
+ {
95
+ "repo_id": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
96
+ "model_file": "tinyllama-1.1b-chat-v1.0.Q2_K.gguf", # Smaller quantization
97
+ "model_type": "llama"
98
+ }
99
+ ]
100
 
101
+ for i, model_config in enumerate(model_options):
102
+ try:
103
+ print(f"Trying model option {i+1}: {model_config['repo_id']}")
104
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
105
+ model_config["repo_id"],
106
+ model_file=model_config["model_file"],
107
+ model_type=model_config["model_type"],
108
+ temperature=0.1,
109
+ max_new_tokens=100,
110
+ context_length=512, # Very short context
111
+ gpu_layers=0,
112
+ threads=2, # Minimal threads
113
+ batch_size=1, # Smallest batch
114
+ stream=False,
115
+ local_files_only=False
116
+ )
117
+ print(f"Successfully loaded model: {model_config['repo_id']}")
118
+ break
119
+ except Exception as e:
120
+ print(f"Failed to load model option {i+1}: {e}")
121
+ if i == len(model_options) - 1:
122
+ raise Exception("Failed to load any LLM model")
123
 
124
+ def search_bm25_fast(self, query: str, k: int = 5):
125
+ """Ultra-fast BM25 search on sample"""
126
  tokenized_query = query.lower().split()
 
 
 
 
 
127
 
128
+ # Only search the sample
129
+ sample_scores = []
130
+ for idx in self.sample_indices[:10000]: # Even smaller subset
131
+ doc_tokens = self.bm25.doc_freqs[idx]
132
+ score = self.bm25._score(tokenized_query, idx)
133
+ sample_scores.append((idx, score))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ # Get top k
136
+ sample_scores.sort(key=lambda x: x[1], reverse=True)
137
 
138
  results = []
139
+ for idx, score in sample_scores[:k]:
 
 
 
 
140
  chunk_info = self.chunks_df.iloc[idx]
141
  results.append({
142
  'chunk_id': chunk_info['chunk_id'],
143
  'doc_id': chunk_info['doc_id'],
144
+ 'score': float(score),
145
+ 'text': chunk_info['chunk_text'][:500], # Truncate text
146
  'title': chunk_info['original_title'],
147
  'url': chunk_info['original_url']
148
  })
149
  return results
150
 
151
+ def search_faiss_fast(self, query: str, k: int = 5):
152
+ """Fast FAISS search with timeout"""
153
+ try:
154
+ # Quick embedding
155
+ query_embedding = self.embedding_model.encode(
156
+ query[:100], # Truncate query if too long
157
+ convert_to_tensor=False,
158
+ show_progress_bar=False
159
+ )
160
+ query_embedding = query_embedding.reshape(1, -1).astype('float32')
161
+ faiss.normalize_L2(query_embedding)
162
+
163
+ # Search with reduced k
164
+ distances, indices = self.index_faiss.search(query_embedding, k)
165
+
166
+ results = []
167
+ for i in range(min(k, len(indices[0]))):
168
+ idx = indices[0][i]
169
+ if idx < 0 or idx >= len(self.chunks_df):
170
+ continue
171
+ score = float(distances[0][i])
172
+ chunk_info = self.chunks_df.iloc[idx]
173
+ results.append({
174
+ 'chunk_id': chunk_info['chunk_id'],
175
+ 'doc_id': chunk_info['doc_id'],
176
+ 'score': score,
177
+ 'text': chunk_info['chunk_text'][:500],
178
+ 'title': chunk_info['original_title'],
179
+ 'url': chunk_info['original_url']
180
+ })
181
+ return results
182
+ except Exception as e:
183
+ print(f"FAISS search error: {e}")
184
+ return []
 
 
 
 
 
 
 
185
 
186
+ def simple_search(self, query: str, k: int = 3):
187
+ """Ultra-simple search - just use FAISS"""
188
+ print(" - Performing simple FAISS-only search...")
189
+ return self.search_faiss_fast(query, k=k)
190
+
191
+ def format_rag_prompt_phi3(self, query: str, context_chunks: list):
192
+ """Format prompt for Phi-3"""
193
+ # Very short context
194
+ context = " ".join([chunk['text'][:200] for chunk in context_chunks[:2]])
195
 
196
+ # Phi-3 instruct format
197
+ prompt = f"""<|system|>
198
+ You are a helpful assistant. Answer based only on the context provided. Be very brief.
199
+ <|end|>
200
+ <|user|>
201
+ Context: {context}
202
+
203
+ Question: {query}
204
+ <|end|>
205
+ <|assistant|>"""
206
  return prompt
207
 
208
+ def format_rag_prompt_tinyllama(self, query: str, context_chunks: list):
209
+ """Format prompt for TinyLlama"""
210
+ context = " ".join([chunk['text'][:200] for chunk in context_chunks[:2]])
211
+ prompt = f"<|system|>\nAnswer briefly based on context.\n</s>\n<|user|>\nContext: {context}\n\nQ: {query}\n</s>\n<|assistant|>\n"
212
+ return prompt
213
+
214
+ def generate_llm_answer_with_timeout(self, query: str, context_chunks: list, timeout_seconds: int = 30):
215
+ """Generate answer with timeout protection"""
216
  if not context_chunks:
217
+ return "No relevant context found.", []
218
 
219
+ # Choose prompt format based on model
220
+ if hasattr(self.llm_model, 'model_type') and self.llm_model.model_type == 'phi3':
221
+ formatted_prompt = self.format_rag_prompt_phi3(query, context_chunks)
222
+ else:
223
+ formatted_prompt = self.format_rag_prompt_tinyllama(query, context_chunks)
224
 
225
+ print(f" - Generating answer (max {timeout_seconds}s)...")
226
+
227
+ result = {"answer": None, "error": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ def generate():
230
+ try:
231
+ answer = self.llm_model(
232
+ formatted_prompt,
233
+ max_new_tokens=50, # Very short
234
+ stop=["<|end|>", "</s>", "\n\n"],
235
+ stream=False
236
+ )
237
+ result["answer"] = answer.strip()
238
+ except Exception as e:
239
+ result["error"] = str(e)
240
+
241
+ # Run generation in thread with timeout
242
+ thread = threading.Thread(target=generate)
243
+ thread.start()
244
+ thread.join(timeout=timeout_seconds)
245
+
246
+ if thread.is_alive():
247
+ print(" - Generation timed out!")
248
+ return "Generation timed out. The model is too slow for this environment.", context_chunks[:2]
249
+
250
+ if result["error"]:
251
+ print(f" - Generation error: {result['error']}")
252
+ return f"Error: {result['error']}", context_chunks[:2]
253
+
254
+ answer = result["answer"] or "Could not generate answer."
255
+ return answer, context_chunks[:2]
256
 
257
  def answer_query(self, query: str):
258
+ """Main query answering with aggressive timeouts"""
259
  print(f"Received query: {query}")
260
 
261
+ total_start = time.time()
262
+
263
  try:
264
+ # 1. Super fast retrieval
265
  start_time = time.time()
266
+ with ThreadPoolExecutor(max_workers=1) as executor:
267
+ future = executor.submit(self.simple_search, query, 3)
268
+ try:
269
+ retrieved_context = future.result(timeout=5) # 5 second timeout
270
+ except TimeoutError:
271
+ print(" Search timed out!")
272
+ return "Search timed out. Please try a simpler query.", [], ""
273
+
274
  print(f" Retrieval completed in {time.time() - start_time:.2f}s")
275
 
276
  if not retrieved_context:
277
+ return "No relevant documents found.", [], ""
278
 
279
+ # 2. Generate Answer with timeout
280
+ llm_answer, used_chunks = self.generate_llm_answer_with_timeout(
281
+ query,
282
+ retrieved_context,
283
+ timeout_seconds=20
284
+ )
285
 
286
  # 3. Format sources
287
  sources_text = "\n\n**Sources:**\n"
288
+ for chunk in used_chunks:
289
+ sources_text += f"- [{chunk['title']}]({chunk['url']})\n"
 
 
 
290
 
291
+ total_time = time.time() - total_start
292
+ print(f"Total processing time: {total_time:.2f}s")
293
+
294
+ return llm_answer, used_chunks, sources_text
295
 
296
  except Exception as e:
297
  print(f"Error in answer_query: {e}")
298
  import traceback
299
  traceback.print_exc()
300
+ return f"System error: {str(e)}", [], ""
301
+
302
+ # For testing without full pipeline
303
+ def test_simple_generation():
304
+ """Test if LLM generation works at all"""
305
+ try:
306
+ from ctransformers import AutoModelForCausalLM
307
+ print("Testing simple generation...")
308
+
309
+ model = AutoModelForCausalLM.from_pretrained(
310
+ "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
311
+ model_file="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
312
+ model_type="llama",
313
+ gpu_layers=0,
314
+ threads=2,
315
+ context_length=128,
316
+ max_new_tokens=20
317
+ )
318
+
319
+ result = model("Hello, how are", max_new_tokens=10, stream=False)
320
+ print(f"Test result: {result}")
321
+ return True
322
+ except Exception as e:
323
+ print(f"Test failed: {e}")
324
+ return False