more cross comparing
Browse files
script/hyperparameter_tuning.py
CHANGED
|
@@ -84,8 +84,9 @@ def objective(trial, hyperparam_run_dir, data_path):
|
|
| 84 |
'dataset_label': dataset_label,
|
| 85 |
'trial_number': trial.number,
|
| 86 |
'parameters': trial.params,
|
| 87 |
-
'
|
| 88 |
-
'visualization_dir': vis_dir
|
|
|
|
| 89 |
}
|
| 90 |
|
| 91 |
with open(os.path.join(trial_dir, 'trial_info.json'), 'w') as f:
|
|
|
|
| 84 |
'dataset_label': dataset_label,
|
| 85 |
'trial_number': trial.number,
|
| 86 |
'parameters': trial.params,
|
| 87 |
+
'accuracy': val_accuracy,
|
| 88 |
+
'visualization_dir': vis_dir,
|
| 89 |
+
'trial_dir': trial_dir
|
| 90 |
}
|
| 91 |
|
| 92 |
with open(os.path.join(trial_dir, 'trial_info.json'), 'w') as f:
|
script/visualization/analyze_trials.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
def parse_error_analysis(vis_dir):
|
| 8 |
+
"""Parse the error_analysis.txt file to get accuracy and misclassification details"""
|
| 9 |
+
metrics = {}
|
| 10 |
+
class_accuracies = {}
|
| 11 |
+
misclassified_files = []
|
| 12 |
+
|
| 13 |
+
with open(os.path.join(vis_dir, 'error_analysis.txt'), 'r') as f:
|
| 14 |
+
lines = f.readlines()
|
| 15 |
+
parsing_errors = False
|
| 16 |
+
header_found = False
|
| 17 |
+
|
| 18 |
+
for line in lines:
|
| 19 |
+
# Get overall accuracy
|
| 20 |
+
if line.startswith("Overall Accuracy:"):
|
| 21 |
+
metrics['overall_accuracy'] = float(line.split(":")[1].strip().rstrip('%')) / 100
|
| 22 |
+
|
| 23 |
+
# Parse per-class accuracy
|
| 24 |
+
if "samples)" in line and ":" in line:
|
| 25 |
+
class_name = line.split(":")[0].strip()
|
| 26 |
+
accuracy = float(line.split(":")[1].split("%")[0].strip()) / 100
|
| 27 |
+
samples = int(line.split("(")[1].split(" ")[0])
|
| 28 |
+
class_accuracies[class_name] = {
|
| 29 |
+
'accuracy': accuracy,
|
| 30 |
+
'samples': samples
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Parse misclassified files
|
| 34 |
+
if "Misclassified Videos:" in line:
|
| 35 |
+
parsing_errors = True
|
| 36 |
+
continue
|
| 37 |
+
if "Filename" in line and "True Class" in line:
|
| 38 |
+
header_found = True
|
| 39 |
+
continue
|
| 40 |
+
if parsing_errors and header_found and line.strip() and not line.startswith("-"):
|
| 41 |
+
try:
|
| 42 |
+
# Split the line while preserving filename with spaces
|
| 43 |
+
parts = line.strip().split()
|
| 44 |
+
# Find the confidence value (last element with %)
|
| 45 |
+
confidence_idx = next(i for i, part in enumerate(parts) if part.endswith('%'))
|
| 46 |
+
# Everything before the last three elements is the filename
|
| 47 |
+
filename = ' '.join(parts[:confidence_idx-2])
|
| 48 |
+
true_class = parts[confidence_idx-2]
|
| 49 |
+
pred_class = parts[confidence_idx-1]
|
| 50 |
+
confidence = float(parts[confidence_idx].rstrip('%')) / 100
|
| 51 |
+
|
| 52 |
+
misclassified_files.append({
|
| 53 |
+
'filename': filename,
|
| 54 |
+
'true_class': true_class,
|
| 55 |
+
'predicted_class': pred_class,
|
| 56 |
+
'confidence': confidence
|
| 57 |
+
})
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Warning: Could not parse line: {line.strip()}")
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
metrics['class_accuracies'] = class_accuracies
|
| 63 |
+
metrics['misclassified_files'] = misclassified_files
|
| 64 |
+
return metrics
|
| 65 |
+
|
| 66 |
+
def analyze_trial(trial_dir):
|
| 67 |
+
"""Analyze all visualization directories in a trial and aggregate results"""
|
| 68 |
+
trial_metrics = {
|
| 69 |
+
'overall_accuracy': 0,
|
| 70 |
+
'total_samples': 0,
|
| 71 |
+
'class_accuracies': defaultdict(lambda: {'correct': 0, 'total': 0}),
|
| 72 |
+
'misclassified_files': []
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Find all visualization directories
|
| 76 |
+
vis_dirs = [d for d in trial_dir.iterdir() if d.is_dir() and d.name.startswith('visualization_')]
|
| 77 |
+
if not vis_dirs:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
for vis_dir in vis_dirs:
|
| 81 |
+
try:
|
| 82 |
+
metrics = parse_error_analysis(vis_dir)
|
| 83 |
+
|
| 84 |
+
# Add to total samples and weighted accuracy
|
| 85 |
+
samples = sum(m['samples'] for m in metrics['class_accuracies'].values())
|
| 86 |
+
trial_metrics['total_samples'] += samples
|
| 87 |
+
trial_metrics['overall_accuracy'] += metrics['overall_accuracy'] * samples
|
| 88 |
+
|
| 89 |
+
# Aggregate per-class metrics
|
| 90 |
+
for class_name, class_metrics in metrics['class_accuracies'].items():
|
| 91 |
+
trial_metrics['class_accuracies'][class_name]['correct'] += (
|
| 92 |
+
class_metrics['accuracy'] * class_metrics['samples']
|
| 93 |
+
)
|
| 94 |
+
trial_metrics['class_accuracies'][class_name]['total'] += class_metrics['samples']
|
| 95 |
+
|
| 96 |
+
# Collect misclassified files with visualization directory info
|
| 97 |
+
for error in metrics['misclassified_files']:
|
| 98 |
+
error['vis_dir'] = vis_dir.name
|
| 99 |
+
trial_metrics['misclassified_files'].append(error)
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Error processing visualization directory {vis_dir}: {e}")
|
| 103 |
+
|
| 104 |
+
# Calculate final metrics
|
| 105 |
+
if trial_metrics['total_samples'] > 0:
|
| 106 |
+
trial_metrics['overall_accuracy'] /= trial_metrics['total_samples']
|
| 107 |
+
|
| 108 |
+
for class_metrics in trial_metrics['class_accuracies'].values():
|
| 109 |
+
if class_metrics['total'] > 0:
|
| 110 |
+
class_metrics['accuracy'] = class_metrics['correct'] / class_metrics['total']
|
| 111 |
+
|
| 112 |
+
return trial_metrics
|
| 113 |
+
|
| 114 |
+
def analyze_trials(hyperparam_dir):
|
| 115 |
+
results = {
|
| 116 |
+
'search_dirs': defaultdict(lambda: {
|
| 117 |
+
'best_overall': {'accuracy': 0, 'trial': None},
|
| 118 |
+
'best_per_class': defaultdict(lambda: {'accuracy': 0, 'trial': None}),
|
| 119 |
+
'misclassified_files': []
|
| 120 |
+
})
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# Process each search directory
|
| 124 |
+
for search_dir in Path(hyperparam_dir).iterdir():
|
| 125 |
+
if not search_dir.is_dir() or not search_dir.name.startswith('search_'):
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
# Process each trial directory
|
| 129 |
+
for trial_dir in search_dir.iterdir():
|
| 130 |
+
if not trial_dir.is_dir() or not trial_dir.name.startswith('trial_'):
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
trial_metrics = analyze_trial(trial_dir)
|
| 134 |
+
if trial_metrics is None:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
search_results = results['search_dirs'][search_dir.name]
|
| 138 |
+
|
| 139 |
+
# Update overall best for this search directory
|
| 140 |
+
if trial_metrics['overall_accuracy'] > search_results['best_overall']['accuracy']:
|
| 141 |
+
search_results['best_overall']['accuracy'] = trial_metrics['overall_accuracy']
|
| 142 |
+
search_results['best_overall']['trial'] = trial_dir.name
|
| 143 |
+
|
| 144 |
+
# Update per-class bests for this search directory
|
| 145 |
+
for class_name, class_metrics in trial_metrics['class_accuracies'].items():
|
| 146 |
+
if class_metrics['accuracy'] > search_results['best_per_class'][class_name]['accuracy']:
|
| 147 |
+
search_results['best_per_class'][class_name]['accuracy'] = class_metrics['accuracy']
|
| 148 |
+
search_results['best_per_class'][class_name]['trial'] = trial_dir.name
|
| 149 |
+
|
| 150 |
+
# Collect misclassified files
|
| 151 |
+
search_results['misclassified_files'].extend(trial_metrics['misclassified_files'])
|
| 152 |
+
|
| 153 |
+
return results
|
| 154 |
+
|
| 155 |
+
def save_analysis_report(results, hyperparam_dir):
|
| 156 |
+
output_file = os.path.join(hyperparam_dir, 'trial_analysis_report.txt')
|
| 157 |
+
|
| 158 |
+
with open(output_file, 'w') as f:
|
| 159 |
+
for search_dir, search_results in results['search_dirs'].items():
|
| 160 |
+
f.write(f"\n=== Results for {search_dir} ===\n")
|
| 161 |
+
f.write("-" * 80 + "\n")
|
| 162 |
+
|
| 163 |
+
# Best overall model
|
| 164 |
+
f.write("\nBest Overall Model:\n")
|
| 165 |
+
f.write(f"Trial: {search_results['best_overall']['trial']}\n")
|
| 166 |
+
f.write(f"Accuracy: {search_results['best_overall']['accuracy']:.2%}\n")
|
| 167 |
+
|
| 168 |
+
# Best model per class
|
| 169 |
+
f.write("\nBest Model Per Class:\n")
|
| 170 |
+
f.write(f"{'Class':<20} {'Accuracy':<10} {'Trial'}\n")
|
| 171 |
+
f.write("-" * 60 + "\n")
|
| 172 |
+
for class_name, data in search_results['best_per_class'].items():
|
| 173 |
+
f.write(f"{class_name:<20} {data['accuracy']:.2%} {data['trial']}\n")
|
| 174 |
+
|
| 175 |
+
# Most frequently misclassified files
|
| 176 |
+
f.write("\nMost Frequently Misclassified Files:\n")
|
| 177 |
+
f.write(f"{'Filename':<40} {'True Class':<15} {'Predicted':<15} {'Confidence':<10} {'Dataset'}\n")
|
| 178 |
+
f.write("-" * 100 + "\n")
|
| 179 |
+
|
| 180 |
+
# Sort misclassified files by confidence (ascending) to show most problematic cases first
|
| 181 |
+
misclassified = sorted(search_results['misclassified_files'],
|
| 182 |
+
key=lambda x: x['confidence'])
|
| 183 |
+
for error in misclassified[:10]: # Show top 10 most problematic
|
| 184 |
+
f.write(f"{error['filename']:<40} {error['true_class']:<15} "
|
| 185 |
+
f"{error['predicted_class']:<15} {error['confidence']:<10.2%} {error['vis_dir']}\n")
|
| 186 |
+
|
| 187 |
+
f.write("\n" + "=" * 80 + "\n")
|
| 188 |
+
|
| 189 |
+
def print_results(results):
|
| 190 |
+
"""Print a summary of the analysis results"""
|
| 191 |
+
for search_dir, search_results in results['search_dirs'].items():
|
| 192 |
+
print(f"\n=== Results for {search_dir} ===")
|
| 193 |
+
print("-" * 80)
|
| 194 |
+
|
| 195 |
+
# Best overall model
|
| 196 |
+
print(f"\nBest Overall Model:")
|
| 197 |
+
print(f"Trial: {search_results['best_overall']['trial']}")
|
| 198 |
+
print(f"Accuracy: {search_results['best_overall']['accuracy']:.2%}")
|
| 199 |
+
|
| 200 |
+
# Best model per class
|
| 201 |
+
print(f"\nBest Model Per Class:")
|
| 202 |
+
print(f"{'Class':<20} {'Accuracy':<10} {'Trial'}")
|
| 203 |
+
print("-" * 60)
|
| 204 |
+
for class_name, data in search_results['best_per_class'].items():
|
| 205 |
+
print(f"{class_name:<20} {data['accuracy']:.2%} {data['trial']}")
|
| 206 |
+
|
| 207 |
+
# Most frequently misclassified files (top 5)
|
| 208 |
+
print(f"\nTop 5 Most Problematic Files:")
|
| 209 |
+
print(f"{'Filename':<40} {'True Class':<15} {'Predicted':<15} {'Confidence'}")
|
| 210 |
+
print("-" * 80)
|
| 211 |
+
misclassified = sorted(search_results['misclassified_files'],
|
| 212 |
+
key=lambda x: x['confidence'])[:5]
|
| 213 |
+
for error in misclassified:
|
| 214 |
+
print(f"{error['filename']:<40} {error['true_class']:<15} "
|
| 215 |
+
f"{error['predicted_class']:<15} {error['confidence']:.2%}")
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
hyperparam_dir = "runs_hyperparam/hyperparam_20241106_124214"
|
| 219 |
+
results = analyze_trials(hyperparam_dir)
|
| 220 |
+
|
| 221 |
+
# Print summary to console
|
| 222 |
+
print_results(results)
|
| 223 |
+
|
| 224 |
+
# Save detailed results to file
|
| 225 |
+
save_analysis_report(results, hyperparam_dir)
|
script/visualization/visualize.py
CHANGED
|
@@ -45,9 +45,10 @@ def generate_evaluation_metrics(model, data_loader, device, output_dir, class_la
|
|
| 45 |
all_preds = []
|
| 46 |
all_labels = []
|
| 47 |
all_probs = []
|
|
|
|
| 48 |
|
| 49 |
with torch.no_grad():
|
| 50 |
-
for frames, labels,
|
| 51 |
frames = frames.to(device)
|
| 52 |
labels = labels.to(device)
|
| 53 |
|
|
@@ -58,11 +59,44 @@ def generate_evaluation_metrics(model, data_loader, device, output_dir, class_la
|
|
| 58 |
all_preds.extend(predicted.cpu().numpy())
|
| 59 |
all_labels.extend(labels.cpu().numpy())
|
| 60 |
all_probs.extend(probs.cpu().numpy())
|
|
|
|
| 61 |
|
| 62 |
all_labels = np.array(all_labels)
|
| 63 |
all_preds = np.array(all_preds)
|
| 64 |
all_probs = np.array(all_probs)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# Compute and plot confusion matrix
|
| 67 |
cm = confusion_matrix(all_labels, all_preds)
|
| 68 |
plt.figure(figsize=(10, 8))
|
|
@@ -124,7 +158,11 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
|
|
| 124 |
|
| 125 |
class_labels = config['class_labels']
|
| 126 |
num_classes = config['num_classes']
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
# Paths
|
| 130 |
log_file = os.path.join(run_dir, 'training_log.csv')
|
|
@@ -136,6 +174,8 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
|
|
| 136 |
# Get the last directory of data_path and the file name
|
| 137 |
last_dir = os.path.basename(os.path.normpath(data_path))
|
| 138 |
file_name = os.path.basename(test_csv)
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# Create a directory for visualization outputs
|
| 141 |
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
|
|
@@ -164,8 +204,9 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
|
|
| 164 |
|
| 165 |
if __name__ == "__main__":
|
| 166 |
# Find the most recent run directory
|
| 167 |
-
run_dir = get_latest_run_dir()
|
|
|
|
| 168 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
|
| 169 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
|
| 170 |
-
|
| 171 |
-
run_visualization(run_dir)
|
|
|
|
| 45 |
all_preds = []
|
| 46 |
all_labels = []
|
| 47 |
all_probs = []
|
| 48 |
+
all_files = []
|
| 49 |
|
| 50 |
with torch.no_grad():
|
| 51 |
+
for frames, labels, filenames in data_loader:
|
| 52 |
frames = frames.to(device)
|
| 53 |
labels = labels.to(device)
|
| 54 |
|
|
|
|
| 59 |
all_preds.extend(predicted.cpu().numpy())
|
| 60 |
all_labels.extend(labels.cpu().numpy())
|
| 61 |
all_probs.extend(probs.cpu().numpy())
|
| 62 |
+
all_files.extend(filenames)
|
| 63 |
|
| 64 |
all_labels = np.array(all_labels)
|
| 65 |
all_preds = np.array(all_preds)
|
| 66 |
all_probs = np.array(all_probs)
|
| 67 |
|
| 68 |
+
# Generate error analysis file
|
| 69 |
+
error_file = os.path.join(output_dir, 'error_analysis.txt')
|
| 70 |
+
with open(error_file, 'w') as f:
|
| 71 |
+
f.write(f"Error Analysis for {data_info}\n")
|
| 72 |
+
f.write("=" * 80 + "\n\n")
|
| 73 |
+
|
| 74 |
+
# Overall accuracy
|
| 75 |
+
accuracy = (all_labels == all_preds).mean()
|
| 76 |
+
f.write(f"Overall Accuracy: {accuracy:.2%}\n\n")
|
| 77 |
+
|
| 78 |
+
# Per-class accuracy
|
| 79 |
+
f.write("Per-Class Accuracy:\n")
|
| 80 |
+
for i, class_name in enumerate(class_labels):
|
| 81 |
+
class_mask = all_labels == i
|
| 82 |
+
if class_mask.sum() > 0:
|
| 83 |
+
class_acc = (all_preds[class_mask] == i).mean()
|
| 84 |
+
f.write(f"{class_name}: {class_acc:.2%} ({(class_mask).sum()} samples)\n")
|
| 85 |
+
f.write("\n")
|
| 86 |
+
|
| 87 |
+
# Detailed error analysis
|
| 88 |
+
f.write("Misclassified Videos:\n")
|
| 89 |
+
f.write("-" * 80 + "\n")
|
| 90 |
+
f.write(f"{'Filename':<40} {'True Class':<20} {'Predicted Class':<20} Confidence\n")
|
| 91 |
+
f.write("-" * 80 + "\n")
|
| 92 |
+
|
| 93 |
+
for i, (true_label, pred_label, probs, filename) in enumerate(zip(all_labels, all_preds, all_probs, all_files)):
|
| 94 |
+
if true_label != pred_label:
|
| 95 |
+
true_class = class_labels[true_label]
|
| 96 |
+
pred_class = class_labels[pred_label]
|
| 97 |
+
confidence = probs[pred_label]
|
| 98 |
+
f.write(f"{filename:<40} {true_class:<20} {pred_class:<20} {confidence:.2%}\n")
|
| 99 |
+
|
| 100 |
# Compute and plot confusion matrix
|
| 101 |
cm = confusion_matrix(all_labels, all_preds)
|
| 102 |
plt.figure(figsize=(10, 8))
|
|
|
|
| 158 |
|
| 159 |
class_labels = config['class_labels']
|
| 160 |
num_classes = config['num_classes']
|
| 161 |
+
|
| 162 |
+
# Update the config's data_path if provided
|
| 163 |
+
if data_path:
|
| 164 |
+
config['data_path'] = data_path
|
| 165 |
+
data_path = config['data_path']
|
| 166 |
|
| 167 |
# Paths
|
| 168 |
log_file = os.path.join(run_dir, 'training_log.csv')
|
|
|
|
| 174 |
# Get the last directory of data_path and the file name
|
| 175 |
last_dir = os.path.basename(os.path.normpath(data_path))
|
| 176 |
file_name = os.path.basename(test_csv)
|
| 177 |
+
|
| 178 |
+
print(f"Running visualization for {data_path} with {test_csv} from CWD {os.getcwd()}")
|
| 179 |
|
| 180 |
# Create a directory for visualization outputs
|
| 181 |
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
|
|
|
|
| 204 |
|
| 205 |
if __name__ == "__main__":
|
| 206 |
# Find the most recent run directory
|
| 207 |
+
# run_dir = get_latest_run_dir()
|
| 208 |
+
run_dir = "/home/bawolf/workspace/break/clip/runs_hyperparam/hyperparam_20241106_124214/search_combined_adjusted/trial_combined_adjusted_20241106-195023/"
|
| 209 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
|
| 210 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
|
| 211 |
+
data_path = "/home/bawolf/workspace/break/finetune/blog/combined/all"
|
| 212 |
+
run_visualization(run_dir, data_path=data_path)
|
script/visualization/viz_cross_compare.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from visualize import run_visualization
|
| 4 |
+
|
| 5 |
+
def get_opposite_dataset_path(run_folder):
|
| 6 |
+
# Map run folders to their corresponding opposite dataset training files
|
| 7 |
+
dataset_mapping = {
|
| 8 |
+
'search_bryant_adjusted': '../finetune/blog/youtube/adjusted',
|
| 9 |
+
'search_bryant_random': '../finetune/blog/youtube/random',
|
| 10 |
+
'search_youtube_adjusted': '../finetune/blog/bryant/adjusted',
|
| 11 |
+
'search_youtube_random': '../finetune/blog/bryant/random'
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
for folder_prefix, dataset_path in dataset_mapping.items():
|
| 15 |
+
if run_folder.startswith(folder_prefix):
|
| 16 |
+
return dataset_path
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
def process_runs(base_dir):
|
| 20 |
+
# Get the full path to the runs directory
|
| 21 |
+
runs_dir = Path(base_dir)
|
| 22 |
+
|
| 23 |
+
# Process each search directory
|
| 24 |
+
for search_dir in runs_dir.iterdir():
|
| 25 |
+
if not search_dir.is_dir() or search_dir.name == 'visualization':
|
| 26 |
+
continue
|
| 27 |
+
|
| 28 |
+
# Get the opposite dataset path for this search directory
|
| 29 |
+
opposite_dataset = get_opposite_dataset_path(search_dir.name)
|
| 30 |
+
|
| 31 |
+
if opposite_dataset is not None:
|
| 32 |
+
print(f"Skipping {search_dir.name} - no matching dataset mapping")
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
# Process each trial directory within the search directory
|
| 36 |
+
for trial_dir in search_dir.iterdir():
|
| 37 |
+
if not trial_dir.is_dir() or not trial_dir.name.startswith('trial_'):
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
print(f"Processing {trial_dir} with {opposite_dataset}")
|
| 41 |
+
try:
|
| 42 |
+
vis_dir, cm = run_visualization(
|
| 43 |
+
run_dir=str(trial_dir),
|
| 44 |
+
data_path=opposite_dataset,
|
| 45 |
+
test_csv=os.path.join(opposite_dataset, "train.csv")
|
| 46 |
+
)
|
| 47 |
+
print(f"Visualization complete: {vis_dir}")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"Error processing {trial_dir}: {e}")
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
# Example usage
|
| 53 |
+
runs_path = "runs_hyperparam/hyperparam_20241106_124214"
|
| 54 |
+
process_runs(runs_path)
|
src/dataset/dataset.py
CHANGED
|
@@ -48,6 +48,11 @@ class VideoDataset(Dataset):
|
|
| 48 |
def __getitem__(self, idx):
|
| 49 |
video_path, label = self.data[idx]
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
frames, success = extract_frames(video_path,
|
| 52 |
{"max_frames": self.max_frames, "sigma": self.sigma},
|
| 53 |
self.transform)
|
|
|
|
| 48 |
def __getitem__(self, idx):
|
| 49 |
video_path, label = self.data[idx]
|
| 50 |
|
| 51 |
+
if not os.path.exists(video_path):
|
| 52 |
+
print(f"File not found: {video_path}")
|
| 53 |
+
print(f"Absolute path attempt: {os.path.abspath(video_path)}")
|
| 54 |
+
raise FileNotFoundError(f"File not found: {video_path}")
|
| 55 |
+
|
| 56 |
frames, success = extract_frames(video_path,
|
| 57 |
{"max_frames": self.max_frames, "sigma": self.sigma},
|
| 58 |
self.transform)
|