Updated app with code for deduplication
Browse files
app.py
CHANGED
|
@@ -108,24 +108,31 @@ def perform_deduplication(
|
|
| 108 |
# Convert threshold to float
|
| 109 |
threshold = float(threshold)
|
| 110 |
|
|
|
|
|
|
|
|
|
|
| 111 |
if deduplication_type == "Single dataset":
|
| 112 |
# Load Dataset 1
|
| 113 |
-
|
|
|
|
| 114 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
| 115 |
ds = ds_default1
|
| 116 |
else:
|
| 117 |
ds = load_dataset(dataset1_name, split=dataset1_split)
|
| 118 |
|
| 119 |
# Extract texts
|
| 120 |
-
|
|
|
|
| 121 |
texts = [example[dataset1_text_column] for example in ds]
|
| 122 |
|
| 123 |
# Compute embeddings
|
| 124 |
-
|
|
|
|
| 125 |
embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
|
| 126 |
|
| 127 |
# Deduplicate
|
| 128 |
-
|
|
|
|
| 129 |
deduplicated_indices, duplicate_to_original_mapping = deduplicate(
|
| 130 |
embedding_matrix, threshold
|
| 131 |
)
|
|
@@ -154,41 +161,50 @@ def perform_deduplication(
|
|
| 154 |
else:
|
| 155 |
result_text += "No duplicates found."
|
| 156 |
|
| 157 |
-
|
|
|
|
|
|
|
| 158 |
|
| 159 |
elif deduplication_type == "Cross-dataset":
|
| 160 |
# Load Dataset 1
|
| 161 |
-
|
|
|
|
| 162 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
| 163 |
ds1 = ds_default1
|
| 164 |
else:
|
| 165 |
ds1 = load_dataset(dataset1_name, split=dataset1_split)
|
| 166 |
|
| 167 |
# Load Dataset 2
|
| 168 |
-
|
|
|
|
| 169 |
if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
|
| 170 |
ds2 = ds_default2
|
| 171 |
else:
|
| 172 |
ds2 = load_dataset(dataset2_name, split=dataset2_split)
|
| 173 |
|
| 174 |
# Extract texts from Dataset 1
|
| 175 |
-
|
|
|
|
| 176 |
texts1 = [example[dataset1_text_column] for example in ds1]
|
| 177 |
|
| 178 |
# Extract texts from Dataset 2
|
| 179 |
-
|
|
|
|
| 180 |
texts2 = [example[dataset2_text_column] for example in ds2]
|
| 181 |
|
| 182 |
# Compute embeddings for Dataset 1
|
| 183 |
-
|
|
|
|
| 184 |
embedding_matrix1 = model.encode(texts1, show_progressbar=True)
|
| 185 |
|
| 186 |
# Compute embeddings for Dataset 2
|
| 187 |
-
|
|
|
|
| 188 |
embedding_matrix2 = model.encode(texts2, show_progressbar=True)
|
| 189 |
|
| 190 |
# Deduplicate across datasets
|
| 191 |
-
|
|
|
|
| 192 |
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
|
| 193 |
embedding_matrix1, embedding_matrix2, threshold
|
| 194 |
)
|
|
@@ -217,7 +233,9 @@ def perform_deduplication(
|
|
| 217 |
else:
|
| 218 |
result_text += "No duplicates found."
|
| 219 |
|
| 220 |
-
|
|
|
|
|
|
|
| 221 |
|
| 222 |
finally:
|
| 223 |
# Restore original tqdm
|
|
@@ -257,7 +275,8 @@ with gr.Blocks() as demo:
|
|
| 257 |
|
| 258 |
compute_button = gr.Button("Compute")
|
| 259 |
|
| 260 |
-
|
|
|
|
| 261 |
|
| 262 |
# Function to update the visibility of dataset2_inputs
|
| 263 |
def update_visibility(deduplication_type_value):
|
|
@@ -284,9 +303,9 @@ with gr.Blocks() as demo:
|
|
| 284 |
dataset2_text_column,
|
| 285 |
threshold
|
| 286 |
],
|
| 287 |
-
outputs=
|
| 288 |
)
|
| 289 |
-
|
| 290 |
demo.launch()
|
| 291 |
|
| 292 |
|
|
|
|
| 108 |
# Convert threshold to float
|
| 109 |
threshold = float(threshold)
|
| 110 |
|
| 111 |
+
# Initialize status message
|
| 112 |
+
status = ""
|
| 113 |
+
|
| 114 |
if deduplication_type == "Single dataset":
|
| 115 |
# Load Dataset 1
|
| 116 |
+
status = "Loading Dataset 1..."
|
| 117 |
+
yield status, ""
|
| 118 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
| 119 |
ds = ds_default1
|
| 120 |
else:
|
| 121 |
ds = load_dataset(dataset1_name, split=dataset1_split)
|
| 122 |
|
| 123 |
# Extract texts
|
| 124 |
+
status = "Extracting texts from Dataset 1..."
|
| 125 |
+
yield status, ""
|
| 126 |
texts = [example[dataset1_text_column] for example in ds]
|
| 127 |
|
| 128 |
# Compute embeddings
|
| 129 |
+
status = "Computing embeddings for Dataset 1..."
|
| 130 |
+
yield status, ""
|
| 131 |
embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
|
| 132 |
|
| 133 |
# Deduplicate
|
| 134 |
+
status = "Deduplicating embeddings..."
|
| 135 |
+
yield status, ""
|
| 136 |
deduplicated_indices, duplicate_to_original_mapping = deduplicate(
|
| 137 |
embedding_matrix, threshold
|
| 138 |
)
|
|
|
|
| 161 |
else:
|
| 162 |
result_text += "No duplicates found."
|
| 163 |
|
| 164 |
+
# Final status
|
| 165 |
+
status = "Deduplication completed."
|
| 166 |
+
yield status, result_text
|
| 167 |
|
| 168 |
elif deduplication_type == "Cross-dataset":
|
| 169 |
# Load Dataset 1
|
| 170 |
+
status = "Loading Dataset 1..."
|
| 171 |
+
yield status, ""
|
| 172 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
| 173 |
ds1 = ds_default1
|
| 174 |
else:
|
| 175 |
ds1 = load_dataset(dataset1_name, split=dataset1_split)
|
| 176 |
|
| 177 |
# Load Dataset 2
|
| 178 |
+
status = "Loading Dataset 2..."
|
| 179 |
+
yield status, ""
|
| 180 |
if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
|
| 181 |
ds2 = ds_default2
|
| 182 |
else:
|
| 183 |
ds2 = load_dataset(dataset2_name, split=dataset2_split)
|
| 184 |
|
| 185 |
# Extract texts from Dataset 1
|
| 186 |
+
status = "Extracting texts from Dataset 1..."
|
| 187 |
+
yield status, ""
|
| 188 |
texts1 = [example[dataset1_text_column] for example in ds1]
|
| 189 |
|
| 190 |
# Extract texts from Dataset 2
|
| 191 |
+
status = "Extracting texts from Dataset 2..."
|
| 192 |
+
yield status, ""
|
| 193 |
texts2 = [example[dataset2_text_column] for example in ds2]
|
| 194 |
|
| 195 |
# Compute embeddings for Dataset 1
|
| 196 |
+
status = "Computing embeddings for Dataset 1..."
|
| 197 |
+
yield status, ""
|
| 198 |
embedding_matrix1 = model.encode(texts1, show_progressbar=True)
|
| 199 |
|
| 200 |
# Compute embeddings for Dataset 2
|
| 201 |
+
status = "Computing embeddings for Dataset 2..."
|
| 202 |
+
yield status, ""
|
| 203 |
embedding_matrix2 = model.encode(texts2, show_progressbar=True)
|
| 204 |
|
| 205 |
# Deduplicate across datasets
|
| 206 |
+
status = "Deduplicating embeddings across datasets..."
|
| 207 |
+
yield status, ""
|
| 208 |
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
|
| 209 |
embedding_matrix1, embedding_matrix2, threshold
|
| 210 |
)
|
|
|
|
| 233 |
else:
|
| 234 |
result_text += "No duplicates found."
|
| 235 |
|
| 236 |
+
# Final status
|
| 237 |
+
status = "Deduplication completed."
|
| 238 |
+
yield status, result_text
|
| 239 |
|
| 240 |
finally:
|
| 241 |
# Restore original tqdm
|
|
|
|
| 275 |
|
| 276 |
compute_button = gr.Button("Compute")
|
| 277 |
|
| 278 |
+
status_output = gr.Markdown()
|
| 279 |
+
result_output = gr.Markdown()
|
| 280 |
|
| 281 |
# Function to update the visibility of dataset2_inputs
|
| 282 |
def update_visibility(deduplication_type_value):
|
|
|
|
| 303 |
dataset2_text_column,
|
| 304 |
threshold
|
| 305 |
],
|
| 306 |
+
outputs=[status_output, result_output]
|
| 307 |
)
|
| 308 |
+
|
| 309 |
demo.launch()
|
| 310 |
|
| 311 |
|