test / grad_app.py
Artem
added base demo
5effdd5
import gradio as gr
import jax
import jax.numpy as jnp
from jax import random
from jax.random import PRNGKey
import json
from globals import Char, State, UserInfo
from thompson import (
init_thompson,
recommend_characters,
update_posterior,
compute_reward,
construct_feats,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
class LMCharacterKnowledge:
def __init__(self, model_name: str, game_name: str):
self.game_name = game_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.prompt = [
{
"role": "system",
"content": "You are a knowledgeable bastion of fighting game knowledge. Your goal is to answer questions as best as possible about the game you are asked about.",
}
]
self.cache = {}
def ask_lm(self, prompt, max_tok: int = 4096):
try:
messages = self.prompt + [{"role": "user", "content": prompt}]
inputs = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
outputs = self.model.generate(**inputs, max_new_tokens=512)
result = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[-1] :], skip_special_tokens=True
)
print(result)
return result
except Exception as e:
print(f"Couldn't query{self.model}, error: {e}")
def get_roster(self) -> list[str]:
cache_key = f"roster_{self.game_name}"
if cache_key in self.cache:
return self.cache[cache_key]
roster_prompt = f"""
List ALL playable characters in {self.game_name}. Return a structured json array of character names, nothing else at all.
Example format is : ["Ryu", "Ken", "Chun Li", "Akuma"]
"""
response = self.ask_lm(roster_prompt)
try:
start = response.find("[")
end = response.find("]") + 1
if start != -1 and end > start:
roster = json.loads(response[start:end])
self.cache[cache_key] = roster
return roster
except:
# TODO: handle errors here way better
pass
return ["Ryu", "Ken", "Luke"]
def get_character_data(self, char_name: str) -> dict:
cache_key = f"char_{self.game_name}_{char_name}"
if cache_key in self.cache:
return self.cache[cache_key]
char_data_prompt = f"""
for the character {char_name} in the game {
self.game_name
},
provide some statistics in explicit JSON format:
Example format:
{{
"difficulty": 0.7,
"execution_barrier": 0.6,
"neutral_intensity": 0.5,
"tier": 0.8,
"archetypes": {{
"rushdown": 0.8,
"zoner": 0.1,
"grappler": 0.0,
"all_rounder": 0.1,
"setplay": 0.0,
"footsies": 0.0
}}
}}
Replace ALL values with actual numbers for {char_name}. Return ONLY the JSON object, nothing else.
"""
response = self.ask_lm(char_data_prompt, max_tok=300)
print(f"Raw response for {char_name}: {response}")
try:
start = response.find("{")
if start == -1:
raise ValueError("No opening brace found")
brace_count = 0
end = -1
for i in range(start, len(response)):
if response[i] == '{':
brace_count += 1
elif response[i] == '}':
brace_count -= 1
if brace_count == 0:
end = i + 1
break
if end == -1:
raise ValueError("No matching closing brace found")
json_str = response[start:end]
print(f"Extracted JSON: {json_str}")
data = json.loads(json_str)
required_keys = ["difficulty", "execution_barrier", "neutral_intensity", "tier", "archetypes"]
if not all(key in data for key in required_keys):
raise ValueError(f"Missing required keys in parsed data")
self.cache[cache_key] = data
return data
except Exception as e:
print(f"Couldn't parse {char_name}'s data: {e}")
print(f"Response was: {response[:200]}...")
return {
"difficulty": 0.5,
"execution_barrier": 0.5,
"neutral_intensity": 0.5,
"tier": 0.5,
"archetypes": {
"rushdown": 0.3,
"zoner": 0.3,
"grappler": 0.1,
"all_rounder": 0.2,
"setplay": 0.05,
"footsies": 0.05,
},
}
def build_roster(self) -> tuple[list[Char], list[str]]:
roster = self.get_roster()
chars = []
for i, char_name in enumerate(roster):
data = self.get_character_data(char_name)
archetype_order = [
"rushdown",
"zoner",
"grappler",
"all_rounder",
"setplay",
"footsies",
]
archetype_vec = jnp.array(
[data["archetypes"].get(a, 0.0) for a in archetype_order]
)
archetype_vec = archetype_vec / (jnp.sum(archetype_vec) + 1e-8)
char = Char(
difficulty=data["difficulty"],
archetype_vec=archetype_vec,
execution_level=data["execution_barrier"],
neutral_required=data["neutral_intensity"],
tier=data["tier"],
)
chars.append(char)
batched_chars = Char(
difficulty=jnp.array([c.difficulty for c in chars]),
archetype_vec=jnp.stack([c.archetype_vec for c in chars]),
execution_level=jnp.array([c.execution_level for c in chars]),
neutral_required=jnp.array([c.neutral_required for c in chars]),
tier=jnp.array([c.tier for c in chars]),
)
return batched_chars, roster
class FGRecommender:
def __init__(self):
self.lm = None
self.chars = None
self.roster = None
self.state = None
self.user = None
self.key = PRNGKey(67)
self.n_archetypes = 6
self.history = []
def init_game(self, game_name: str) -> str:
if not game_name.strip():
return "please enter name of game"
try:
self.lm = LMCharacterKnowledge(model_name="LiquidAI/LFM2-350M", game_name = game_name)
self.chars, self.roster = self.lm.build_roster()
n_chars = len(self.roster)
feature_dim = 17
self.state = init_thompson(n_chars, feature_dim)
self.user = UserInfo(
skill_level=0.3,
games_played=0,
chars_attempted_mask=jnp.zeros(n_chars),
wr=jnp.ones(n_chars) * 0.5,
playtime=jnp.zeros(n_chars),
pref_archetype=jnp.zeros(self.n_archetypes),
)
return f"loaded {n_chars} from {game_name}"
except Exception as e:
return f"Error: {e}"
def get_recs(self, top_k: int = 5) -> tuple[str, str]:
if self.state is None:
return "please init game"
self.key, subkey = random.split(self.key)
sel, sample_rewards = recommend_characters(
subkey,
self.state,
self.user,
self.chars,
len(self.roster),
top_k=top_k,
diversity_threshold=0.75,
)
recommend_text = "## Recommended Chars: \n\n"
for i, char_idx in enumerate(sel):
char_idx = int(char_idx)
if char_idx < 0:
continue
char_name = self.roster[char_idx]
reward = float(sample_rewards[char_idx])
tried = bool(self.user.chars_attempted_mask[char_idx] > 0.5)
status = "NEW" if not tried else "TRIED"
recommend_text += f"### {i + 1}. {char_name} {status} \n"
recommend_text += f"expected_reward: {reward: .4f} \n"
recommend_text += f"difficulty: {self.chars.difficulty[char_idx]:.2f}\n"
recommend_text += f" Tier: {self.chars.tier[char_idx]:.2f}\n\n"
char_opts = [self.roster[int(idx)] for idx in sel if idx >= 0]
return recommend_text, gr.Dropdown(
choices=char_opts, value=char_opts[0] if char_opts else None
)
def record_feedback(
self, char_name: str, won: bool, rating: float, playtime: float
) -> str:
if self.state is None or char_name is None:
return "get recs first"
try:
char_idx = self.roster.index(char_name)
except ValueError:
return f"char {char_name} not found"
sel_char_obj = jax.tree.map(lambda x: x[char_idx], self.chars)
feats = construct_feats(self.user, sel_char_obj, char_idx)
reward = compute_reward(
won=won, completed=True, rating=rating, playtime_mins=playtime
)
self.user = self.user._replace(
games_played=self.user.games_played + 1,
chars_attempted_mask=self.user.chars_attempted_mask.at[char_idx].set(1),
wr=self.user.wr.at[char_idx].set(
0.8 * self.user.wr[char_idx] + 0.2 * float(won)
),
playtime=self.user.playtime.at[char_idx].add(playtime),
)
self.history.append(
{
"character": char_name,
"won": won,
"rating": rating,
"reward": float(reward),
}
)
return f"recorded {char_name}'s feedback! Reward was {reward:.4f}"
def get_stats(self) -> str:
if self.user is None:
return "no stats lol. play some games u scrub"
tried = int(jnp.sum(self.user.chars_attempted_mask))
total = len(self.roster)
avg_wr = float(jnp.mean(self.user.wr))
stats = f"""## Your Stats
- **Games played:** {self.user.games_played}
- **Characters tried:** {tried}/{total}
- **Average win rate:** {avg_wr:.1%}
- **Skill level:** {self.user.skill_level:.2f}
"""
if tried > 0:
top_indices = jnp.argsort(-self.user.playtime)[:5]
stats += "\n###Most Played:\n"
for idx in top_indices:
idx = int(idx)
playtime = float(self.user.playtime[idx])
if playtime > 0:
char_name = self.roster[idx]
wr = float(self.user.wr[idx])
stats += f"- **{char_name}**: {playtime:.0f}m, {wr:.1%} WR\n"
return stats
#
app = FGRecommender()
def create_ui():
with gr.Blocks(
title="Fighting Game Character Recommender", theme=gr.themes.Soft()
) as demo:
gr.Markdown("# Fighting Game Character Recommender")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Setup")
game_input = gr.Textbox(
label="Game Name",
placeholder="e.g., Street Fighter 6, Guilty Gear Strive",
value="Street Fighter 6",
)
init_btn = gr.Button("Initialize Game", variant="primary")
init_output = gr.Markdown()
gr.Markdown("### User Profile")
skill_slider = gr.Slider(0.0, 1.0, value=0.3, label="Skill Level")
stats_display = gr.Markdown("No stats yet")
refresh_stats_btn = gr.Button("Refresh Stats")
with gr.Column(scale=2):
gr.Markdown("### Recommendations")
top_k_slider = gr.Slider(
1, 5, value=3, step=1, label="Number of Recommendations"
)
get_rec_btn = gr.Button("Get Recommendations", variant="primary")
rec_output = gr.Markdown()
gr.Markdown("### Record Feedback")
with gr.Row():
char_dropdown = gr.Dropdown(label="Character Played", choices=[])
won_checkbox = gr.Checkbox(label="Won?", value=False)
with gr.Row():
rating_slider = gr.Slider(
1, 5, value=3, step=0.5, label="Rating (1-5)"
)
playtime_slider = gr.Slider(
5, 60, value=20, step=5, label="Playtime (minutes)"
)
submit_btn = gr.Button("Submit Feedback", variant="secondary")
feedback_output = gr.Markdown()
def init_game(game_name):
result = app.init_game(game_name)
stats = app.get_stats()
return result, stats
init_btn.click(
init_game, inputs=[game_input], outputs=[init_output, stats_display]
)
get_rec_btn.click(
lambda k: app.get_recs(int(k)),
inputs=[top_k_slider],
outputs=[rec_output, char_dropdown],
)
submit_btn.click(
app.record_feedback,
inputs=[char_dropdown, won_checkbox, rating_slider, playtime_slider],
outputs=[feedback_output],
)
refresh_stats_btn.click(app.get_stats, outputs=[stats_display])
return demo
#
if __name__ == "__main__":
demo = create_ui()
demo.launch()