File size: 5,420 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from collections import defaultdict

from sotopia.database.logs import EpisodeLog


def analyze_episodes_with_positions(tag):
    # Find episodes with the specified tag
    episodes = EpisodeLog.find(EpisodeLog.tag == tag).all()
    print(f"Total episodes found: {len(episodes)}")

    # Track rewards by model name
    model_rewards = defaultdict(lambda: defaultdict(float))
    model_counts = defaultdict(int)

    # Track position counts (how many times a model appears as agent1 vs agent2)
    position_counts = defaultdict(lambda: {'agent1': 0, 'agent2': 0})

    # Track rewards by position
    position_rewards = defaultdict(lambda: {
        'agent1': defaultdict(float),
        'agent2': defaultdict(float)
    })

    episode_message_length_sum = 0

    # Process each episode
    for episode in episodes:
        try:
            # Skip if no models or rewards
            if not hasattr(episode, 'models') or len(episode.models) < 3:
                continue
            if not hasattr(episode, 'rewards') or len(episode.rewards) < 2:
                continue

            print(episode.models)
            print('Episode message length:', len(episode.messages))
            episode_message_length = len(episode.messages)
            episode_message_length_sum += episode_message_length
            # Get model names
            model1_name = episode.models[1]  # agent1's model
            model2_name = episode.models[2]  # agent2's model

            # Get rewards (handle both list and direct formats)
            try:
                reward1 = episode.rewards[0][-1]
                reward2 = episode.rewards[1][-1]
            except (IndexError, TypeError):
                continue

            # Skip if rewards are not dictionaries
            if not isinstance(reward1, dict) or not isinstance(reward2, dict):
                continue

            # Add rewards to model accumulators
            for key, value in reward1.items():
                model_rewards[model1_name][key] += value
                position_rewards[model1_name]['agent1'][key] += value

                if key == 'goal':
                    print(f"Model1 Goal: {value}")

            for key, value in reward2.items():
                model_rewards[model2_name][key] += value
                position_rewards[model2_name]['agent2'][key] += value

                if key == 'goal':
                    print(f"Model2 Goal: {value}")

            # Count model appearances
            model_counts[model1_name] += 1
            model_counts[model2_name] += 1

            # Count position appearances
            position_counts[model1_name]['agent1'] += 1
            position_counts[model2_name]['agent2'] += 1

        except Exception as e:
            print(f"Error: {e}")

    # Calculate overall averages
    print("\n===== OVERALL MODEL PERFORMANCE =====")

    print(f"Total episodes: {len(episodes)}")
    print(f"Mean message length: {episode_message_length_sum / len(episodes):.2f}")

    for model, rewards in model_rewards.items():
        print(f"\nModel: {model} (appeared in {model_counts[model]} episodes)")
        print(f"  As agent1: {position_counts[model]['agent1']} times")
        print(f"  As agent2: {position_counts[model]['agent2']} times")

        for key, value in rewards.items():
            avg = value / model_counts[model]
            print(f"  {key}: {avg:.4f}")
            # Update the dict with average value
            model_rewards[model][key] = avg

    # Calculate position-specific averages
    print("\n===== PERFORMANCE BY POSITION =====")
    for model in position_rewards:
        print(f"\nModel: {model}")

        # Agent1 position
        if position_counts[model]['agent1'] > 0:
            print(f"  As agent1 ({position_counts[model]['agent1']} episodes):")
            for key, value in position_rewards[model]['agent1'].items():
                avg = value / position_counts[model]['agent1']
                print(f"    {key}: {avg:.4f}")
                position_rewards[model]['agent1'][key] = avg
        else:
            print("  Never appeared as agent1")

        # Agent2 position
        if position_counts[model]['agent2'] > 0:
            print(f"  As agent2 ({position_counts[model]['agent2']} episodes):")
            for key, value in position_rewards[model]['agent2'].items():
                avg = value / position_counts[model]['agent2']
                print(f"    {key}: {avg:.4f}")
                position_rewards[model]['agent2'][key] = avg
        else:
            print("  Never appeared as agent2")

    # Count model pairs
    print("\n===== MODEL PAIRINGS =====")
    model_pairs = defaultdict(int)
    for episode in episodes:
        try:
            if not hasattr(episode, 'models') or len(episode.models) < 3:
                continue

            model1 = episode.models[1]
            model2 = episode.models[2]
            pair_key = f"{model1} vs {model2}"
            model_pairs[pair_key] += 1
        except Exception:
            continue

    for pair, count in model_pairs.items():
        print(f"{pair}: {count} episodes")

    return {
        'model_rewards': dict(model_rewards),
        'position_counts': dict(position_counts),
        'position_rewards': dict(position_rewards)
    }

# Run the analysis
results = analyze_episodes_with_positions("rm_goal_direct_0507_rej_sampling_num10_vs_sft_0510_epoch500_0511")