File size: 4,179 Bytes
46bfd91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Simulate N users going through the study and verify all 50 items get covered.

Usage:
    cd /dfs/scratch1/echoi1/prolific_preferences
    HF_TOKEN=hf_... python scripts/test_coverage.py
"""
import sys
import uuid
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from src.config import load_config
from src.data import (
    ensure_datasets,
    assign_items,
    release_reservation,
    record_completion,
    _load_pool,
    _pool_path,
    _data_dir,
)


def simulate_user(cfg: dict, complete: bool = True) -> dict:
    user_id = str(uuid.uuid4())
    items   = assign_items(cfg, user_id)
    if complete:
        release_reservation(user_id, cfg)
        record_completion(user_id, items, cfg)
    item_ids = [(item.get("pair_id") or item.get("item_id", ""), item.get("category", ""))
                for item in items]
    return {"user_id": user_id, "items": item_ids, "raw_items": items, "completed": complete}


def clear_local_state(cfg: dict):
    data_dir = _data_dir(cfg)
    for pattern in ["reservations*", "completion_cache*", "local_completions*",
                    "variant_counter*", "alternation_counter*"]:
        for f in data_dir.glob(pattern):
            f.unlink()


def analyse_coverage(results: list, cfg: dict) -> bool:
    cats      = [c["name"] for c in cfg["categories"]]
    all_passed = True

    print()
    print("=" * 60)
    print("COVERAGE ANALYSIS")
    print("=" * 60)

    for cat in cats:
        pool     = _load_pool(str(_pool_path(cat, cfg)))
        pool_ids = [p.get("pair_id") or p.get("item_id", "") for p in pool]
        covered  = {pid: 0 for pid in pool_ids}

        for result in results:
            if not result["completed"]:
                continue
            for item_id, item_cat in result["items"]:
                if item_cat == cat and item_id in covered:
                    covered[item_id] += 1

        covered_once  = sum(1 for c in covered.values() if c >= 1)
        never_covered = [pid[:8] for pid, c in covered.items() if c == 0]
        over_covered  = [pid[:8] for pid, c in covered.items() if c > 1]

        print(f"\nCategory: {cat}")
        print(f"  Pool size:     {len(pool)}")
        print(f"  Covered >= 1x: {covered_once} / {len(pool)}")
        print(f"  Never covered: {len(never_covered)} {never_covered[:5]}")
        print(f"  Over-covered:  {len(over_covered)} {over_covered[:5]}")

        if covered_once == len(pool):
            print(f"  βœ… PASS β€” all {len(pool)} items covered")
        else:
            print(f"  ❌ FAIL β€” {len(pool) - covered_once} items not covered")
            all_passed = False

    print()
    print("=" * 60)
    print("OVERALL:", "βœ… PASS" if all_passed else "❌ FAIL")
    print("=" * 60)
    return all_passed


def run_simulation(label: str, n_users: int, dropout_indices: list = None):
    dropout_indices = dropout_indices or []
    cfg = load_config()
    ensure_datasets(cfg)
    clear_local_state(cfg)

    print(f"\n── {label} ──")
    print(f"[TEST] {n_users} users, dropouts at: {dropout_indices}")

    results = []
    for i in range(n_users):
        complete = i not in dropout_indices
        result   = simulate_user(cfg, complete=complete)
        results.append(result)
        status = "βœ… completed" if complete else "❌ abandoned"
        print(f"  User {i+1:2d} ({status}): "
              f"indices = {[r[0][:8] for r in result['items']]}")

    return analyse_coverage(results, cfg)


if __name__ == "__main__":
    # Test 1: perfect run β€” all 10 users complete, all 50 items covered exactly once
    run_simulation("Test 1: Perfect run", n_users=10)

    # Test 2: 2 dropouts β€” abandoned items should be picked up by extra users
    # The new sort_key means uncovered+reserved items are preferred over covered+unreserved
    # so items 35-39 (abandoned) get picked up by users 11-12 instead of re-covering 0-9
    run_simulation("Test 2: 2 dropouts, 12 users", n_users=12, dropout_indices=[7, 3])

    # Test 3: first user drops out β€” 11 users needed to cover all 50
    run_simulation("Test 3: First user drops out", n_users=11, dropout_indices=[0])