vxa8502 commited on
Commit
5c74b4d
·
1 Parent(s): daa087a

Add percentile comment

Browse files
Files changed (2) hide show
  1. sage/core/chunking.py +3 -0
  2. tests/test_chunking.py +3 -2
sage/core/chunking.py CHANGED
@@ -161,6 +161,9 @@ def find_split_points(
161
  if not similarities:
162
  return []
163
 
 
 
 
164
  threshold = np.percentile(similarities, 100 - threshold_percentile)
165
 
166
  return [i for i, sim in enumerate(similarities) if sim < threshold]
 
161
  if not similarities:
162
  return []
163
 
164
+ # Split where similarity drops into the bottom (100 - threshold_percentile)%.
165
+ # E.g., threshold_percentile=85 splits at values below the 15th percentile
166
+ # (i.e., the lowest 15% of similarities indicate topic shifts).
167
  threshold = np.percentile(similarities, 100 - threshold_percentile)
168
 
169
  return [i for i, sim in enumerate(similarities) if sim < threshold]
tests/test_chunking.py CHANGED
@@ -369,7 +369,7 @@ class TestFindSplitPoints:
369
  splits = find_split_points(sims, threshold_percentile=50)
370
  # With uniform values, the threshold equals the value itself
371
  # so no similarity is strictly below the threshold
372
- assert isinstance(splits, list)
373
 
374
  def test_clear_topic_boundary(self):
375
  # High similarities with one dip
@@ -383,8 +383,9 @@ class TestFindSplitPoints:
383
  assert splits == []
384
 
385
  def test_single_value(self):
 
386
  splits = find_split_points([0.5])
387
- assert isinstance(splits, list)
388
 
389
  def test_returns_sorted_indices(self):
390
  sims = [0.9, 0.1, 0.9, 0.05, 0.9]
 
369
  splits = find_split_points(sims, threshold_percentile=50)
370
  # With uniform values, the threshold equals the value itself
371
  # so no similarity is strictly below the threshold
372
+ assert splits == []
373
 
374
  def test_clear_topic_boundary(self):
375
  # High similarities with one dip
 
383
  assert splits == []
384
 
385
  def test_single_value(self):
386
+ # Single value cannot be below its own percentile threshold
387
  splits = find_split_points([0.5])
388
+ assert splits == []
389
 
390
  def test_returns_sorted_indices(self):
391
  sims = [0.9, 0.1, 0.9, 0.05, 0.9]