File size: 4,247 Bytes
31fc7e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import json
from collections import Counter
import matplotlib.pyplot as plt
from pathlib import Path
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.utils.utils import get_latest_run_dir

def analyze_misclassifications(run_dir=None):
    if run_dir is None:
        # run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
        run_dir =  get_latest_run_dir()
    
    misclassifications_dir = os.path.join(run_dir, 'misclassifications')
    all_misclassifications = {}
    
    # Collect all misclassifications across epochs
    for file in os.listdir(misclassifications_dir):
        if file.endswith('.json'):
            with open(os.path.join(misclassifications_dir, file), 'r') as f:
                epoch_misclassifications = json.load(f)
                for item in epoch_misclassifications:
                    video_path = item['video_path']
                    if video_path not in all_misclassifications:
                        all_misclassifications[video_path] = []
                    all_misclassifications[video_path].append(item)

    # Determine the total number of epochs from the files
    epoch_files = [f for f in os.listdir(misclassifications_dir) if f.startswith('epoch_') and f.endswith('.json')]
    total_epochs = len(epoch_files)

    # Count misclassifications per video
    misclassification_counts = {video: len(misclassifications) 
                                for video, misclassifications in all_misclassifications.items()}

    # Calculate percentage of epochs each video was misclassified
    misclassification_percentages = {video: (count / total_epochs) * 100 
                                     for video, count in misclassification_counts.items()}

    # Sort videos by misclassification percentage
    sorted_videos = sorted(misclassification_percentages.items(), key=lambda x: x[1], reverse=True)

    # Prepare report
    report = "Misclassification Analysis Report\n"
    report += "=================================\n\n"

    # Top N most misclassified videos
    N = 20
    report += f"Top {N} Most Misclassified Videos:\n"
    for video, percentage in sorted_videos[:N]:
        report += f"{Path(video).name}: Misclassified in {percentage:.2f}% of epochs ({misclassification_counts[video]} out of {total_epochs})\n"
        misclassifications = all_misclassifications[video]
        true_label = misclassifications[0]['true_label']
        predicted_labels = Counter(m['predicted_label'] for m in misclassifications)
        report += f"  True Label: {true_label}\n"
        report += f"  Predicted Labels: {dict(predicted_labels)}\n\n"

    # Overall statistics
    total_misclassifications = sum(misclassification_counts.values())
    total_videos = len(misclassification_counts)
    report += "Overall Statistics:\n"
    report += f"Total misclassified videos: {total_videos}\n"
    report += f"Total misclassifications: {total_misclassifications}\n"
    report += f"Average misclassification percentage per video: {sum(misclassification_percentages.values()) / total_videos:.2f}%\n"
    report += f"Total epochs: {total_epochs}\n"

    # Save report
    report_path = os.path.join(run_dir, 'misclassification_report.txt')
    with open(report_path, 'w') as f:
        f.write(report)

    # Create visualization
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(sorted_videos)), [percentage for _, percentage in sorted_videos])
    plt.title(f'Videos Ranked by Misclassification Percentage (Total Epochs: {total_epochs})')
    plt.xlabel('Video Rank')
    plt.ylabel('Misclassification Percentage')
    plt.ylim(0, 100)  # Set y-axis limit to 0-100%
    plt.tight_layout()
    plt.savefig(os.path.join(run_dir, 'misclassification_distribution.png'))

    print(f"Analysis complete. Report saved to {report_path}")
    print(f"Visualization saved to {os.path.join(run_dir, 'misclassification_distribution.png')}")

if __name__ == "__main__":
    import sys
    if len(sys.argv) > 2:
        print("Usage: python analyze_misclassifications.py [path_to_run_directory]")
        sys.exit(1)
    
    run_dir = sys.argv[1] if len(sys.argv) == 2 else None
    analyze_misclassifications(run_dir)