test / thompson.py
Artem
added base demo
5effdd5
import chex
import jax
from jax import random, jit, vmap
import jax.numpy as jnp
from functools import partial
from consts import EPS
from globals import UserInfo, Char, State
def norm_playtime(arr: chex.Array, cid: int) -> chex.Array:
max_playtime = jnp.max(arr) + EPS
norm = arr[cid] / max_playtime
return norm
@jit
def construct_feats(user: UserInfo, char: Char, char_id: int) -> chex.Array:
feats = [
user.skill_level,
jnp.log1p(user.games_played),
char.difficulty,
char.execution_level,
char.neutral_required,
char.tier,
]
feats.append(char.archetype_vec)
skill_match = 1.0 - jnp.abs(user.skill_level - (1.0 - char.difficulty))
feats.append(jnp.array([skill_match]))
archetype_sim = jnp.dot(user.pref_archetype, char.archetype_vec)
feats.append(jnp.array([archetype_sim]))
tried_before = user.chars_attempted_mask[char_id]
novelty_bonus = 1.0 - tried_before
feats.append(jnp.array([novelty_bonus]))
past_perf = jnp.where(tried_before > 0.5, user.wr[char_id], 0.5)
feats.append(jnp.array([past_perf]))
norm = norm_playtime(user.playtime, char_id)
feats.append(jnp.array([norm]))
return jnp.concatenate([jnp.atleast_1d(feat) for feat in feats])
@partial(jit, static_argnums=(2,))
def build_feats(user: UserInfo, chars: Char, n_chars: int):
def build_single(cid: int):
char = jax.tree.map(lambda x: x[cid], chars)
return construct_feats(user, char, cid)
return vmap(build_single)(jnp.arange(n_chars))
@jit
def sample_params(key: chex.PRNGKey, mu: chex.Array, Sigma: chex.Array) -> chex.Array:
d = mu.shape[0]
Lambda = Sigma + EPS * jnp.eye(d)
theta = random.multivariate_normal(key, mu, Lambda)
return theta
@jit
def compute_expected_rewards(thetas: chex.Array, feats: chex.Array) -> chex.Array:
return vmap(jnp.dot)(thetas, feats)
@jit
def thompson_sample(
key: chex.PRNGKey, state: State, feats: chex.Array
) -> tuple[chex.Array, chex.Array]:
num_chars = feats.shape[0]
keys = random.split(key, num_chars)
thetas = vmap(sample_params)(keys, state.mu, state.Sigma)
rewards = compute_expected_rewards(thetas, feats)
return rewards, thetas
@jit
def update_posterior(
state: State,
char_id: int,
feats: chex.Array,
reward: float,
noise_var: float = 1.0,
use_adaptive_noise: bool = True,
) -> State:
x = feats
d = x.shape[0]
mu_old = state.mu[char_id]
sigma_old = state.Sigma[char_id]
# might be numerically unstable, not sure... for noninvertivle matrices should check this later when not lazy
# ugly and hacky but idk how to approx this outside of inv, solve and do op, then inv to undo
Sigma_old_inv = jnp.linalg.inv(sigma_old + EPS * jnp.eye(d))
Sigma_new_inv = Sigma_old_inv + (1.0 / noise_var) * jnp.outer(x, x)
Sigma_new = jnp.linalg.inv(Sigma_new_inv)
mu_new = Sigma_new @ (Sigma_old_inv @ mu_old + (reward / noise_var) * x)
new_mu = state.mu.at[char_id].set(mu_new)
new_Sigma = state.Sigma.at[char_id].set(Sigma_new)
# TODO: figure out whether adaptive noise in gp is needed
new_beta = None
if use_adaptive_noise:
new_beta = state.beta.at[char_id].add(1)
return State(
mu=new_mu,
Sigma=new_Sigma,
alpha=state.alpha,
beta=new_beta if new_beta is not None else state.beta,
)
@partial(jit, static_argnums=(2, 3))
def select_top_k_diverse(
scores: chex.Array, archetypes: chex.Array, k: int, diversity_thresh: float
) -> chex.Array:
n_chars = scores.shape[0]
sorted_idx = jnp.argsort(-scores)
def selection_step(carry, cand_idx):
select, cnt = carry
cand_idx = sorted_idx[cand_idx]
done = cnt > k
cand_arch = archetypes[cand_idx]
def check_item_diversity(sel_idx):
# may need a max bound here
is_valid = sel_idx >= 0
sel_arch = archetypes[sel_idx]
# cos_sim w little eps to avoid div 0
sim = jnp.dot(cand_arch, sel_arch) / (
jnp.linalg.norm(cand_arch) * jnp.linalg.norm(sel_arch) + 1e-8
)
return jnp.where(is_valid, sim < diversity_thresh, True)
all_diverse = jnp.all(vmap(check_item_diversity)(select))
add_op = jnp.logical_and(jnp.logical_not(done), all_diverse)
new_sel = jnp.where(add_op, select.at[cnt].set(cand_idx), select)
new_cnt = jnp.where(add_op, cnt + 1, cnt)
return (new_sel, new_cnt), None
init = jnp.full(k, -1, dtype=jnp.int32)
init = init.at[0].set(sorted_idx[0])
(final_sel, null), null = jax.lax.scan(
selection_step, (init, 1), jnp.arange(1, n_chars)
)
return final_sel
@jit
def compute_reward(
won: bool, completed: bool, rating: float, playtime_mins: float, weights:chex.Array = jnp.array([0.3, 0.15, 0.25, 0.3])
) -> float:
win_reward = jnp.where(won, weights[0], 0.0)
completion_reward = jnp.where(completed, weights[1], 0.0)
rating_reward = weights[2] * jnp.clip(rating / 5.0, 0.0, 1.0)
engagement_reward = weights[3] * jnp.clip(jnp.log1p(playtime_mins) / 5.0, 0.0, 1.0)
return win_reward + completion_reward + rating_reward + engagement_reward
@partial(jit, static_argnums=(4, 5))
def recommend_characters(
key: chex.PRNGKey,
state: State,
user: UserInfo,
characters: Char,
n_chars: int,
top_k: int = 3,
diversity_threshold: float = 0.75,
) -> tuple[chex.Array, chex.Array]:
features = build_feats(user, characters, n_chars)
sampled_rewards, sampled_thetas = thompson_sample(key, state, features)
selected = select_top_k_diverse(
sampled_rewards, characters.archetype_vec, top_k, diversity_threshold
)
return selected, sampled_rewards
def init_thompson(n_chars: int, feature_dim: int, prior_var: float = 1.0) -> State:
return State(
mu=jnp.zeros((n_chars, feature_dim)),
Sigma=jnp.tile(prior_var * jnp.eye(feature_dim), (n_chars, 1, 1)),
alpha=jnp.ones(n_chars),
beta=jnp.ones(n_chars),
)
@jit
def batch_update_posterior(
state: State,
char_ids: chex.Array,
features: chex.Array,
rewards: chex.Array,
noise_var: float = 1.0,
) -> State:
def single_update(s, data):
char_id, feat, reward = data
return update_posterior(s, char_id, feat, reward, noise_var), None
final_state, _ = jax.lax.scan(single_update, state, (char_ids, features, rewards))
return final_state