File size: 10,214 Bytes
6256eb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
#!/usr/bin/env python3
"""
Evaluate KT performance by context length (history size).

Analyzes how KT accuracy changes as student history grows from 50 to 400 interactions.
Plots all models in a single figure for comparison.

Usage:
    python evaluate_kt_by_context.py
"""

import argparse
import json
import math
import os
from glob import glob
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score

# Set publication-quality font sizes
plt.rcParams.update({
    'font.size': 14,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# Model name mapping for display
MODEL_NAMES = {
    'gptoss120b': 'GPT-OSS-120B',
    'llama33_70b_instruct': 'Llama-3.3-70B-Instruct',
    'qwen3next80binstruct': 'Qwen3-80B-Instruct',
    'qwen3next80bthinking': 'Qwen3-80B-Thinking',
}

# Colors for each model
MODEL_COLORS = {
    'gptoss120b': '#1f77b4',  # blue
    'llama33_70b_instruct': '#ff7f0e',  # orange
    'qwen3next80binstruct': '#2ca02c',  # green
    'qwen3next80bthinking': '#d62728',  # red
}


def normalize_mcq_answer(answer_str: str) -> str:
    """Normalize MCQ answer format for consistent comparison."""
    parts = [p.strip().upper() for p in answer_str.split(',')]
    parts = [p for p in parts if p]
    if parts and all(len(p) == 1 and p.isalpha() for p in parts):
        return ', '.join(sorted(set(parts)))
    return answer_str


def numerical_match(answer1: str, answer2: str, atol: float = 0.01, rtol: float = 0.01) -> bool:
    """Check if two answers are numerically close within tolerance."""
    try:
        a = float(answer1.strip())
        b = float(answer2.strip())
        return math.isclose(a, b, abs_tol=atol, rel_tol=rtol)
    except (ValueError, AttributeError):
        return False


def answers_match(pred, actual):
    """Check if predicted answer matches actual answer."""
    if pred is None or actual is None:
        return False

    pred_str = str(pred).strip()
    actual_str = str(actual).strip()

    if pred_str == actual_str:
        return True

    pred_normalized = normalize_mcq_answer(pred_str)
    actual_normalized = normalize_mcq_answer(actual_str)
    if pred_normalized == actual_normalized:
        return True

    return numerical_match(pred_str, actual_str)


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate KT by context length")
    parser.add_argument(
        "--results-dir", "-r",
        type=str,
        default="inference_data_kt_results",
        help="Directory containing JSONL results files"
    )
    parser.add_argument(
        "--output-dir", "-o",
        type=str,
        default="dataset_analysis/plots",
        help="Directory to save output plots"
    )
    parser.add_argument(
        "--no-plots",
        action="store_true",
        help="Skip generating plots"
    )
    return parser.parse_args()


def extract_model_name(filename):
    """Extract model identifier from filename."""
    basename = os.path.basename(filename)
    # Pattern: modelname_n500_bin10_hist50.jsonl
    for model_key in MODEL_NAMES.keys():
        if basename.startswith(model_key):
            return model_key
    return basename.replace('.jsonl', '')


def load_results(jsonl_path):
    """Load results from JSONL file."""
    results = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            if line.strip():
                results.append(json.loads(line))
    return results


def compute_metrics_by_bin(results):
    """Compute metrics grouped by history_size."""
    bins = defaultdict(list)

    for r in results:
        history_size = r.get('history_size', 50)
        bins[history_size].append(r)

    metrics = {}
    for history_size in sorted(bins.keys()):
        bin_results = bins[history_size]
        n = len(bin_results)

        # FKT: Collect valid predictions for AUC-ROC
        y_true = []
        y_pred = []
        for r in bin_results:
            actual = r.get('actual_score')
            pred = r.get('predicted_question_level')
            if actual is not None and pred is not None:
                y_true.append(int(actual))
                y_pred.append(int(pred))

        # Compute AUC-ROC (requires both classes present)
        fkt_auc = None
        if len(set(y_true)) == 2 and len(y_true) > 0:
            try:
                fkt_auc = roc_auc_score(y_true, y_pred)
            except ValueError:
                pass

        # FKT accuracy (for reference)
        fkt_correct = sum(1 for t, p in zip(y_true, y_pred) if t == p)
        fkt_acc = fkt_correct / len(y_true) if y_true else 0.0

        # Cognitive accuracy (answer match)
        cognitive_correct = sum(
            1 for r in bin_results
            if answers_match(r.get('predicted_student_answer'), r.get('actual_answer'))
        )

        metrics[history_size] = {
            'n': n,
            'fkt_auc': fkt_auc,
            'fkt_acc': fkt_acc,
            'fkt_valid': len(y_true),
            'cognitive_acc': cognitive_correct / n if n > 0 else 0.0,
        }

    return metrics


def print_table(all_metrics):
    """Print metrics table to console."""
    # Get all history sizes across all models
    all_history_sizes = sorted(set(
        hs for model_metrics in all_metrics.values()
        for hs in model_metrics.keys()
    ))

    # Header
    print("\n" + "=" * 100)
    print("KT Performance by Context Length (History Size)")
    print("=" * 100)

    # Print FKT AUC-ROC table
    print("\nFKT AUC-ROC (Question-Level):")
    print("-" * 80)
    header = f"{'History':>8}"
    for model_key in all_metrics.keys():
        header += f"  {MODEL_NAMES.get(model_key, model_key)[:20]:>20}"
    print(header)
    print("-" * 80)

    for hs in all_history_sizes:
        row = f"{hs:>8}"
        for model_key in all_metrics.keys():
            if hs in all_metrics[model_key]:
                auc = all_metrics[model_key][hs]['fkt_auc']
                if auc is not None:
                    row += f"  {auc:>20.3f}"
                else:
                    row += f"  {'N/A':>20}"
            else:
                row += f"  {'N/A':>20}"
        print(row)

    # Print Cognitive accuracy table
    print("\nCognitive Accuracy (Answer Prediction):")
    print("-" * 80)
    print(header)
    print("-" * 80)

    for hs in all_history_sizes:
        row = f"{hs:>8}"
        for model_key in all_metrics.keys():
            if hs in all_metrics[model_key]:
                acc = all_metrics[model_key][hs]['cognitive_acc']
                row += f"  {acc:>20.3f}"
            else:
                row += f"  {'N/A':>20}"
        print(row)


def plot_results(all_metrics, output_dir):
    """Generate plot with all models."""
    os.makedirs(output_dir, exist_ok=True)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Plot 1: FKT AUC-ROC
    for model_key, metrics in all_metrics.items():
        history_sizes = sorted(metrics.keys())
        # Filter out None values
        valid_hs = [hs for hs in history_sizes if metrics[hs]['fkt_auc'] is not None]
        fkt_aucs = [metrics[hs]['fkt_auc'] for hs in valid_hs]

        if valid_hs:
            axes[0].plot(
                valid_hs, fkt_aucs,
                marker='o', markersize=4,
                color=MODEL_COLORS.get(model_key, 'gray'),
                label=MODEL_NAMES.get(model_key, model_key),
                linewidth=2
            )

    axes[0].set_xlabel('History Size (# prior interactions)')
    axes[0].set_ylabel('AUC-ROC')
    axes[0].set_title('FKT AUC-ROC vs Context Length')
    axes[0].legend(loc='best')
    axes[0].grid(True, alpha=0.3)
    axes[0].set_xlim(40, 410)
    axes[0].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random')

    # Plot 2: Cognitive Accuracy
    for model_key, metrics in all_metrics.items():
        history_sizes = sorted(metrics.keys())
        cognitive_accs = [metrics[hs]['cognitive_acc'] for hs in history_sizes]

        axes[1].plot(
            history_sizes, cognitive_accs,
            marker='o', markersize=4,
            color=MODEL_COLORS.get(model_key, 'gray'),
            label=MODEL_NAMES.get(model_key, model_key),
            linewidth=2
        )

    axes[1].set_xlabel('History Size (# prior interactions)')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Cognitive Modeling Accuracy vs Context Length')
    axes[1].legend(loc='best')
    axes[1].grid(True, alpha=0.3)
    axes[1].set_xlim(40, 410)

    plt.tight_layout()

    plot_path = os.path.join(output_dir, 'kt_context_scaling.png')
    plt.savefig(plot_path, dpi=150)
    plt.close()

    print(f"\nSaved: {plot_path}")


def main():
    args = parse_args()

    # Find all JSONL files
    jsonl_files = glob(os.path.join(args.results_dir, '*.jsonl'))

    if not jsonl_files:
        print(f"No JSONL files found in {args.results_dir}")
        return

    print(f"Found {len(jsonl_files)} result files:")
    for f in jsonl_files:
        print(f"  - {os.path.basename(f)}")

    # Load and analyze each model
    all_metrics = {}

    for jsonl_path in sorted(jsonl_files):
        model_key = extract_model_name(jsonl_path)
        print(f"\nProcessing {MODEL_NAMES.get(model_key, model_key)}...")

        results = load_results(jsonl_path)
        print(f"  Loaded {len(results):,} predictions")

        metrics = compute_metrics_by_bin(results)
        all_metrics[model_key] = metrics

        # Print quick summary
        history_sizes = sorted(metrics.keys())
        valid_aucs = [metrics[hs]['fkt_auc'] for hs in history_sizes if metrics[hs]['fkt_auc'] is not None]
        avg_auc = np.mean(valid_aucs) if valid_aucs else 0.0
        avg_cognitive = np.mean([metrics[hs]['cognitive_acc'] for hs in history_sizes])
        print(f"  Avg FKT AUC-ROC: {avg_auc:.3f}")
        print(f"  Avg Cognitive accuracy: {avg_cognitive:.3f}")

    # Print detailed table
    print_table(all_metrics)

    # Generate plot
    if not args.no_plots:
        plot_results(all_metrics, args.output_dir)

    print("\nDone!")


if __name__ == "__main__":
    main()