kisejin commited on
Commit
0e76626
·
verified ·
1 Parent(s): e4ff17f

Update BERTopic/bertopic/_bertopic.py

Browse files
Files changed (1) hide show
  1. BERTopic/bertopic/_bertopic.py +11 -1
BERTopic/bertopic/_bertopic.py CHANGED
@@ -535,7 +535,10 @@ class BERTopic:
535
  logger.info("Clustering - Approximating new points with `hdbscan_model`")
536
  if is_supported_hdbscan(self.hdbscan_model):
537
  predictions, probabilities = hdbscan_delegator(self.hdbscan_model, "approximate_predict", umap_embeddings)
538
-
 
 
 
539
  # Calculate probabilities
540
  if self.calculate_probabilities:
541
  logger.info("Probabilities - Start calculation of probabilities with HDBSCAN")
@@ -548,9 +551,16 @@ class BERTopic:
548
 
549
  # Map probabilities and predictions
550
  probabilities = self._map_probabilities(probabilities, original_topics=True)
 
551
  predictions = self._map_predictions(predictions)
 
 
552
  return predictions, probabilities
553
 
 
 
 
 
554
  def partial_fit(self,
555
  documents: List[str],
556
  embeddings: np.ndarray = None,
 
535
  logger.info("Clustering - Approximating new points with `hdbscan_model`")
536
  if is_supported_hdbscan(self.hdbscan_model):
537
  predictions, probabilities = hdbscan_delegator(self.hdbscan_model, "approximate_predict", umap_embeddings)
538
+
539
+ # Show all proba of topic in one sentence
540
+ self.probabilities_transform = hdbscan_delegator(self.hdbscan_model, "membership_vector", umap_embeddings)
541
+
542
  # Calculate probabilities
543
  if self.calculate_probabilities:
544
  logger.info("Probabilities - Start calculation of probabilities with HDBSCAN")
 
551
 
552
  # Map probabilities and predictions
553
  probabilities = self._map_probabilities(probabilities, original_topics=True)
554
+ self.probabilities_transform = self._map_probabilities(self.probabilities_transform, original_topics=True)
555
  predictions = self._map_predictions(predictions)
556
+
557
+ self.predictions_transform = predictions
558
  return predictions, probabilities
559
 
560
+ def get_result_transform(self):
561
+ return self.predictions_transform, self.probabilities_transform
562
+
563
+
564
  def partial_fit(self,
565
  documents: List[str],
566
  embeddings: np.ndarray = None,