Update rag.py
Browse files
rag.py
CHANGED
|
@@ -96,6 +96,84 @@ class SessionRAG:
|
|
| 96 |
|
| 97 |
return embeddings.cpu().numpy()[0]
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def extract_legal_entities(self, text: str) -> List[Dict[str, Any]]:
|
| 100 |
"""Extract legal entities from text"""
|
| 101 |
entities = []
|
|
|
|
| 96 |
|
| 97 |
return embeddings.cpu().numpy()[0]
|
| 98 |
|
| 99 |
+
def load_existing_session_data(self, chunks_from_db: List[Dict[str, Any]]):
|
| 100 |
+
"""Load pre-existing chunks with embeddings from database"""
|
| 101 |
+
print(f"Loading existing session data for {self.session_id}: {len(chunks_from_db)} chunks...")
|
| 102 |
+
|
| 103 |
+
# Process chunks from MongoDB format
|
| 104 |
+
self.chunks_data = self.process_db_chunks(chunks_from_db)
|
| 105 |
+
|
| 106 |
+
# Rebuild indices from existing embeddings (don't recreate embeddings)
|
| 107 |
+
self.rebuild_indices_from_existing_embeddings()
|
| 108 |
+
|
| 109 |
+
print(f"Session {self.session_id} loaded with existing embeddings!")
|
| 110 |
+
|
| 111 |
+
def rebuild_indices_from_existing_embeddings(self):
|
| 112 |
+
"""Rebuild search indices using existing embeddings from database"""
|
| 113 |
+
if not self.chunks_data:
|
| 114 |
+
raise ValueError("No chunks data available")
|
| 115 |
+
|
| 116 |
+
print(f"Rebuilding indices from existing embeddings...")
|
| 117 |
+
|
| 118 |
+
# Extract existing embeddings
|
| 119 |
+
embeddings = []
|
| 120 |
+
for chunk in self.chunks_data:
|
| 121 |
+
if 'embedding' in chunk and chunk['embedding'] is not None:
|
| 122 |
+
embeddings.append(chunk['embedding'])
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Missing embedding for chunk {chunk.get('id', 'unknown')}")
|
| 125 |
+
|
| 126 |
+
# Build FAISS index from existing embeddings
|
| 127 |
+
embeddings_matrix = np.vstack(embeddings)
|
| 128 |
+
self.dense_index = faiss.IndexFlatIP(embeddings_matrix.shape[1])
|
| 129 |
+
self.dense_index.add(embeddings_matrix.astype('float32'))
|
| 130 |
+
|
| 131 |
+
# Build other indices
|
| 132 |
+
tokenized_corpus = [chunk['text'].lower().split() for chunk in self.chunks_data]
|
| 133 |
+
self.bm25_index = BM25Okapi(tokenized_corpus)
|
| 134 |
+
|
| 135 |
+
# 3. ColBERT-style token index
|
| 136 |
+
self.token_to_chunks = defaultdict(set)
|
| 137 |
+
for i, chunk in enumerate(self.chunks_data):
|
| 138 |
+
tokens = chunk['text'].lower().split()
|
| 139 |
+
for token in tokens:
|
| 140 |
+
self.token_to_chunks[token].add(i)
|
| 141 |
+
|
| 142 |
+
# 4. Legal concept graph
|
| 143 |
+
self.concept_graph = nx.Graph()
|
| 144 |
+
for i, chunk in enumerate(self.chunks_data):
|
| 145 |
+
self.concept_graph.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
|
| 146 |
+
|
| 147 |
+
for j, other_chunk in enumerate(self.chunks_data[i+1:], i+1):
|
| 148 |
+
shared_entities = set(e['text'] for e in chunk['entities']) & \
|
| 149 |
+
set(e['text'] for e in other_chunk['entities'])
|
| 150 |
+
if shared_entities:
|
| 151 |
+
self.concept_graph.add_edge(i, j, weight=len(shared_entities))
|
| 152 |
+
|
| 153 |
+
print(f"All indices rebuilt from existing embeddings for session {self.session_id}!")
|
| 154 |
+
|
| 155 |
+
def process_db_chunks(self, chunks_from_db: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 156 |
+
"""Convert MongoDB chunk format to internal format"""
|
| 157 |
+
processed_chunks = []
|
| 158 |
+
for chunk in chunks_from_db:
|
| 159 |
+
# Convert embedding from list to numpy array if needed
|
| 160 |
+
embedding = chunk.get('embedding')
|
| 161 |
+
if embedding and isinstance(embedding, list):
|
| 162 |
+
embedding = np.array(embedding)
|
| 163 |
+
|
| 164 |
+
processed_chunk = {
|
| 165 |
+
'id': chunk.get('chunk_id', chunk.get('id')),
|
| 166 |
+
'text': chunk.get('content', chunk.get('text', '')),
|
| 167 |
+
'title': chunk.get('title', 'Document'),
|
| 168 |
+
'section_type': chunk.get('section_type', 'general'),
|
| 169 |
+
'importance_score': chunk.get('importance_score', 1.0),
|
| 170 |
+
'entities': chunk.get('entities', []),
|
| 171 |
+
'embedding': embedding
|
| 172 |
+
}
|
| 173 |
+
processed_chunks.append(processed_chunk)
|
| 174 |
+
|
| 175 |
+
return processed_chunks
|
| 176 |
+
|
| 177 |
def extract_legal_entities(self, text: str) -> List[Dict[str, Any]]:
|
| 178 |
"""Extract legal entities from text"""
|
| 179 |
entities = []
|