Updates
Browse files
app.py
CHANGED
|
@@ -20,6 +20,35 @@ default_threshold = 0.9
|
|
| 20 |
ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
|
| 21 |
ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def batch_iterable(iterable, batch_size):
|
| 24 |
"""Helper function to create batches from an iterable."""
|
| 25 |
for i in range(0, len(iterable), batch_size):
|
|
@@ -114,15 +143,18 @@ def perform_deduplication(
|
|
| 114 |
yield status, ""
|
| 115 |
texts = [example[dataset1_text_column] for example in ds]
|
| 116 |
|
|
|
|
|
|
|
| 117 |
# Compute embeddings
|
| 118 |
status = "Computing embeddings for Dataset 1..."
|
| 119 |
yield status, ""
|
| 120 |
-
embedding_matrix =
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
# Deduplicate
|
| 128 |
status = "Deduplicating embeddings..."
|
|
|
|
| 20 |
ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
|
| 21 |
ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
|
| 22 |
|
| 23 |
+
from tqdm import tqdm as original_tqdm
|
| 24 |
+
# Patch tqdm to use Gradio's progress bar
|
| 25 |
+
def patch_tqdm_for_gradio(progress):
|
| 26 |
+
class GradioTqdm(original_tqdm):
|
| 27 |
+
def __init__(self, *args, **kwargs):
|
| 28 |
+
super().__init__(*args, **kwargs)
|
| 29 |
+
self.progress = progress
|
| 30 |
+
self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
|
| 31 |
+
|
| 32 |
+
def update(self, n=1):
|
| 33 |
+
super().update(n)
|
| 34 |
+
self.progress(self.n / self.total_batches)
|
| 35 |
+
|
| 36 |
+
return GradioTqdm
|
| 37 |
+
# Function to patch the original encode function with our Gradio tqdm
|
| 38 |
+
def original_encode_with_tqdm(original_encode_func, patched_tqdm):
|
| 39 |
+
def new_encode(*args, **kwargs):
|
| 40 |
+
# Replace tqdm with our patched version
|
| 41 |
+
original_tqdm_backup = original_tqdm
|
| 42 |
+
try:
|
| 43 |
+
# Patch the `tqdm` within encode
|
| 44 |
+
globals()['tqdm'] = patched_tqdm
|
| 45 |
+
return original_encode_func(*args, **kwargs)
|
| 46 |
+
finally:
|
| 47 |
+
# Restore original tqdm after calling encode
|
| 48 |
+
globals()['tqdm'] = original_tqdm_backup
|
| 49 |
+
|
| 50 |
+
return new_encode
|
| 51 |
+
|
| 52 |
def batch_iterable(iterable, batch_size):
|
| 53 |
"""Helper function to create batches from an iterable."""
|
| 54 |
for i in range(0, len(iterable), batch_size):
|
|
|
|
| 143 |
yield status, ""
|
| 144 |
texts = [example[dataset1_text_column] for example in ds]
|
| 145 |
|
| 146 |
+
patched_tqdm = patch_tqdm_for_gradio(progress)
|
| 147 |
+
model.encode = original_encode_with_tqdm(model.encode, patched_tqdm)
|
| 148 |
# Compute embeddings
|
| 149 |
status = "Computing embeddings for Dataset 1..."
|
| 150 |
yield status, ""
|
| 151 |
+
embedding_matrix = model.encode(texts, show_progressbar=True)
|
| 152 |
+
# embedding_matrix = compute_embeddings(
|
| 153 |
+
# texts,
|
| 154 |
+
# batch_size=64,
|
| 155 |
+
# progress=progress,
|
| 156 |
+
# desc="Computing embeddings for Dataset 1",
|
| 157 |
+
# )
|
| 158 |
|
| 159 |
# Deduplicate
|
| 160 |
status = "Deduplicating embeddings..."
|