kn29 commited on
Commit
4c6aa01
·
verified ·
1 Parent(s): df9660d

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +529 -508
rag.py CHANGED
@@ -14,21 +14,13 @@ from collections import defaultdict
14
  import spacy
15
  from rank_bm25 import BM25Okapi
16
 
17
- # Global variables for models
18
- MODEL = None
19
- TOKENIZER = None
20
- GROQ_CLIENT = None
21
- NLP_MODEL = None
22
- DEVICE = None
23
 
24
- # Global indices
25
- DENSE_INDEX = None
26
- BM25_INDEX = None
27
- CONCEPT_GRAPH = None
28
- TOKEN_TO_CHUNKS = None
29
- CHUNKS_DATA = []
30
-
31
- # Legal knowledge base
32
  LEGAL_CONCEPTS = {
33
  'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'],
34
  'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'],
@@ -46,8 +38,8 @@ QUERY_PATTERNS = {
46
  }
47
 
48
  def initialize_models(model_id: str, groq_api_key: str = None):
49
- """Initialize all models and components"""
50
- global MODEL, TOKENIZER, GROQ_CLIENT, NLP_MODEL, DEVICE
51
 
52
  try:
53
  nltk.download('punkt', quiet=True)
@@ -55,539 +47,568 @@ def initialize_models(model_id: str, groq_api_key: str = None):
55
  except:
56
  pass
57
 
58
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
- print(f"Using device: {DEVICE}")
60
 
61
  print(f"Loading model: {model_id}")
62
- TOKENIZER = AutoTokenizer.from_pretrained(model_id)
63
- MODEL = AutoModel.from_pretrained(model_id).to(DEVICE)
64
- MODEL.eval()
65
-
66
- if groq_api_key:
67
- GROQ_CLIENT = Groq(api_key=groq_api_key)
68
 
69
  try:
70
- NLP_MODEL = spacy.load("en_core_web_sm")
71
  except:
72
  print("SpaCy model not found, using basic NER")
73
- NLP_MODEL = None
74
 
75
- def create_embedding(text: str) -> np.ndarray:
76
- """Create dense embedding for text"""
77
- inputs = TOKENIZER(text, padding=True, truncation=True,
78
- max_length=512, return_tensors='pt').to(DEVICE)
79
 
80
- with torch.no_grad():
81
- outputs = MODEL(**inputs)
82
- attention_mask = inputs['attention_mask']
83
- token_embeddings = outputs.last_hidden_state
84
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
85
- embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
86
 
87
- # Normalize embeddings
88
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
 
 
 
 
89
 
90
- return embeddings.cpu().numpy()[0]
91
-
92
- def extract_legal_entities(text: str) -> List[Dict[str, Any]]:
93
- """Extract legal entities from text"""
94
- entities = []
95
-
96
- if NLP_MODEL:
97
- doc = NLP_MODEL(text[:5000]) # Limit for performance
98
- for ent in doc.ents:
99
- if ent.label_ in ['PERSON', 'ORG', 'LAW', 'GPE']:
100
- entities.append({
101
- 'text': ent.text,
102
- 'type': ent.label_,
103
- 'importance': 1.0
104
- })
105
-
106
- # Legal citations
107
- citation_pattern = r'\b\d+\s+[A-Z][a-z]+\.?\s+\d+\b'
108
- for match in re.finditer(citation_pattern, text):
109
- entities.append({
110
- 'text': match.group(),
111
- 'type': 'case_citation',
112
- 'importance': 2.0
113
- })
114
-
115
- # Statute references
116
- statute_pattern = r'§\s*\d+[\.\d]*|\bSection\s+\d+'
117
- for match in re.finditer(statute_pattern, text):
118
- entities.append({
119
- 'text': match.group(),
120
- 'type': 'statute',
121
- 'importance': 1.5
122
- })
123
-
124
- return entities
125
 
126
- def analyze_query(query: str) -> Dict[str, Any]:
127
- """Analyze query to understand intent"""
128
- query_lower = query.lower()
129
-
130
- # Classify query type
131
- query_type = 'general'
132
- for qtype, patterns in QUERY_PATTERNS.items():
133
- if any(pattern in query_lower for pattern in patterns):
134
- query_type = qtype
135
- break
136
-
137
- # Extract entities
138
- entities = extract_legal_entities(query)
139
-
140
- # Extract key concepts
141
- key_concepts = []
142
- for concept_category, concepts in LEGAL_CONCEPTS.items():
143
- for concept in concepts:
144
- if concept in query_lower:
145
- key_concepts.append(concept)
146
-
147
- # Generate expanded queries
148
- expanded_queries = [query]
149
-
150
- # Concept expansion
151
- if key_concepts:
152
- expanded_queries.append(f"{query} {' '.join(key_concepts[:3])}")
153
-
154
- # Type-based expansion
155
- if query_type == 'precedent':
156
- expanded_queries.append(f"legal precedent case law {query}")
157
- elif query_type == 'statute_interpretation':
158
- expanded_queries.append(f"statutory interpretation meaning {query}")
159
-
160
- # HyDE - Hypothetical document generation
161
- if GROQ_CLIENT:
162
- hyde_doc = generate_hypothetical_document(query)
163
- if hyde_doc:
164
- expanded_queries.append(hyde_doc)
165
-
166
- return {
167
- 'original_query': query,
168
- 'query_type': query_type,
169
- 'entities': entities,
170
- 'key_concepts': key_concepts,
171
- 'expanded_queries': expanded_queries[:4] # Limit to 4 queries
172
- }
173
 
174
- def generate_hypothetical_document(query: str) -> Optional[str]:
175
- """Generate hypothetical answer document (HyDE technique)"""
176
- if not GROQ_CLIENT:
177
- return None
178
-
179
- try:
180
- prompt = f"""Generate a brief hypothetical legal document excerpt that would answer this question: {query}
181
-
182
- Write it as if it's from an actual legal case or statute. Be specific and use legal language.
183
- Keep it under 100 words."""
184
-
185
- response = GROQ_CLIENT.chat.completions.create(
186
- messages=[
187
- {"role": "system", "content": "You are a legal expert generating hypothetical legal text."},
188
- {"role": "user", "content": prompt}
189
- ],
190
- model="llama-3.1-8b-instant",
191
- temperature=0.3,
192
- max_tokens=150
193
- )
194
-
195
- return response.choices[0].message.content
196
- except:
197
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- def chunk_text_hierarchical(text: str, title: str = "") -> List[Dict[str, Any]]:
200
- """Create hierarchical chunks with legal structure awareness"""
201
- chunks = []
202
-
203
- # Clean text
204
- text = re.sub(r'\s+', ' ', text)
205
-
206
- # Identify legal sections
207
- section_patterns = [
208
- (r'(?i)\bFACTS?\b[:\s]', 'facts'),
209
- (r'(?i)\bHOLDING\b[:\s]', 'holding'),
210
- (r'(?i)\bREASONING\b[:\s]', 'reasoning'),
211
- (r'(?i)\bDISSENT\b[:\s]', 'dissent'),
212
- (r'(?i)\bCONCLUSION\b[:\s]', 'conclusion')
213
- ]
214
-
215
- sections = []
216
- for pattern, section_type in section_patterns:
217
- matches = list(re.finditer(pattern, text))
218
- for match in matches:
219
- sections.append((match.start(), section_type))
220
-
221
- sections.sort(key=lambda x: x[0])
222
-
223
- # Split into sentences
224
- import nltk
225
- try:
226
- sentences = nltk.sent_tokenize(text)
227
- except:
228
- sentences = text.split('. ')
229
-
230
- # Create chunks
231
- current_section = 'introduction'
232
- section_sentences = []
233
- chunk_size = 500 # words
234
-
235
- for sent in sentences:
236
- # Check section type
237
- sent_pos = text.find(sent)
238
- for pos, stype in sections:
239
- if sent_pos >= pos:
240
- current_section = stype
241
-
242
- section_sentences.append(sent)
243
-
244
- # Create chunk when we have enough content
245
- chunk_text = ' '.join(section_sentences)
246
- if len(chunk_text.split()) >= chunk_size or len(section_sentences) >= 10:
247
- chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
248
 
249
- # Calculate importance
250
- importance = 1.0
251
- section_weights = {
252
- 'holding': 2.0, 'conclusion': 1.8, 'reasoning': 1.5,
253
- 'facts': 1.2, 'dissent': 0.8
254
- }
255
- importance *= section_weights.get(current_section, 1.0)
256
 
257
- # Entity importance
258
- entities = extract_legal_entities(chunk_text)
259
- if entities:
260
- entity_score = sum(e['importance'] for e in entities) / len(entities)
261
- importance *= (1 + entity_score * 0.5)
 
 
 
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  chunks.append({
264
  'id': chunk_id,
265
  'text': chunk_text,
266
  'title': title,
267
  'section_type': current_section,
268
- 'importance_score': importance,
269
- 'entities': entities,
270
- 'embedding': None # Will be filled during indexing
271
  })
272
-
273
- section_sentences = []
274
-
275
- # Add remaining sentences
276
- if section_sentences:
277
- chunk_text = ' '.join(section_sentences)
278
- chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
279
- chunks.append({
280
- 'id': chunk_id,
281
- 'text': chunk_text,
282
- 'title': title,
283
- 'section_type': current_section,
284
- 'importance_score': 1.0,
285
- 'entities': extract_legal_entities(chunk_text),
286
- 'embedding': None
287
- })
288
-
289
- return chunks
290
 
291
- def build_all_indices(chunks: List[Dict[str, Any]]):
292
- """Build all retrieval indices"""
293
- global DENSE_INDEX, BM25_INDEX, CONCEPT_GRAPH, TOKEN_TO_CHUNKS, CHUNKS_DATA
294
-
295
- CHUNKS_DATA = chunks
296
- print(f"Building indices for {len(chunks)} chunks...")
297
-
298
- # 1. Dense embeddings + FAISS index
299
- print("Building FAISS index...")
300
- embeddings = []
301
- for chunk in tqdm(chunks, desc="Creating embeddings"):
302
- embedding = create_embedding(chunk['text'])
303
- chunk['embedding'] = embedding
304
- embeddings.append(embedding)
305
-
306
- embeddings_matrix = np.vstack(embeddings)
307
- DENSE_INDEX = faiss.IndexFlatIP(embeddings_matrix.shape[1]) # Inner product for normalized vectors
308
- DENSE_INDEX.add(embeddings_matrix.astype('float32'))
309
-
310
- # 2. BM25 index for sparse retrieval
311
- print("Building BM25 index...")
312
- tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
313
- BM25_INDEX = BM25Okapi(tokenized_corpus)
314
-
315
- # 3. ColBERT-style token index
316
- print("Building ColBERT token index...")
317
- TOKEN_TO_CHUNKS = defaultdict(set)
318
- for i, chunk in enumerate(chunks):
319
- # Simple tokenization for token-level matching
320
- tokens = chunk['text'].lower().split()
321
- for token in tokens:
322
- TOKEN_TO_CHUNKS[token].add(i)
323
-
324
- # 4. Legal concept graph
325
- print("Building legal concept graph...")
326
- CONCEPT_GRAPH = nx.Graph()
327
-
328
- for i, chunk in enumerate(chunks):
329
- CONCEPT_GRAPH.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
330
-
331
- # Add edges between chunks with shared entities
332
- for j, other_chunk in enumerate(chunks[i+1:], i+1):
333
- shared_entities = set(e['text'] for e in chunk['entities']) & \
334
- set(e['text'] for e in other_chunk['entities'])
335
- if shared_entities:
336
- CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities))
337
-
338
- print("All indices built successfully!")
339
 
340
- def multi_stage_retrieval(query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]:
341
- """Perform multi-stage retrieval combining all techniques"""
342
- candidates = {}
343
-
344
- print("Performing multi-stage retrieval...")
345
-
346
- # Stage 1: Dense retrieval with expanded queries
347
- print("Stage 1: Dense retrieval...")
348
- for query in query_analysis['expanded_queries'][:3]:
349
- query_emb = create_embedding(query)
350
- scores, indices = DENSE_INDEX.search(
351
- query_emb.reshape(1, -1).astype('float32'),
352
- top_k * 2
353
- )
354
-
355
- for idx, score in zip(indices[0], scores[0]):
356
- if idx < len(CHUNKS_DATA):
357
- chunk_id = CHUNKS_DATA[idx]['id']
358
- if chunk_id not in candidates:
359
- candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}}
360
- candidates[chunk_id]['scores']['dense'] = float(score)
361
-
362
- # Stage 2: Sparse retrieval (BM25)
363
- print("Stage 2: Sparse retrieval...")
364
- query_tokens = query_analysis['original_query'].lower().split()
365
- bm25_scores = BM25_INDEX.get_scores(query_tokens)
366
- top_bm25_indices = np.argsort(bm25_scores)[-top_k*2:][::-1]
367
-
368
- for idx in top_bm25_indices:
369
- if idx < len(CHUNKS_DATA):
370
- chunk_id = CHUNKS_DATA[idx]['id']
371
- if chunk_id not in candidates:
372
- candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}}
373
- candidates[chunk_id]['scores']['bm25'] = float(bm25_scores[idx])
374
-
375
- # Stage 3: Entity-based retrieval
376
- print("Stage 3: Entity-based retrieval...")
377
- for entity in query_analysis['entities']:
378
- for chunk in CHUNKS_DATA:
379
- chunk_entity_texts = [e['text'].lower() for e in chunk['entities']]
380
- if entity['text'].lower() in chunk_entity_texts:
381
- chunk_id = chunk['id']
382
  if chunk_id not in candidates:
383
- candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
384
- candidates[chunk_id]['scores']['entity'] = \
385
- candidates[chunk_id]['scores'].get('entity', 0) + entity['importance']
386
-
387
- # Stage 4: Graph-based retrieval
388
- print("Stage 4: Graph-based retrieval...")
389
- if candidates and CONCEPT_GRAPH:
390
- seed_chunks = []
391
- for chunk_id, data in list(candidates.items())[:5]:
392
- for i, chunk in enumerate(CHUNKS_DATA):
393
- if chunk['id'] == chunk_id:
394
- seed_chunks.append(i)
395
- break
396
-
397
- for seed_idx in seed_chunks:
398
- if seed_idx in CONCEPT_GRAPH:
399
- neighbors = list(CONCEPT_GRAPH.neighbors(seed_idx))[:3]
400
- for neighbor_idx in neighbors:
401
- if neighbor_idx < len(CHUNKS_DATA):
402
- chunk = CHUNKS_DATA[neighbor_idx]
403
- chunk_id = chunk['id']
404
- if chunk_id not in candidates:
405
- candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
406
- candidates[chunk_id]['scores']['graph'] = 0.5
407
-
408
- # Combine scores
409
- print("Combining scores...")
410
- weights = {'dense': 0.35, 'bm25': 0.25, 'entity': 0.25, 'graph': 0.15}
411
- final_scores = []
412
-
413
- for chunk_id, data in candidates.items():
414
- chunk = data['chunk']
415
- scores = data['scores']
416
-
417
- final_score = 0
418
- for method, weight in weights.items():
419
- if method in scores:
420
- # Normalize scores
421
- if method == 'dense':
422
- normalized = (scores[method] + 1) / 2 # [-1, 1] to [0, 1]
423
- elif method == 'bm25':
424
- normalized = min(scores[method] / 10, 1)
425
- elif method == 'entity':
426
- normalized = min(scores[method] / 3, 1)
427
- else:
428
- normalized = scores[method]
429
-
430
- final_score += weight * normalized
431
 
432
- # Boost by importance and section relevance
433
- final_score *= chunk['importance_score']
 
 
 
 
 
 
 
 
 
434
 
435
- if query_analysis['query_type'] == 'precedent' and chunk['section_type'] == 'holding':
436
- final_score *= 1.5
437
- elif query_analysis['query_type'] == 'factual' and chunk['section_type'] == 'facts':
438
- final_score *= 1.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
- final_scores.append((chunk, final_score))
441
-
442
- # Sort and return top-k
443
- final_scores.sort(key=lambda x: x[1], reverse=True)
444
- return final_scores[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
- def generate_answer_with_reasoning(query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]:
447
- """Generate answer with legal reasoning"""
448
- if not GROQ_CLIENT:
449
- return {'error': 'Groq client not initialized'}
450
-
451
- # Prepare context
452
- context_parts = []
453
- for i, (chunk, score) in enumerate(retrieved_chunks, 1):
454
- entities = ', '.join([e['text'] for e in chunk['entities'][:3]])
455
- context_parts.append(f"""
456
- Document {i} [{chunk['title']}] - Relevance: {score:.2f}
457
- Section: {chunk['section_type']}
458
- Key Entities: {entities}
459
- Content: {chunk['text'][:800]}
460
- """)
461
-
462
- context = "\n---\n".join(context_parts)
463
-
464
- system_prompt = """You are an expert legal analyst. Provide thorough legal analysis using the IRAC method:
465
- 1. ISSUE: Identify the legal issue(s)
466
- 2. RULE: State the applicable legal rules/precedents
467
- 3. APPLICATION: Apply the rules to the facts
468
- 4. CONCLUSION: Provide a clear conclusion
469
 
470
- CRITICAL: Base ALL responses on the provided document excerpts only. Quote directly when making claims.
471
- If information is not in the excerpts, state "This information is not provided in the available documents."
472
- """
473
-
474
- user_prompt = f"""Query: {query}
475
 
476
- Retrieved Legal Documents:
477
- {context}
478
 
479
- Please provide a comprehensive legal analysis using IRAC method. Cite the documents when making claims."""
480
-
481
- try:
482
- response = GROQ_CLIENT.chat.completions.create(
483
- messages=[
484
- {"role": "system", "content": system_prompt},
485
- {"role": "user", "content": user_prompt}
486
- ],
487
- model="llama-3.1-8b-instant",
488
- temperature=0.1,
489
- max_tokens=1000
490
- )
491
-
492
- answer = response.choices[0].message.content
493
-
494
- # Calculate confidence
495
- avg_score = sum(score for _, score in retrieved_chunks[:3]) / min(3, len(retrieved_chunks))
496
- confidence = min(avg_score * 100, 100)
497
 
498
- return {
499
- 'answer': answer,
500
- 'confidence': confidence,
501
- 'sources': [
502
- {
503
- 'chunk_id': chunk['id'],
504
- 'title': chunk['title'],
505
- 'section': chunk['section_type'],
506
- 'relevance_score': float(score),
507
- 'excerpt': chunk['text'][:200] + '...',
508
- 'entities': [e['text'] for e in chunk['entities'][:5]]
509
- }
510
- for chunk, score in retrieved_chunks
511
- ]
512
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
- except Exception as e:
515
  return {
516
- 'error': f'Error generating answer: {str(e)}',
517
- 'sources': [{'chunk': c['text'][:200], 'score': s} for c, s in retrieved_chunks[:3]]
 
518
  }
519
 
520
- # Main functions for external use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  def process_documents(documents: List[Dict[str, str]]) -> Dict[str, Any]:
522
- """Process documents and build indices"""
523
- all_chunks = []
524
-
525
- for doc in documents:
526
- chunks = chunk_text_hierarchical(doc['text'], doc.get('title', 'Document'))
527
- all_chunks.extend(chunks)
528
-
529
- build_all_indices(all_chunks)
530
-
531
- return {
532
- 'success': True,
533
- 'chunk_count': len(all_chunks),
534
- 'message': f'Processed {len(documents)} documents into {len(all_chunks)} chunks'
535
- }
536
 
537
  def query_documents(query: str, top_k: int = 5) -> Dict[str, Any]:
538
- """Main query function - takes query, returns answer with sources"""
539
- if not CHUNKS_DATA:
540
- return {'error': 'No documents indexed. Call process_documents first.'}
541
-
542
- # Analyze query
543
- query_analysis = analyze_query(query)
544
-
545
- # Multi-stage retrieval
546
- retrieved_chunks = multi_stage_retrieval(query_analysis, top_k)
547
-
548
- if not retrieved_chunks:
549
- return {
550
- 'error': 'No relevant documents found',
551
- 'query_analysis': query_analysis
552
- }
553
-
554
- # Generate answer
555
- result = generate_answer_with_reasoning(query, retrieved_chunks)
556
- result['query_analysis'] = query_analysis
557
-
558
- return result
559
 
560
  def search_chunks_simple(query: str, top_k: int = 3) -> List[Dict[str, Any]]:
561
- """Simple search function for compatibility"""
562
- if not CHUNKS_DATA:
563
- return []
564
-
565
- query_analysis = analyze_query(query)
566
- retrieved_chunks = multi_stage_retrieval(query_analysis, top_k)
567
-
568
- results = []
569
- for chunk, score in retrieved_chunks:
570
- results.append({
571
- 'chunk': {
572
- 'id': chunk['id'],
573
- 'text': chunk['text'],
574
- 'title': chunk['title']
575
- },
576
- 'score': score
577
- })
578
-
579
- return results
580
 
581
  def generate_conservative_answer(query: str, context_chunks: List[Dict[str, Any]]) -> str:
582
- """Generate conservative answer - for compatibility"""
583
- if not context_chunks:
584
- return "No relevant information found."
585
-
586
- # Convert format
587
- retrieved_chunks = [(chunk['chunk'], chunk['score']) for chunk in context_chunks]
588
- result = generate_answer_with_reasoning(query, retrieved_chunks)
589
-
590
- if 'error' in result:
591
- return result['error']
592
-
593
- return result.get('answer', 'Unable to generate answer.')
 
14
  import spacy
15
  from rank_bm25 import BM25Okapi
16
 
17
+ # Global model instances (shared across sessions)
18
+ _SHARED_MODEL = None
19
+ _SHARED_TOKENIZER = None
20
+ _SHARED_NLP_MODEL = None
21
+ _DEVICE = None
 
22
 
23
+ # Legal knowledge base (shared constants)
 
 
 
 
 
 
 
24
  LEGAL_CONCEPTS = {
25
  'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'],
26
  'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'],
 
38
  }
39
 
40
  def initialize_models(model_id: str, groq_api_key: str = None):
41
+ """Initialize shared models (call once at startup)"""
42
+ global _SHARED_MODEL, _SHARED_TOKENIZER, _SHARED_NLP_MODEL, _DEVICE
43
 
44
  try:
45
  nltk.download('punkt', quiet=True)
 
47
  except:
48
  pass
49
 
50
+ _DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ print(f"Using device: {_DEVICE}")
52
 
53
  print(f"Loading model: {model_id}")
54
+ _SHARED_TOKENIZER = AutoTokenizer.from_pretrained(model_id)
55
+ _SHARED_MODEL = AutoModel.from_pretrained(model_id).to(_DEVICE)
56
+ _SHARED_MODEL.eval()
 
 
 
57
 
58
  try:
59
+ _SHARED_NLP_MODEL = spacy.load("en_core_web_sm")
60
  except:
61
  print("SpaCy model not found, using basic NER")
62
+ _SHARED_NLP_MODEL = None
63
 
64
+ class SessionRAG:
65
+ """Session-specific RAG instance"""
 
 
66
 
67
+ def __init__(self, session_id: str, groq_api_key: str = None):
68
+ self.session_id = session_id
69
+ self.groq_client = Groq(api_key=groq_api_key) if groq_api_key else None
 
 
 
70
 
71
+ # Session-specific indices and data
72
+ self.dense_index = None
73
+ self.bm25_index = None
74
+ self.concept_graph = None
75
+ self.token_to_chunks = None
76
+ self.chunks_data = []
77
 
78
+ # Verify shared models are initialized
79
+ if _SHARED_MODEL is None or _SHARED_TOKENIZER is None:
80
+ raise ValueError("Models not initialized. Call initialize_models() first.")
81
+
82
+ def create_embedding(self, text: str) -> np.ndarray:
83
+ """Create dense embedding for text"""
84
+ inputs = _SHARED_TOKENIZER(text, padding=True, truncation=True,
85
+ max_length=512, return_tensors='pt').to(_DEVICE)
86
+
87
+ with torch.no_grad():
88
+ outputs = _SHARED_MODEL(**inputs)
89
+ attention_mask = inputs['attention_mask']
90
+ token_embeddings = outputs.last_hidden_state
91
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
92
+ embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
93
+
94
+ # Normalize embeddings
95
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
96
+
97
+ return embeddings.cpu().numpy()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ def extract_legal_entities(self, text: str) -> List[Dict[str, Any]]:
100
+ """Extract legal entities from text"""
101
+ entities = []
102
+
103
+ if _SHARED_NLP_MODEL:
104
+ doc = _SHARED_NLP_MODEL(text[:5000]) # Limit for performance
105
+ for ent in doc.ents:
106
+ if ent.label_ in ['PERSON', 'ORG', 'LAW', 'GPE']:
107
+ entities.append({
108
+ 'text': ent.text,
109
+ 'type': ent.label_,
110
+ 'importance': 1.0
111
+ })
112
+
113
+ # Legal citations
114
+ citation_pattern = r'\b\d+\s+[A-Z][a-z]+\.?\s+\d+\b'
115
+ for match in re.finditer(citation_pattern, text):
116
+ entities.append({
117
+ 'text': match.group(),
118
+ 'type': 'case_citation',
119
+ 'importance': 2.0
120
+ })
121
+
122
+ # Statute references
123
+ statute_pattern = r'§\s*\d+[\.\d]*|\bSection\s+\d+'
124
+ for match in re.finditer(statute_pattern, text):
125
+ entities.append({
126
+ 'text': match.group(),
127
+ 'type': 'statute',
128
+ 'importance': 1.5
129
+ })
130
+
131
+ return entities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ def analyze_query(self, query: str) -> Dict[str, Any]:
134
+ """Analyze query to understand intent"""
135
+ query_lower = query.lower()
136
+
137
+ # Classify query type
138
+ query_type = 'general'
139
+ for qtype, patterns in QUERY_PATTERNS.items():
140
+ if any(pattern in query_lower for pattern in patterns):
141
+ query_type = qtype
142
+ break
143
+
144
+ # Extract entities
145
+ entities = self.extract_legal_entities(query)
146
+
147
+ # Extract key concepts
148
+ key_concepts = []
149
+ for concept_category, concepts in LEGAL_CONCEPTS.items():
150
+ for concept in concepts:
151
+ if concept in query_lower:
152
+ key_concepts.append(concept)
153
+
154
+ # Generate expanded queries
155
+ expanded_queries = [query]
156
+
157
+ # Concept expansion
158
+ if key_concepts:
159
+ expanded_queries.append(f"{query} {' '.join(key_concepts[:3])}")
160
+
161
+ # Type-based expansion
162
+ if query_type == 'precedent':
163
+ expanded_queries.append(f"legal precedent case law {query}")
164
+ elif query_type == 'statute_interpretation':
165
+ expanded_queries.append(f"statutory interpretation meaning {query}")
166
+
167
+ # HyDE - Hypothetical document generation
168
+ if self.groq_client:
169
+ hyde_doc = self.generate_hypothetical_document(query)
170
+ if hyde_doc:
171
+ expanded_queries.append(hyde_doc)
172
+
173
+ return {
174
+ 'original_query': query,
175
+ 'query_type': query_type,
176
+ 'entities': entities,
177
+ 'key_concepts': key_concepts,
178
+ 'expanded_queries': expanded_queries[:4] # Limit to 4 queries
179
+ }
180
 
181
+ def generate_hypothetical_document(self, query: str) -> Optional[str]:
182
+ """Generate hypothetical answer document (HyDE technique)"""
183
+ if not self.groq_client:
184
+ return None
185
+
186
+ try:
187
+ prompt = f"""Generate a brief hypothetical legal document excerpt that would answer this question: {query}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ Write it as if it's from an actual legal case or statute. Be specific and use legal language.
190
+ Keep it under 100 words."""
 
 
 
 
 
191
 
192
+ response = self.groq_client.chat.completions.create(
193
+ messages=[
194
+ {"role": "system", "content": "You are a legal expert generating hypothetical legal text."},
195
+ {"role": "user", "content": prompt}
196
+ ],
197
+ model="llama-3.1-8b-instant",
198
+ temperature=0.3,
199
+ max_tokens=150
200
+ )
201
 
202
+ return response.choices[0].message.content
203
+ except:
204
+ return None
205
+
206
+ def chunk_text_hierarchical(self, text: str, title: str = "") -> List[Dict[str, Any]]:
207
+ """Create hierarchical chunks with legal structure awareness"""
208
+ chunks = []
209
+
210
+ # Clean text
211
+ text = re.sub(r'\s+', ' ', text)
212
+
213
+ # Identify legal sections
214
+ section_patterns = [
215
+ (r'(?i)\bFACTS?\b[:\s]', 'facts'),
216
+ (r'(?i)\bHOLDING\b[:\s]', 'holding'),
217
+ (r'(?i)\bREASONING\b[:\s]', 'reasoning'),
218
+ (r'(?i)\bDISSENT\b[:\s]', 'dissent'),
219
+ (r'(?i)\bCONCLUSION\b[:\s]', 'conclusion')
220
+ ]
221
+
222
+ sections = []
223
+ for pattern, section_type in section_patterns:
224
+ matches = list(re.finditer(pattern, text))
225
+ for match in matches:
226
+ sections.append((match.start(), section_type))
227
+
228
+ sections.sort(key=lambda x: x[0])
229
+
230
+ # Split into sentences
231
+ import nltk
232
+ try:
233
+ sentences = nltk.sent_tokenize(text)
234
+ except:
235
+ sentences = text.split('. ')
236
+
237
+ # Create chunks
238
+ current_section = 'introduction'
239
+ section_sentences = []
240
+ chunk_size = 500 # words
241
+
242
+ for sent in sentences:
243
+ # Check section type
244
+ sent_pos = text.find(sent)
245
+ for pos, stype in sections:
246
+ if sent_pos >= pos:
247
+ current_section = stype
248
+
249
+ section_sentences.append(sent)
250
+
251
+ # Create chunk when we have enough content
252
+ chunk_text = ' '.join(section_sentences)
253
+ if len(chunk_text.split()) >= chunk_size or len(section_sentences) >= 10:
254
+ chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
255
+
256
+ # Calculate importance
257
+ importance = 1.0
258
+ section_weights = {
259
+ 'holding': 2.0, 'conclusion': 1.8, 'reasoning': 1.5,
260
+ 'facts': 1.2, 'dissent': 0.8
261
+ }
262
+ importance *= section_weights.get(current_section, 1.0)
263
+
264
+ # Entity importance
265
+ entities = self.extract_legal_entities(chunk_text)
266
+ if entities:
267
+ entity_score = sum(e['importance'] for e in entities) / len(entities)
268
+ importance *= (1 + entity_score * 0.5)
269
+
270
+ chunks.append({
271
+ 'id': chunk_id,
272
+ 'text': chunk_text,
273
+ 'title': title,
274
+ 'section_type': current_section,
275
+ 'importance_score': importance,
276
+ 'entities': entities,
277
+ 'embedding': None # Will be filled during indexing
278
+ })
279
+
280
+ section_sentences = []
281
+
282
+ # Add remaining sentences
283
+ if section_sentences:
284
+ chunk_text = ' '.join(section_sentences)
285
+ chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
286
  chunks.append({
287
  'id': chunk_id,
288
  'text': chunk_text,
289
  'title': title,
290
  'section_type': current_section,
291
+ 'importance_score': 1.0,
292
+ 'entities': self.extract_legal_entities(chunk_text),
293
+ 'embedding': None
294
  })
295
+
296
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ def build_all_indices(self, chunks: List[Dict[str, Any]]):
299
+ """Build all retrieval indices for this session"""
300
+ self.chunks_data = chunks
301
+ print(f"Building indices for session {self.session_id}: {len(chunks)} chunks...")
302
+
303
+ # 1. Dense embeddings + FAISS index
304
+ print("Building FAISS index...")
305
+ embeddings = []
306
+ for chunk in tqdm(chunks, desc="Creating embeddings"):
307
+ embedding = self.create_embedding(chunk['text'])
308
+ chunk['embedding'] = embedding
309
+ embeddings.append(embedding)
310
+
311
+ embeddings_matrix = np.vstack(embeddings)
312
+ self.dense_index = faiss.IndexFlatIP(embeddings_matrix.shape[1]) # Inner product for normalized vectors
313
+ self.dense_index.add(embeddings_matrix.astype('float32'))
314
+
315
+ # 2. BM25 index for sparse retrieval
316
+ print("Building BM25 index...")
317
+ tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
318
+ self.bm25_index = BM25Okapi(tokenized_corpus)
319
+
320
+ # 3. ColBERT-style token index
321
+ print("Building ColBERT token index...")
322
+ self.token_to_chunks = defaultdict(set)
323
+ for i, chunk in enumerate(chunks):
324
+ # Simple tokenization for token-level matching
325
+ tokens = chunk['text'].lower().split()
326
+ for token in tokens:
327
+ self.token_to_chunks[token].add(i)
328
+
329
+ # 4. Legal concept graph
330
+ print("Building legal concept graph...")
331
+ self.concept_graph = nx.Graph()
332
+
333
+ for i, chunk in enumerate(chunks):
334
+ self.concept_graph.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
335
+
336
+ # Add edges between chunks with shared entities
337
+ for j, other_chunk in enumerate(chunks[i+1:], i+1):
338
+ shared_entities = set(e['text'] for e in chunk['entities']) & \
339
+ set(e['text'] for e in other_chunk['entities'])
340
+ if shared_entities:
341
+ self.concept_graph.add_edge(i, j, weight=len(shared_entities))
342
+
343
+ print(f"All indices built successfully for session {self.session_id}!")
 
 
344
 
345
+ def multi_stage_retrieval(self, query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]:
346
+ """Perform multi-stage retrieval combining all techniques"""
347
+ candidates = {}
348
+
349
+ print(f"Performing multi-stage retrieval for session {self.session_id}...")
350
+
351
+ # Stage 1: Dense retrieval with expanded queries
352
+ print("Stage 1: Dense retrieval...")
353
+ for query in query_analysis['expanded_queries'][:3]:
354
+ query_emb = self.create_embedding(query)
355
+ scores, indices = self.dense_index.search(
356
+ query_emb.reshape(1, -1).astype('float32'),
357
+ top_k * 2
358
+ )
359
+
360
+ for idx, score in zip(indices[0], scores[0]):
361
+ if idx < len(self.chunks_data):
362
+ chunk_id = self.chunks_data[idx]['id']
363
+ if chunk_id not in candidates:
364
+ candidates[chunk_id] = {'chunk': self.chunks_data[idx], 'scores': {}}
365
+ candidates[chunk_id]['scores']['dense'] = float(score)
366
+
367
+ # Stage 2: Sparse retrieval (BM25)
368
+ print("Stage 2: Sparse retrieval...")
369
+ query_tokens = query_analysis['original_query'].lower().split()
370
+ bm25_scores = self.bm25_index.get_scores(query_tokens)
371
+ top_bm25_indices = np.argsort(bm25_scores)[-top_k*2:][::-1]
372
+
373
+ for idx in top_bm25_indices:
374
+ if idx < len(self.chunks_data):
375
+ chunk_id = self.chunks_data[idx]['id']
 
 
 
 
 
 
 
 
 
 
 
376
  if chunk_id not in candidates:
377
+ candidates[chunk_id] = {'chunk': self.chunks_data[idx], 'scores': {}}
378
+ candidates[chunk_id]['scores']['bm25'] = float(bm25_scores[idx])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
+ # Stage 3: Entity-based retrieval
381
+ print("Stage 3: Entity-based retrieval...")
382
+ for entity in query_analysis['entities']:
383
+ for chunk in self.chunks_data:
384
+ chunk_entity_texts = [e['text'].lower() for e in chunk['entities']]
385
+ if entity['text'].lower() in chunk_entity_texts:
386
+ chunk_id = chunk['id']
387
+ if chunk_id not in candidates:
388
+ candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
389
+ candidates[chunk_id]['scores']['entity'] = \
390
+ candidates[chunk_id]['scores'].get('entity', 0) + entity['importance']
391
 
392
+ # Stage 4: Graph-based retrieval
393
+ print("Stage 4: Graph-based retrieval...")
394
+ if candidates and self.concept_graph:
395
+ seed_chunks = []
396
+ for chunk_id, data in list(candidates.items())[:5]:
397
+ for i, chunk in enumerate(self.chunks_data):
398
+ if chunk['id'] == chunk_id:
399
+ seed_chunks.append(i)
400
+ break
401
+
402
+ for seed_idx in seed_chunks:
403
+ if seed_idx in self.concept_graph:
404
+ neighbors = list(self.concept_graph.neighbors(seed_idx))[:3]
405
+ for neighbor_idx in neighbors:
406
+ if neighbor_idx < len(self.chunks_data):
407
+ chunk = self.chunks_data[neighbor_idx]
408
+ chunk_id = chunk['id']
409
+ if chunk_id not in candidates:
410
+ candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
411
+ candidates[chunk_id]['scores']['graph'] = 0.5
412
 
413
+ # Combine scores
414
+ print("Combining scores...")
415
+ weights = {'dense': 0.35, 'bm25': 0.25, 'entity': 0.25, 'graph': 0.15}
416
+ final_scores = []
417
+
418
+ for chunk_id, data in candidates.items():
419
+ chunk = data['chunk']
420
+ scores = data['scores']
421
+
422
+ final_score = 0
423
+ for method, weight in weights.items():
424
+ if method in scores:
425
+ # Normalize scores
426
+ if method == 'dense':
427
+ normalized = (scores[method] + 1) / 2 # [-1, 1] to [0, 1]
428
+ elif method == 'bm25':
429
+ normalized = min(scores[method] / 10, 1)
430
+ elif method == 'entity':
431
+ normalized = min(scores[method] / 3, 1)
432
+ else:
433
+ normalized = scores[method]
434
+
435
+ final_score += weight * normalized
436
+
437
+ # Boost by importance and section relevance
438
+ final_score *= chunk['importance_score']
439
+
440
+ if query_analysis['query_type'] == 'precedent' and chunk['section_type'] == 'holding':
441
+ final_score *= 1.5
442
+ elif query_analysis['query_type'] == 'factual' and chunk['section_type'] == 'facts':
443
+ final_score *= 1.5
444
+
445
+ final_scores.append((chunk, final_score))
446
+
447
+ # Sort and return top-k
448
+ final_scores.sort(key=lambda x: x[1], reverse=True)
449
+ return final_scores[:top_k]
450
 
451
+ def generate_answer_with_reasoning(self, query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]:
452
+ """Generate answer with legal reasoning"""
453
+ if not self.groq_client:
454
+ return {'error': 'Groq client not initialized'}
455
+
456
+ # Prepare context
457
+ context_parts = []
458
+ for i, (chunk, score) in enumerate(retrieved_chunks, 1):
459
+ entities = ', '.join([e['text'] for e in chunk['entities'][:3]])
460
+ context_parts.append(f"""
461
+ Document {i} [{chunk['title']}] - Relevance: {score:.2f}
462
+ Section: {chunk['section_type']}
463
+ Key Entities: {entities}
464
+ Content: {chunk['text'][:800]}
465
+ """)
466
+
467
+ context = "\n---\n".join(context_parts)
468
+
469
+ system_prompt = """You are an expert legal analyst. Provide thorough legal analysis using the IRAC method:
470
+ 1. ISSUE: Identify the legal issue(s)
471
+ 2. RULE: State the applicable legal rules/precedents
472
+ 3. APPLICATION: Apply the rules to the facts
473
+ 4. CONCLUSION: Provide a clear conclusion
474
 
475
+ CRITICAL: Base ALL responses on the provided document excerpts only. Quote directly when making claims.
476
+ If information is not in the excerpts, state "This information is not provided in the available documents."
477
+ """
478
+
479
+ user_prompt = f"""Query: {query}
480
 
481
+ Retrieved Legal Documents:
482
+ {context}
483
 
484
+ Please provide a comprehensive legal analysis using IRAC method. Cite the documents when making claims."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
+ try:
487
+ response = self.groq_client.chat.completions.create(
488
+ messages=[
489
+ {"role": "system", "content": system_prompt},
490
+ {"role": "user", "content": user_prompt}
491
+ ],
492
+ model="llama-3.1-8b-instant",
493
+ temperature=0.1,
494
+ max_tokens=1000
495
+ )
496
+
497
+ answer = response.choices[0].message.content
498
+
499
+ # Calculate confidence
500
+ avg_score = sum(score for _, score in retrieved_chunks[:3]) / min(3, len(retrieved_chunks))
501
+ confidence = min(avg_score * 100, 100)
502
+
503
+ return {
504
+ 'answer': answer,
505
+ 'confidence': confidence,
506
+ 'sources': [
507
+ {
508
+ 'chunk_id': chunk['id'],
509
+ 'title': chunk['title'],
510
+ 'section': chunk['section_type'],
511
+ 'relevance_score': float(score),
512
+ 'excerpt': chunk['text'][:200] + '...',
513
+ 'entities': [e['text'] for e in chunk['entities'][:5]]
514
+ }
515
+ for chunk, score in retrieved_chunks
516
+ ]
517
+ }
518
+
519
+ except Exception as e:
520
+ return {
521
+ 'error': f'Error generating answer: {str(e)}',
522
+ 'sources': [{'chunk': c['text'][:200], 'score': s} for c, s in retrieved_chunks[:3]]
523
+ }
524
+
525
+ def process_documents(self, documents: List[Dict[str, str]]) -> Dict[str, Any]:
526
+ """Process documents and build indices for this session"""
527
+ all_chunks = []
528
+
529
+ for doc in documents:
530
+ chunks = self.chunk_text_hierarchical(doc['text'], doc.get('title', 'Document'))
531
+ all_chunks.extend(chunks)
532
+
533
+ self.build_all_indices(all_chunks)
534
 
 
535
  return {
536
+ 'success': True,
537
+ 'chunk_count': len(all_chunks),
538
+ 'message': f'Processed {len(documents)} documents into {len(all_chunks)} chunks for session {self.session_id}'
539
  }
540
 
541
+ def query_documents(self, query: str, top_k: int = 5) -> Dict[str, Any]:
542
+ """Main query function - takes query, returns answer with sources"""
543
+ if not self.chunks_data:
544
+ return {'error': f'No documents indexed for session {self.session_id}. Call process_documents first.'}
545
+
546
+ # Analyze query
547
+ query_analysis = self.analyze_query(query)
548
+
549
+ # Multi-stage retrieval
550
+ retrieved_chunks = self.multi_stage_retrieval(query_analysis, top_k)
551
+
552
+ if not retrieved_chunks:
553
+ return {
554
+ 'error': 'No relevant documents found',
555
+ 'query_analysis': query_analysis
556
+ }
557
+
558
+ # Generate answer
559
+ result = self.generate_answer_with_reasoning(query, retrieved_chunks)
560
+ result['query_analysis'] = query_analysis
561
+
562
+ return result
563
+
564
+ def search_chunks_simple(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
565
+ """Simple search function for compatibility"""
566
+ if not self.chunks_data:
567
+ return []
568
+
569
+ query_analysis = self.analyze_query(query)
570
+ retrieved_chunks = self.multi_stage_retrieval(query_analysis, top_k)
571
+
572
+ results = []
573
+ for chunk, score in retrieved_chunks:
574
+ results.append({
575
+ 'chunk': {
576
+ 'id': chunk['id'],
577
+ 'text': chunk['text'],
578
+ 'title': chunk['title']
579
+ },
580
+ 'score': score
581
+ })
582
+
583
+ return results
584
+
585
+ def generate_conservative_answer(self, query: str, context_chunks: List[Dict[str, Any]]) -> str:
586
+ """Generate conservative answer - for compatibility"""
587
+ if not context_chunks:
588
+ return "No relevant information found."
589
+
590
+ # Convert format
591
+ retrieved_chunks = [(chunk['chunk'], chunk['score']) for chunk in context_chunks]
592
+ result = self.generate_answer_with_reasoning(query, retrieved_chunks)
593
+
594
+ if 'error' in result:
595
+ return result['error']
596
+
597
+ return result.get('answer', 'Unable to generate answer.')
598
+
599
+ # Backward compatibility functions (deprecated - use SessionRAG instead)
600
  def process_documents(documents: List[Dict[str, str]]) -> Dict[str, Any]:
601
+ """Deprecated: Use SessionRAG.process_documents() instead"""
602
+ raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
  def query_documents(query: str, top_k: int = 5) -> Dict[str, Any]:
605
+ """Deprecated: Use SessionRAG.query_documents() instead"""
606
+ raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
 
608
  def search_chunks_simple(query: str, top_k: int = 3) -> List[Dict[str, Any]]:
609
+ """Deprecated: Use SessionRAG.search_chunks_simple() instead"""
610
+ raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
 
612
  def generate_conservative_answer(query: str, context_chunks: List[Dict[str, Any]]) -> str:
613
+ """Deprecated: Use SessionRAG.generate_conservative_answer() instead"""
614
+ raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")