File size: 4,773 Bytes
45e7dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
online.py -- persistent finetuning of the boss from real player fights.

Flow: the browser logs every boss decision (state, action, HP) during a fight and,
on fight end, POSTs the trajectory + who-died to /learn. We buffer trajectories and,
every MM_UPDATE_EVERY fights, run one REINFORCE step (mm_grad.OnlineLearner) that
nudges the HARD brain toward what actually worked against humans. The adapted
weights feed straight back into the live boss.

ONLY HARD-tier fights are used, so the data stays on-policy (Easy/Normal are
deliberately handicapped checkpoints; learning from them would be off-policy).

Persistence (optional, set as Space secrets):
  HF_TOKEN          - a write token
  MM_DATASET_REPO   - e.g. "your-name/boss-fight-online"
If set, adapted weights are pushed to / pulled from that dataset so learning
survives Space restarts. Without them it still adapts live, just in-memory.

Safety: a frozen copy of the sim-trained weights is kept as an anchor (the learner
pulls gently back toward it), so a weird run can't brick the boss.
"""
from __future__ import annotations

import os
import threading

import numpy as np

from mm_grad import OnlineLearner

HERE = os.path.dirname(os.path.abspath(__file__))
BASE_WEIGHTS = os.path.join(HERE, "mm_weights.npz")        # sim-trained HARD brain
LIVE_WEIGHTS = os.path.join(HERE, "mm_weights_live.npz")   # player-adapted snapshot

ENABLED = os.environ.get("MM_ONLINE", "1") == "1"
UPDATE_EVERY = int(os.environ.get("MM_UPDATE_EVERY", "3"))  # fights per update
ADAPT_TIER = "hard"
DATASET_REPO = os.environ.get("MM_DATASET_REPO")
HF_TOKEN = os.environ.get("HF_TOKEN")
WEIGHTS_IN_REPO = "mm_weights_live.npz"

_LOCK = threading.Lock()
_LEARNER = None
_BUFFER = []
_FIGHTS = 0


def _pull_from_dataset():
    """Try to download the latest adapted weights from the HF dataset."""
    if not (DATASET_REPO and HF_TOKEN):
        return None
    try:
        from huggingface_hub import hf_hub_download
        p = hf_hub_download(repo_id=DATASET_REPO, filename=WEIGHTS_IN_REPO,
                            repo_type="dataset", token=HF_TOKEN)
        return {k: v for k, v in np.load(p).items()}
    except Exception as e:
        print(f"[online] no persisted weights pulled ({e}); starting from sim weights")
        return None


def _push_to_dataset():
    """Upload the current adapted weights to the HF dataset (best-effort)."""
    if not (DATASET_REPO and HF_TOKEN):
        return
    try:
        from huggingface_hub import upload_file
        upload_file(path_or_fileobj=LIVE_WEIGHTS, path_in_repo=WEIGHTS_IN_REPO,
                    repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN)
    except Exception as e:
        print(f"[online] dataset push failed ({e})")


def get_learner():
    global _LEARNER
    if _LEARNER is None:
        base = _pull_from_dataset()
        if base is None and os.path.exists(LIVE_WEIGHTS):
            base = {k: v for k, v in np.load(LIVE_WEIGHTS).items()}
        if base is None:
            base = {k: v for k, v in np.load(BASE_WEIGHTS).items()}
        # anchor is ALWAYS the pristine sim-trained weights
        _LEARNER = OnlineLearner(base)
        _LEARNER.anchor = {k: v.astype(np.float64).copy()
                           for k, v in np.load(BASE_WEIGHTS).items()}
    return _LEARNER


def live_weights():
    """The HARD brain's current (possibly player-adapted) weights, shared live."""
    return get_learner().W


def _save_live():
    np.savez(LIVE_WEIGHTS, **{k: v.astype(np.float32) for k, v in get_learner().W.items()})
    _push_to_dataset()


def record_fight(trajectory: dict) -> dict:
    """Called by /learn. trajectory = {difficulty, steps:[{state,action,bossHP,playerHP}],
    result:{bossDied,playerDied}}. Buffers it and updates every UPDATE_EVERY fights."""
    global _FIGHTS
    if not ENABLED:
        return {"enabled": False}
    if trajectory.get("difficulty") != ADAPT_TIER:
        return {"enabled": True, "skipped": "only HARD-tier fights train the brain"}
    if len(trajectory.get("steps", [])) < 2:
        return {"enabled": True, "skipped": "too short"}
    with _LOCK:
        _BUFFER.append(trajectory)
        _FIGHTS += 1
        if len(_BUFFER) >= UPDATE_EVERY:
            stats = get_learner().update(list(_BUFFER))
            _BUFFER.clear()
            if stats.get("updated"):
                _save_live()
            return {"enabled": True, "fights": _FIGHTS, **stats}
        return {"enabled": True, "fights": _FIGHTS, "buffered": len(_BUFFER),
                "update_in": UPDATE_EVERY - len(_BUFFER)}


def status() -> dict:
    return {"enabled": ENABLED, "fights_seen": _FIGHTS, "buffered": len(_BUFFER),
            "persistent": bool(DATASET_REPO and HF_TOKEN), "adapt_tier": ADAPT_TIER}