File size: 6,565 Bytes
5effdd5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | import chex
import jax
from jax import random, jit, vmap
import jax.numpy as jnp
from functools import partial
from consts import EPS
from globals import UserInfo, Char, State
def norm_playtime(arr: chex.Array, cid: int) -> chex.Array:
max_playtime = jnp.max(arr) + EPS
norm = arr[cid] / max_playtime
return norm
@jit
def construct_feats(user: UserInfo, char: Char, char_id: int) -> chex.Array:
feats = [
user.skill_level,
jnp.log1p(user.games_played),
char.difficulty,
char.execution_level,
char.neutral_required,
char.tier,
]
feats.append(char.archetype_vec)
skill_match = 1.0 - jnp.abs(user.skill_level - (1.0 - char.difficulty))
feats.append(jnp.array([skill_match]))
archetype_sim = jnp.dot(user.pref_archetype, char.archetype_vec)
feats.append(jnp.array([archetype_sim]))
tried_before = user.chars_attempted_mask[char_id]
novelty_bonus = 1.0 - tried_before
feats.append(jnp.array([novelty_bonus]))
past_perf = jnp.where(tried_before > 0.5, user.wr[char_id], 0.5)
feats.append(jnp.array([past_perf]))
norm = norm_playtime(user.playtime, char_id)
feats.append(jnp.array([norm]))
return jnp.concatenate([jnp.atleast_1d(feat) for feat in feats])
@partial(jit, static_argnums=(2,))
def build_feats(user: UserInfo, chars: Char, n_chars: int):
def build_single(cid: int):
char = jax.tree.map(lambda x: x[cid], chars)
return construct_feats(user, char, cid)
return vmap(build_single)(jnp.arange(n_chars))
@jit
def sample_params(key: chex.PRNGKey, mu: chex.Array, Sigma: chex.Array) -> chex.Array:
d = mu.shape[0]
Lambda = Sigma + EPS * jnp.eye(d)
theta = random.multivariate_normal(key, mu, Lambda)
return theta
@jit
def compute_expected_rewards(thetas: chex.Array, feats: chex.Array) -> chex.Array:
return vmap(jnp.dot)(thetas, feats)
@jit
def thompson_sample(
key: chex.PRNGKey, state: State, feats: chex.Array
) -> tuple[chex.Array, chex.Array]:
num_chars = feats.shape[0]
keys = random.split(key, num_chars)
thetas = vmap(sample_params)(keys, state.mu, state.Sigma)
rewards = compute_expected_rewards(thetas, feats)
return rewards, thetas
@jit
def update_posterior(
state: State,
char_id: int,
feats: chex.Array,
reward: float,
noise_var: float = 1.0,
use_adaptive_noise: bool = True,
) -> State:
x = feats
d = x.shape[0]
mu_old = state.mu[char_id]
sigma_old = state.Sigma[char_id]
# might be numerically unstable, not sure... for noninvertivle matrices should check this later when not lazy
# ugly and hacky but idk how to approx this outside of inv, solve and do op, then inv to undo
Sigma_old_inv = jnp.linalg.inv(sigma_old + EPS * jnp.eye(d))
Sigma_new_inv = Sigma_old_inv + (1.0 / noise_var) * jnp.outer(x, x)
Sigma_new = jnp.linalg.inv(Sigma_new_inv)
mu_new = Sigma_new @ (Sigma_old_inv @ mu_old + (reward / noise_var) * x)
new_mu = state.mu.at[char_id].set(mu_new)
new_Sigma = state.Sigma.at[char_id].set(Sigma_new)
# TODO: figure out whether adaptive noise in gp is needed
new_beta = None
if use_adaptive_noise:
new_beta = state.beta.at[char_id].add(1)
return State(
mu=new_mu,
Sigma=new_Sigma,
alpha=state.alpha,
beta=new_beta if new_beta is not None else state.beta,
)
@partial(jit, static_argnums=(2, 3))
def select_top_k_diverse(
scores: chex.Array, archetypes: chex.Array, k: int, diversity_thresh: float
) -> chex.Array:
n_chars = scores.shape[0]
sorted_idx = jnp.argsort(-scores)
def selection_step(carry, cand_idx):
select, cnt = carry
cand_idx = sorted_idx[cand_idx]
done = cnt > k
cand_arch = archetypes[cand_idx]
def check_item_diversity(sel_idx):
# may need a max bound here
is_valid = sel_idx >= 0
sel_arch = archetypes[sel_idx]
# cos_sim w little eps to avoid div 0
sim = jnp.dot(cand_arch, sel_arch) / (
jnp.linalg.norm(cand_arch) * jnp.linalg.norm(sel_arch) + 1e-8
)
return jnp.where(is_valid, sim < diversity_thresh, True)
all_diverse = jnp.all(vmap(check_item_diversity)(select))
add_op = jnp.logical_and(jnp.logical_not(done), all_diverse)
new_sel = jnp.where(add_op, select.at[cnt].set(cand_idx), select)
new_cnt = jnp.where(add_op, cnt + 1, cnt)
return (new_sel, new_cnt), None
init = jnp.full(k, -1, dtype=jnp.int32)
init = init.at[0].set(sorted_idx[0])
(final_sel, null), null = jax.lax.scan(
selection_step, (init, 1), jnp.arange(1, n_chars)
)
return final_sel
@jit
def compute_reward(
won: bool, completed: bool, rating: float, playtime_mins: float, weights:chex.Array = jnp.array([0.3, 0.15, 0.25, 0.3])
) -> float:
win_reward = jnp.where(won, weights[0], 0.0)
completion_reward = jnp.where(completed, weights[1], 0.0)
rating_reward = weights[2] * jnp.clip(rating / 5.0, 0.0, 1.0)
engagement_reward = weights[3] * jnp.clip(jnp.log1p(playtime_mins) / 5.0, 0.0, 1.0)
return win_reward + completion_reward + rating_reward + engagement_reward
@partial(jit, static_argnums=(4, 5))
def recommend_characters(
key: chex.PRNGKey,
state: State,
user: UserInfo,
characters: Char,
n_chars: int,
top_k: int = 3,
diversity_threshold: float = 0.75,
) -> tuple[chex.Array, chex.Array]:
features = build_feats(user, characters, n_chars)
sampled_rewards, sampled_thetas = thompson_sample(key, state, features)
selected = select_top_k_diverse(
sampled_rewards, characters.archetype_vec, top_k, diversity_threshold
)
return selected, sampled_rewards
def init_thompson(n_chars: int, feature_dim: int, prior_var: float = 1.0) -> State:
return State(
mu=jnp.zeros((n_chars, feature_dim)),
Sigma=jnp.tile(prior_var * jnp.eye(feature_dim), (n_chars, 1, 1)),
alpha=jnp.ones(n_chars),
beta=jnp.ones(n_chars),
)
@jit
def batch_update_posterior(
state: State,
char_ids: chex.Array,
features: chex.Array,
rewards: chex.Array,
noise_var: float = 1.0,
) -> State:
def single_update(s, data):
char_id, feat, reward = data
return update_posterior(s, char_id, feat, reward, noise_var), None
final_state, _ = jax.lax.scan(single_update, state, (char_ids, features, rewards))
return final_state
|