Artem commited on
Commit
5effdd5
·
1 Parent(s): af442f3

added base demo

Browse files
README.md CHANGED
@@ -1,7 +1,8 @@
1
  ---
2
  title: "Test"
3
  sdk: "gradio"
 
4
  python_version: "3.14"
5
- app_file: app.py
6
  pinned: True
7
  ---
 
1
  ---
2
  title: "Test"
3
  sdk: "gradio"
4
+ sdk_version: "6.4.0"
5
  python_version: "3.14"
6
+ app_file: grad_app.py
7
  pinned: True
8
  ---
__pycache__/app.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
__pycache__/consts.cpython-312.pyc ADDED
Binary file (166 Bytes). View file
 
__pycache__/globals.cpython-312.pyc ADDED
Binary file (1.86 kB). View file
 
__pycache__/thompson.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
consts.py ADDED
@@ -0,0 +1 @@
 
 
1
+ EPS = 1e-6
globals.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+ import chex
3
+
4
+
5
+ class State(NamedTuple):
6
+ mu: chex.Array
7
+ Sigma: chex.Array
8
+ alpha: chex.Array
9
+ beta: chex.Array
10
+
11
+
12
+ class Char(NamedTuple):
13
+ """
14
+ difficulty, 1-5 * as per in game
15
+ archetype_vec: zoner, grappler, strikethrow, etc...
16
+ execution barrier: harder to quantify, would an ebedding be better?
17
+ footsies/neutral, how brainded is the char. 2mk -> dr = -points. harder buttons, less pokes = more neutral that needs to be played
18
+ tier: float/int, using like a couple of pros' tier lists maybe
19
+
20
+ """
21
+ difficulty: float
22
+ archetype_vec: chex.Array
23
+ execution_level: float
24
+ neutral_required: float
25
+ tier: float
26
+
27
+
28
+ class UserInfo(NamedTuple):
29
+ """
30
+ the one that should be updated over time.
31
+ """
32
+ skill_level: float
33
+ games_played: int
34
+ chars_attempted_mask: chex.Array
35
+ wr: chex.Array
36
+ playtime: chex.Array
37
+ pref_archetype: chex.Array
38
+
grad_app.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from jax import random
5
+ from jax.random import PRNGKey
6
+ import json
7
+ from globals import Char, State, UserInfo
8
+ from thompson import (
9
+ init_thompson,
10
+ recommend_characters,
11
+ update_posterior,
12
+ compute_reward,
13
+ construct_feats,
14
+ )
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+
17
+
18
+ class LMCharacterKnowledge:
19
+ def __init__(self, model_name: str, game_name: str):
20
+ self.game_name = game_name
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
23
+ self.prompt = [
24
+ {
25
+ "role": "system",
26
+ "content": "You are a knowledgeable bastion of fighting game knowledge. Your goal is to answer questions as best as possible about the game you are asked about.",
27
+ }
28
+ ]
29
+ self.cache = {}
30
+
31
+ def ask_lm(self, prompt, max_tok: int = 4096):
32
+ try:
33
+ messages = self.prompt + [{"role": "user", "content": prompt}]
34
+ inputs = self.tokenizer.apply_chat_template(
35
+ messages,
36
+ add_generation_prompt=True,
37
+ tokenize=True,
38
+ return_dict=True,
39
+ return_tensors="pt",
40
+ )
41
+ outputs = self.model.generate(**inputs, max_new_tokens=512)
42
+ result = self.tokenizer.decode(
43
+ outputs[0][inputs["input_ids"].shape[-1] :], skip_special_tokens=True
44
+ )
45
+ print(result)
46
+ return result
47
+ except Exception as e:
48
+ print(f"Couldn't query{self.model}, error: {e}")
49
+
50
+ def get_roster(self) -> list[str]:
51
+ cache_key = f"roster_{self.game_name}"
52
+ if cache_key in self.cache:
53
+ return self.cache[cache_key]
54
+
55
+ roster_prompt = f"""
56
+ List ALL playable characters in {self.game_name}. Return a structured json array of character names, nothing else at all.
57
+ Example format is : ["Ryu", "Ken", "Chun Li", "Akuma"]
58
+ """
59
+
60
+ response = self.ask_lm(roster_prompt)
61
+
62
+ try:
63
+ start = response.find("[")
64
+ end = response.find("]") + 1
65
+
66
+ if start != -1 and end > start:
67
+ roster = json.loads(response[start:end])
68
+ self.cache[cache_key] = roster
69
+ return roster
70
+ except:
71
+ # TODO: handle errors here way better
72
+ pass
73
+
74
+ return ["Ryu", "Ken", "Luke"]
75
+
76
+ def get_character_data(self, char_name: str) -> dict:
77
+ cache_key = f"char_{self.game_name}_{char_name}"
78
+ if cache_key in self.cache:
79
+ return self.cache[cache_key]
80
+
81
+ char_data_prompt = f"""
82
+ for the character {char_name} in the game {
83
+ self.game_name
84
+ },
85
+ provide some statistics in explicit JSON format:
86
+
87
+ Example format:
88
+ {{
89
+ "difficulty": 0.7,
90
+ "execution_barrier": 0.6,
91
+ "neutral_intensity": 0.5,
92
+ "tier": 0.8,
93
+ "archetypes": {{
94
+ "rushdown": 0.8,
95
+ "zoner": 0.1,
96
+ "grappler": 0.0,
97
+ "all_rounder": 0.1,
98
+ "setplay": 0.0,
99
+ "footsies": 0.0
100
+ }}
101
+ }}
102
+
103
+ Replace ALL values with actual numbers for {char_name}. Return ONLY the JSON object, nothing else.
104
+ """
105
+
106
+ response = self.ask_lm(char_data_prompt, max_tok=300)
107
+ print(f"Raw response for {char_name}: {response}")
108
+
109
+ try:
110
+ start = response.find("{")
111
+ if start == -1:
112
+ raise ValueError("No opening brace found")
113
+
114
+ brace_count = 0
115
+ end = -1
116
+ for i in range(start, len(response)):
117
+ if response[i] == '{':
118
+ brace_count += 1
119
+ elif response[i] == '}':
120
+ brace_count -= 1
121
+ if brace_count == 0:
122
+ end = i + 1
123
+ break
124
+
125
+ if end == -1:
126
+ raise ValueError("No matching closing brace found")
127
+
128
+ json_str = response[start:end]
129
+ print(f"Extracted JSON: {json_str}")
130
+
131
+ data = json.loads(json_str)
132
+
133
+ required_keys = ["difficulty", "execution_barrier", "neutral_intensity", "tier", "archetypes"]
134
+ if not all(key in data for key in required_keys):
135
+ raise ValueError(f"Missing required keys in parsed data")
136
+
137
+ self.cache[cache_key] = data
138
+ return data
139
+
140
+ except Exception as e:
141
+ print(f"Couldn't parse {char_name}'s data: {e}")
142
+ print(f"Response was: {response[:200]}...")
143
+
144
+ return {
145
+ "difficulty": 0.5,
146
+ "execution_barrier": 0.5,
147
+ "neutral_intensity": 0.5,
148
+ "tier": 0.5,
149
+ "archetypes": {
150
+ "rushdown": 0.3,
151
+ "zoner": 0.3,
152
+ "grappler": 0.1,
153
+ "all_rounder": 0.2,
154
+ "setplay": 0.05,
155
+ "footsies": 0.05,
156
+ },
157
+ }
158
+
159
+ def build_roster(self) -> tuple[list[Char], list[str]]:
160
+ roster = self.get_roster()
161
+ chars = []
162
+
163
+ for i, char_name in enumerate(roster):
164
+ data = self.get_character_data(char_name)
165
+ archetype_order = [
166
+ "rushdown",
167
+ "zoner",
168
+ "grappler",
169
+ "all_rounder",
170
+ "setplay",
171
+ "footsies",
172
+ ]
173
+ archetype_vec = jnp.array(
174
+ [data["archetypes"].get(a, 0.0) for a in archetype_order]
175
+ )
176
+
177
+ archetype_vec = archetype_vec / (jnp.sum(archetype_vec) + 1e-8)
178
+
179
+ char = Char(
180
+ difficulty=data["difficulty"],
181
+ archetype_vec=archetype_vec,
182
+ execution_level=data["execution_barrier"],
183
+ neutral_required=data["neutral_intensity"],
184
+ tier=data["tier"],
185
+ )
186
+ chars.append(char)
187
+
188
+ batched_chars = Char(
189
+ difficulty=jnp.array([c.difficulty for c in chars]),
190
+ archetype_vec=jnp.stack([c.archetype_vec for c in chars]),
191
+ execution_level=jnp.array([c.execution_level for c in chars]),
192
+ neutral_required=jnp.array([c.neutral_required for c in chars]),
193
+ tier=jnp.array([c.tier for c in chars]),
194
+ )
195
+
196
+ return batched_chars, roster
197
+
198
+
199
+ class FGRecommender:
200
+ def __init__(self):
201
+ self.lm = None
202
+ self.chars = None
203
+ self.roster = None
204
+ self.state = None
205
+ self.user = None
206
+ self.key = PRNGKey(67)
207
+ self.n_archetypes = 6
208
+ self.history = []
209
+
210
+ def init_game(self, game_name: str) -> str:
211
+ if not game_name.strip():
212
+ return "please enter name of game"
213
+
214
+ try:
215
+ self.lm = LMCharacterKnowledge(model_name="LiquidAI/LFM2-350M", game_name = game_name)
216
+ self.chars, self.roster = self.lm.build_roster()
217
+
218
+ n_chars = len(self.roster)
219
+ feature_dim = 17
220
+
221
+ self.state = init_thompson(n_chars, feature_dim)
222
+
223
+ self.user = UserInfo(
224
+ skill_level=0.3,
225
+ games_played=0,
226
+ chars_attempted_mask=jnp.zeros(n_chars),
227
+ wr=jnp.ones(n_chars) * 0.5,
228
+ playtime=jnp.zeros(n_chars),
229
+ pref_archetype=jnp.zeros(self.n_archetypes),
230
+ )
231
+
232
+ return f"loaded {n_chars} from {game_name}"
233
+ except Exception as e:
234
+ return f"Error: {e}"
235
+
236
+ def get_recs(self, top_k: int = 5) -> tuple[str, str]:
237
+ if self.state is None:
238
+ return "please init game"
239
+
240
+ self.key, subkey = random.split(self.key)
241
+
242
+ sel, sample_rewards = recommend_characters(
243
+ subkey,
244
+ self.state,
245
+ self.user,
246
+ self.chars,
247
+ len(self.roster),
248
+ top_k=top_k,
249
+ diversity_threshold=0.75,
250
+ )
251
+
252
+ recommend_text = "## Recommended Chars: \n\n"
253
+ for i, char_idx in enumerate(sel):
254
+ char_idx = int(char_idx)
255
+ if char_idx < 0:
256
+ continue
257
+
258
+ char_name = self.roster[char_idx]
259
+ reward = float(sample_rewards[char_idx])
260
+ tried = bool(self.user.chars_attempted_mask[char_idx] > 0.5)
261
+
262
+ status = "NEW" if not tried else "TRIED"
263
+
264
+ recommend_text += f"### {i + 1}. {char_name} {status} \n"
265
+ recommend_text += f"expected_reward: {reward: .4f} \n"
266
+ recommend_text += f"difficulty: {self.chars.difficulty[char_idx]:.2f}\n"
267
+ recommend_text += f" Tier: {self.chars.tier[char_idx]:.2f}\n\n"
268
+
269
+ char_opts = [self.roster[int(idx)] for idx in sel if idx >= 0]
270
+
271
+ return recommend_text, gr.Dropdown(
272
+ choices=char_opts, value=char_opts[0] if char_opts else None
273
+ )
274
+
275
+ def record_feedback(
276
+ self, char_name: str, won: bool, rating: float, playtime: float
277
+ ) -> str:
278
+ if self.state is None or char_name is None:
279
+ return "get recs first"
280
+
281
+ try:
282
+ char_idx = self.roster.index(char_name)
283
+ except ValueError:
284
+ return f"char {char_name} not found"
285
+
286
+ sel_char_obj = jax.tree.map(lambda x: x[char_idx], self.chars)
287
+ feats = construct_feats(self.user, sel_char_obj, char_idx)
288
+
289
+ reward = compute_reward(
290
+ won=won, completed=True, rating=rating, playtime_mins=playtime
291
+ )
292
+ self.user = self.user._replace(
293
+ games_played=self.user.games_played + 1,
294
+ chars_attempted_mask=self.user.chars_attempted_mask.at[char_idx].set(1),
295
+ wr=self.user.wr.at[char_idx].set(
296
+ 0.8 * self.user.wr[char_idx] + 0.2 * float(won)
297
+ ),
298
+ playtime=self.user.playtime.at[char_idx].add(playtime),
299
+ )
300
+
301
+ self.history.append(
302
+ {
303
+ "character": char_name,
304
+ "won": won,
305
+ "rating": rating,
306
+ "reward": float(reward),
307
+ }
308
+ )
309
+
310
+ return f"recorded {char_name}'s feedback! Reward was {reward:.4f}"
311
+
312
+ def get_stats(self) -> str:
313
+ if self.user is None:
314
+ return "no stats lol. play some games u scrub"
315
+
316
+ tried = int(jnp.sum(self.user.chars_attempted_mask))
317
+ total = len(self.roster)
318
+ avg_wr = float(jnp.mean(self.user.wr))
319
+
320
+ stats = f"""## Your Stats
321
+
322
+ - **Games played:** {self.user.games_played}
323
+ - **Characters tried:** {tried}/{total}
324
+ - **Average win rate:** {avg_wr:.1%}
325
+ - **Skill level:** {self.user.skill_level:.2f}
326
+ """
327
+ if tried > 0:
328
+ top_indices = jnp.argsort(-self.user.playtime)[:5]
329
+ stats += "\n###Most Played:\n"
330
+ for idx in top_indices:
331
+ idx = int(idx)
332
+ playtime = float(self.user.playtime[idx])
333
+ if playtime > 0:
334
+ char_name = self.roster[idx]
335
+ wr = float(self.user.wr[idx])
336
+ stats += f"- **{char_name}**: {playtime:.0f}m, {wr:.1%} WR\n"
337
+
338
+ return stats
339
+
340
+ #
341
+ app = FGRecommender()
342
+
343
+
344
+ def create_ui():
345
+ with gr.Blocks(
346
+ title="Fighting Game Character Recommender", theme=gr.themes.Soft()
347
+ ) as demo:
348
+ gr.Markdown("# Fighting Game Character Recommender")
349
+
350
+ with gr.Row():
351
+ with gr.Column(scale=1):
352
+ gr.Markdown("### Setup")
353
+ game_input = gr.Textbox(
354
+ label="Game Name",
355
+ placeholder="e.g., Street Fighter 6, Guilty Gear Strive",
356
+ value="Street Fighter 6",
357
+ )
358
+ init_btn = gr.Button("Initialize Game", variant="primary")
359
+ init_output = gr.Markdown()
360
+
361
+ gr.Markdown("### User Profile")
362
+ skill_slider = gr.Slider(0.0, 1.0, value=0.3, label="Skill Level")
363
+
364
+ stats_display = gr.Markdown("No stats yet")
365
+ refresh_stats_btn = gr.Button("Refresh Stats")
366
+
367
+ with gr.Column(scale=2):
368
+ gr.Markdown("### Recommendations")
369
+ top_k_slider = gr.Slider(
370
+ 1, 5, value=3, step=1, label="Number of Recommendations"
371
+ )
372
+ get_rec_btn = gr.Button("Get Recommendations", variant="primary")
373
+ rec_output = gr.Markdown()
374
+
375
+ gr.Markdown("### Record Feedback")
376
+ with gr.Row():
377
+ char_dropdown = gr.Dropdown(label="Character Played", choices=[])
378
+ won_checkbox = gr.Checkbox(label="Won?", value=False)
379
+
380
+ with gr.Row():
381
+ rating_slider = gr.Slider(
382
+ 1, 5, value=3, step=0.5, label="Rating (1-5)"
383
+ )
384
+ playtime_slider = gr.Slider(
385
+ 5, 60, value=20, step=5, label="Playtime (minutes)"
386
+ )
387
+
388
+ submit_btn = gr.Button("Submit Feedback", variant="secondary")
389
+ feedback_output = gr.Markdown()
390
+
391
+ def init_game(game_name):
392
+ result = app.init_game(game_name)
393
+ stats = app.get_stats()
394
+ return result, stats
395
+
396
+ init_btn.click(
397
+ init_game, inputs=[game_input], outputs=[init_output, stats_display]
398
+ )
399
+
400
+ get_rec_btn.click(
401
+ lambda k: app.get_recs(int(k)),
402
+ inputs=[top_k_slider],
403
+ outputs=[rec_output, char_dropdown],
404
+ )
405
+
406
+ submit_btn.click(
407
+ app.record_feedback,
408
+ inputs=[char_dropdown, won_checkbox, rating_slider, playtime_slider],
409
+ outputs=[feedback_output],
410
+ )
411
+
412
+ refresh_stats_btn.click(app.get_stats, outputs=[stats_display])
413
+
414
+ return demo
415
+
416
+ #
417
+ if __name__ == "__main__":
418
+
419
+ demo = create_ui()
420
+ demo.launch()
421
+
main.py DELETED
@@ -1,6 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from jax
4
- class LinUCB():
5
-
6
-
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -4,4 +4,11 @@ version = "0.1.0"
4
  description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
- dependencies = []
 
 
 
 
 
 
 
 
4
  description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
+ dependencies = [
8
+ "chex>=0.1.91",
9
+ "gradio>=6.4.0",
10
+ "jax>=0.9.0",
11
+ "jaxlib>=0.9.0",
12
+ "torch>=2.10.0",
13
+ "transformers>=4.57.6",
14
+ ]
thompson.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chex
2
+ import jax
3
+ from jax import random, jit, vmap
4
+ import jax.numpy as jnp
5
+ from functools import partial
6
+ from consts import EPS
7
+ from globals import UserInfo, Char, State
8
+
9
+
10
+ def norm_playtime(arr: chex.Array, cid: int) -> chex.Array:
11
+ max_playtime = jnp.max(arr) + EPS
12
+ norm = arr[cid] / max_playtime
13
+ return norm
14
+
15
+
16
+ @jit
17
+ def construct_feats(user: UserInfo, char: Char, char_id: int) -> chex.Array:
18
+ feats = [
19
+ user.skill_level,
20
+ jnp.log1p(user.games_played),
21
+ char.difficulty,
22
+ char.execution_level,
23
+ char.neutral_required,
24
+ char.tier,
25
+ ]
26
+ feats.append(char.archetype_vec)
27
+ skill_match = 1.0 - jnp.abs(user.skill_level - (1.0 - char.difficulty))
28
+
29
+ feats.append(jnp.array([skill_match]))
30
+
31
+ archetype_sim = jnp.dot(user.pref_archetype, char.archetype_vec)
32
+ feats.append(jnp.array([archetype_sim]))
33
+
34
+ tried_before = user.chars_attempted_mask[char_id]
35
+ novelty_bonus = 1.0 - tried_before
36
+ feats.append(jnp.array([novelty_bonus]))
37
+
38
+ past_perf = jnp.where(tried_before > 0.5, user.wr[char_id], 0.5)
39
+
40
+ feats.append(jnp.array([past_perf]))
41
+
42
+ norm = norm_playtime(user.playtime, char_id)
43
+ feats.append(jnp.array([norm]))
44
+
45
+ return jnp.concatenate([jnp.atleast_1d(feat) for feat in feats])
46
+
47
+
48
+ @partial(jit, static_argnums=(2,))
49
+ def build_feats(user: UserInfo, chars: Char, n_chars: int):
50
+ def build_single(cid: int):
51
+ char = jax.tree.map(lambda x: x[cid], chars)
52
+ return construct_feats(user, char, cid)
53
+
54
+ return vmap(build_single)(jnp.arange(n_chars))
55
+
56
+
57
+ @jit
58
+ def sample_params(key: chex.PRNGKey, mu: chex.Array, Sigma: chex.Array) -> chex.Array:
59
+ d = mu.shape[0]
60
+ Lambda = Sigma + EPS * jnp.eye(d)
61
+ theta = random.multivariate_normal(key, mu, Lambda)
62
+ return theta
63
+
64
+
65
+ @jit
66
+ def compute_expected_rewards(thetas: chex.Array, feats: chex.Array) -> chex.Array:
67
+ return vmap(jnp.dot)(thetas, feats)
68
+
69
+
70
+ @jit
71
+ def thompson_sample(
72
+ key: chex.PRNGKey, state: State, feats: chex.Array
73
+ ) -> tuple[chex.Array, chex.Array]:
74
+ num_chars = feats.shape[0]
75
+ keys = random.split(key, num_chars)
76
+
77
+ thetas = vmap(sample_params)(keys, state.mu, state.Sigma)
78
+ rewards = compute_expected_rewards(thetas, feats)
79
+ return rewards, thetas
80
+
81
+
82
+ @jit
83
+ def update_posterior(
84
+ state: State,
85
+ char_id: int,
86
+ feats: chex.Array,
87
+ reward: float,
88
+ noise_var: float = 1.0,
89
+ use_adaptive_noise: bool = True,
90
+ ) -> State:
91
+ x = feats
92
+ d = x.shape[0]
93
+ mu_old = state.mu[char_id]
94
+ sigma_old = state.Sigma[char_id]
95
+
96
+ # might be numerically unstable, not sure... for noninvertivle matrices should check this later when not lazy
97
+ # ugly and hacky but idk how to approx this outside of inv, solve and do op, then inv to undo
98
+
99
+ Sigma_old_inv = jnp.linalg.inv(sigma_old + EPS * jnp.eye(d))
100
+ Sigma_new_inv = Sigma_old_inv + (1.0 / noise_var) * jnp.outer(x, x)
101
+ Sigma_new = jnp.linalg.inv(Sigma_new_inv)
102
+
103
+ mu_new = Sigma_new @ (Sigma_old_inv @ mu_old + (reward / noise_var) * x)
104
+
105
+ new_mu = state.mu.at[char_id].set(mu_new)
106
+ new_Sigma = state.Sigma.at[char_id].set(Sigma_new)
107
+
108
+ # TODO: figure out whether adaptive noise in gp is needed
109
+ new_beta = None
110
+
111
+ if use_adaptive_noise:
112
+ new_beta = state.beta.at[char_id].add(1)
113
+ return State(
114
+ mu=new_mu,
115
+ Sigma=new_Sigma,
116
+ alpha=state.alpha,
117
+ beta=new_beta if new_beta is not None else state.beta,
118
+ )
119
+
120
+
121
+ @partial(jit, static_argnums=(2, 3))
122
+ def select_top_k_diverse(
123
+ scores: chex.Array, archetypes: chex.Array, k: int, diversity_thresh: float
124
+ ) -> chex.Array:
125
+ n_chars = scores.shape[0]
126
+ sorted_idx = jnp.argsort(-scores)
127
+
128
+ def selection_step(carry, cand_idx):
129
+ select, cnt = carry
130
+ cand_idx = sorted_idx[cand_idx]
131
+
132
+ done = cnt > k
133
+
134
+ cand_arch = archetypes[cand_idx]
135
+
136
+ def check_item_diversity(sel_idx):
137
+ # may need a max bound here
138
+ is_valid = sel_idx >= 0
139
+ sel_arch = archetypes[sel_idx]
140
+ # cos_sim w little eps to avoid div 0
141
+
142
+ sim = jnp.dot(cand_arch, sel_arch) / (
143
+ jnp.linalg.norm(cand_arch) * jnp.linalg.norm(sel_arch) + 1e-8
144
+ )
145
+ return jnp.where(is_valid, sim < diversity_thresh, True)
146
+
147
+ all_diverse = jnp.all(vmap(check_item_diversity)(select))
148
+
149
+ add_op = jnp.logical_and(jnp.logical_not(done), all_diverse)
150
+
151
+ new_sel = jnp.where(add_op, select.at[cnt].set(cand_idx), select)
152
+ new_cnt = jnp.where(add_op, cnt + 1, cnt)
153
+ return (new_sel, new_cnt), None
154
+
155
+ init = jnp.full(k, -1, dtype=jnp.int32)
156
+ init = init.at[0].set(sorted_idx[0])
157
+
158
+ (final_sel, null), null = jax.lax.scan(
159
+ selection_step, (init, 1), jnp.arange(1, n_chars)
160
+ )
161
+ return final_sel
162
+
163
+
164
+ @jit
165
+ def compute_reward(
166
+ won: bool, completed: bool, rating: float, playtime_mins: float, weights:chex.Array = jnp.array([0.3, 0.15, 0.25, 0.3])
167
+ ) -> float:
168
+ win_reward = jnp.where(won, weights[0], 0.0)
169
+ completion_reward = jnp.where(completed, weights[1], 0.0)
170
+ rating_reward = weights[2] * jnp.clip(rating / 5.0, 0.0, 1.0)
171
+
172
+ engagement_reward = weights[3] * jnp.clip(jnp.log1p(playtime_mins) / 5.0, 0.0, 1.0)
173
+ return win_reward + completion_reward + rating_reward + engagement_reward
174
+
175
+
176
+ @partial(jit, static_argnums=(4, 5))
177
+ def recommend_characters(
178
+ key: chex.PRNGKey,
179
+ state: State,
180
+ user: UserInfo,
181
+ characters: Char,
182
+ n_chars: int,
183
+ top_k: int = 3,
184
+ diversity_threshold: float = 0.75,
185
+ ) -> tuple[chex.Array, chex.Array]:
186
+ features = build_feats(user, characters, n_chars)
187
+ sampled_rewards, sampled_thetas = thompson_sample(key, state, features)
188
+
189
+ selected = select_top_k_diverse(
190
+ sampled_rewards, characters.archetype_vec, top_k, diversity_threshold
191
+ )
192
+
193
+ return selected, sampled_rewards
194
+
195
+
196
+ def init_thompson(n_chars: int, feature_dim: int, prior_var: float = 1.0) -> State:
197
+ return State(
198
+ mu=jnp.zeros((n_chars, feature_dim)),
199
+ Sigma=jnp.tile(prior_var * jnp.eye(feature_dim), (n_chars, 1, 1)),
200
+ alpha=jnp.ones(n_chars),
201
+ beta=jnp.ones(n_chars),
202
+ )
203
+
204
+
205
+ @jit
206
+ def batch_update_posterior(
207
+ state: State,
208
+ char_ids: chex.Array,
209
+ features: chex.Array,
210
+ rewards: chex.Array,
211
+ noise_var: float = 1.0,
212
+ ) -> State:
213
+ def single_update(s, data):
214
+ char_id, feat, reward = data
215
+ return update_posterior(s, char_id, feat, reward, noise_var), None
216
+
217
+ final_state, _ = jax.lax.scan(single_update, state, (char_ids, features, rewards))
218
+ return final_state
219
+
220
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff