Spaces:
Runtime error
Runtime error
Commit
·
18d08ed
1
Parent(s):
833d467
update
Browse files
app.py
CHANGED
|
@@ -126,19 +126,10 @@ def load_session_data(username):
|
|
| 126 |
|
| 127 |
def load_samples(methods):
|
| 128 |
logger.info(f"Loading samples for methods: {methods}")
|
| 129 |
-
samples =
|
| 130 |
categories = ["TP", "TN", "FP", "FN"]
|
| 131 |
|
| 132 |
-
method_dirs = []
|
| 133 |
-
for method in methods:
|
| 134 |
-
if method == 'No-XAI':
|
| 135 |
-
method_dirs.append('NO_XAI')
|
| 136 |
-
elif method == 'Dater':
|
| 137 |
-
method_dirs.append('DATER')
|
| 138 |
-
elif method == 'Chain-of-Table':
|
| 139 |
-
method_dirs.append('COT')
|
| 140 |
-
elif method == 'Plan-of-SQLs':
|
| 141 |
-
method_dirs.append('POS')
|
| 142 |
|
| 143 |
for category in categories:
|
| 144 |
dir_a = f'htmls_{method_dirs[0].upper()}/{category}'
|
|
@@ -150,18 +141,29 @@ def load_samples(methods):
|
|
| 150 |
matching_files = files_a & files_b
|
| 151 |
|
| 152 |
for file in matching_files:
|
| 153 |
-
samples.
|
| 154 |
-
'category': category,
|
| 155 |
-
'file': file
|
| 156 |
-
})
|
| 157 |
|
| 158 |
-
|
|
|
|
| 159 |
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def select_balanced_samples(samples):
|
| 162 |
try:
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
return selected_samples
|
| 166 |
except Exception as e:
|
| 167 |
logger.exception("Error selecting balanced samples")
|
|
@@ -285,7 +287,6 @@ def get_method_dir(method):
|
|
| 285 |
elif method == 'Plan-of-SQLs':
|
| 286 |
return 'POS'
|
| 287 |
|
| 288 |
-
|
| 289 |
def get_visualization_dir(method):
|
| 290 |
if method == "No-XAI":
|
| 291 |
return 'htmls_NO_XAI'
|
|
|
|
| 126 |
|
| 127 |
def load_samples(methods):
|
| 128 |
logger.info(f"Loading samples for methods: {methods}")
|
| 129 |
+
samples = set() # Use a set to avoid duplicates
|
| 130 |
categories = ["TP", "TN", "FP", "FN"]
|
| 131 |
|
| 132 |
+
method_dirs = [get_method_dir(method) for method in methods]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
for category in categories:
|
| 135 |
dir_a = f'htmls_{method_dirs[0].upper()}/{category}'
|
|
|
|
| 141 |
matching_files = files_a & files_b
|
| 142 |
|
| 143 |
for file in matching_files:
|
| 144 |
+
samples.add((category, file))
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Convert set of tuples back to list of dictionaries
|
| 147 |
+
samples = [{'category': category, 'file': file} for category, file in samples]
|
| 148 |
|
| 149 |
+
logger.info(f"Loaded {len(samples)} unique samples across all categories")
|
| 150 |
+
return samples
|
| 151 |
|
| 152 |
def select_balanced_samples(samples):
|
| 153 |
try:
|
| 154 |
+
# Ensure we have at least 10 unique samples
|
| 155 |
+
unique_samples = list({(s['category'], s['file']) for s in samples})
|
| 156 |
+
|
| 157 |
+
if len(unique_samples) < 10:
|
| 158 |
+
logger.warning(f"Not enough unique samples. Only {len(unique_samples)} available.")
|
| 159 |
+
selected_samples = unique_samples
|
| 160 |
+
else:
|
| 161 |
+
selected_samples = random.sample(unique_samples, 10)
|
| 162 |
+
|
| 163 |
+
# Convert back to dictionary format
|
| 164 |
+
selected_samples = [{'category': category, 'file': file} for category, file in selected_samples]
|
| 165 |
+
|
| 166 |
+
logger.info(f"Selected {len(selected_samples)} unique samples")
|
| 167 |
return selected_samples
|
| 168 |
except Exception as e:
|
| 169 |
logger.exception("Error selecting balanced samples")
|
|
|
|
| 287 |
elif method == 'Plan-of-SQLs':
|
| 288 |
return 'POS'
|
| 289 |
|
|
|
|
| 290 |
def get_visualization_dir(method):
|
| 291 |
if method == "No-XAI":
|
| 292 |
return 'htmls_NO_XAI'
|