kn29 commited on
Commit
fef0353
·
verified ·
1 Parent(s): 4790f1a

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +78 -0
rag.py CHANGED
@@ -96,6 +96,84 @@ class SessionRAG:
96
 
97
  return embeddings.cpu().numpy()[0]
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def extract_legal_entities(self, text: str) -> List[Dict[str, Any]]:
100
  """Extract legal entities from text"""
101
  entities = []
 
96
 
97
  return embeddings.cpu().numpy()[0]
98
 
99
+ def load_existing_session_data(self, chunks_from_db: List[Dict[str, Any]]):
100
+ """Load pre-existing chunks with embeddings from database"""
101
+ print(f"Loading existing session data for {self.session_id}: {len(chunks_from_db)} chunks...")
102
+
103
+ # Process chunks from MongoDB format
104
+ self.chunks_data = self.process_db_chunks(chunks_from_db)
105
+
106
+ # Rebuild indices from existing embeddings (don't recreate embeddings)
107
+ self.rebuild_indices_from_existing_embeddings()
108
+
109
+ print(f"Session {self.session_id} loaded with existing embeddings!")
110
+
111
+ def rebuild_indices_from_existing_embeddings(self):
112
+ """Rebuild search indices using existing embeddings from database"""
113
+ if not self.chunks_data:
114
+ raise ValueError("No chunks data available")
115
+
116
+ print(f"Rebuilding indices from existing embeddings...")
117
+
118
+ # Extract existing embeddings
119
+ embeddings = []
120
+ for chunk in self.chunks_data:
121
+ if 'embedding' in chunk and chunk['embedding'] is not None:
122
+ embeddings.append(chunk['embedding'])
123
+ else:
124
+ raise ValueError(f"Missing embedding for chunk {chunk.get('id', 'unknown')}")
125
+
126
+ # Build FAISS index from existing embeddings
127
+ embeddings_matrix = np.vstack(embeddings)
128
+ self.dense_index = faiss.IndexFlatIP(embeddings_matrix.shape[1])
129
+ self.dense_index.add(embeddings_matrix.astype('float32'))
130
+
131
+ # Build other indices
132
+ tokenized_corpus = [chunk['text'].lower().split() for chunk in self.chunks_data]
133
+ self.bm25_index = BM25Okapi(tokenized_corpus)
134
+
135
+ # 3. ColBERT-style token index
136
+ self.token_to_chunks = defaultdict(set)
137
+ for i, chunk in enumerate(self.chunks_data):
138
+ tokens = chunk['text'].lower().split()
139
+ for token in tokens:
140
+ self.token_to_chunks[token].add(i)
141
+
142
+ # 4. Legal concept graph
143
+ self.concept_graph = nx.Graph()
144
+ for i, chunk in enumerate(self.chunks_data):
145
+ self.concept_graph.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
146
+
147
+ for j, other_chunk in enumerate(self.chunks_data[i+1:], i+1):
148
+ shared_entities = set(e['text'] for e in chunk['entities']) & \
149
+ set(e['text'] for e in other_chunk['entities'])
150
+ if shared_entities:
151
+ self.concept_graph.add_edge(i, j, weight=len(shared_entities))
152
+
153
+ print(f"All indices rebuilt from existing embeddings for session {self.session_id}!")
154
+
155
+ def process_db_chunks(self, chunks_from_db: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
156
+ """Convert MongoDB chunk format to internal format"""
157
+ processed_chunks = []
158
+ for chunk in chunks_from_db:
159
+ # Convert embedding from list to numpy array if needed
160
+ embedding = chunk.get('embedding')
161
+ if embedding and isinstance(embedding, list):
162
+ embedding = np.array(embedding)
163
+
164
+ processed_chunk = {
165
+ 'id': chunk.get('chunk_id', chunk.get('id')),
166
+ 'text': chunk.get('content', chunk.get('text', '')),
167
+ 'title': chunk.get('title', 'Document'),
168
+ 'section_type': chunk.get('section_type', 'general'),
169
+ 'importance_score': chunk.get('importance_score', 1.0),
170
+ 'entities': chunk.get('entities', []),
171
+ 'embedding': embedding
172
+ }
173
+ processed_chunks.append(processed_chunk)
174
+
175
+ return processed_chunks
176
+
177
  def extract_legal_entities(self, text: str) -> List[Dict[str, Any]]:
178
  """Extract legal entities from text"""
179
  entities = []