Guess_the_Live / evaluate.py
trioskosmos's picture
Upload 6 files
6d30012 verified
import json
import torch
import random
import numpy as np
from model import LoveLiveTransformer
from game import LoveLiveGame
def evaluate():
print("Loading resources...")
with open('mappings.json', 'r') as f:
mappings = json.load(f)
song_to_idx = mappings['song_to_idx']
artist_to_idx = mappings['artist_to_idx']
live_to_idx = mappings['live_to_idx']
idx_to_live = {v: k for k, v in live_to_idx.items()}
idx_to_song = {v: k for k, v in song_to_idx.items()}
idx_to_artist = {v: k for k, v in artist_to_idx.items()}
game = LoveLiveGame()
# Model parameters must match train.py (using mappings to match trained model)
num_songs = len(mappings['song_to_idx']) + 1
num_artists = len(mappings['artist_to_idx']) + 1
num_feedback = 4
num_lives = len(mappings['live_to_idx'])
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
model = LoveLiveTransformer(num_songs, num_artists, num_feedback, num_lives).to(device)
model.load_state_dict(torch.load('transformer_model.pth', map_location=device))
model.eval()
# Start a simulation
target_id = game.start_game()
print(f"Target Live: {game.lives[target_id]['name']}")
songs_seq = []
artists_seq = []
feedbacks_seq = []
guessed_lives = set()
max_turns = 20
solved = False
all_song_ids = list(game.songs.keys())
for turn in range(max_turns):
# Prepare input
# Pad to max_len (20) used in training, or just use current seq?
# Model expects (seq_len, batch_size)
# We can pass current length seq.
if len(songs_seq) == 0:
# First turn: random guess or empty input?
# Model trained on seq_len >= 1.
# So first guess random.
# Ideally "optimal" would mean picking a song that splits the space well initially.
# Let's pick a very common song or just random.
# Random for diversity.
guess_song_id = random.choice(all_song_ids)
# Pick an artist for this song
artist_candidates = game.songs[guess_song_id]['artist_ids']
guess_artist_id = random.choice(artist_candidates) if artist_candidates else random.choice(list(game.artists.keys()))
print(f"Turn {turn+1}: First guess random -> {game.songs[guess_song_id]['name']}")
else:
# Use model to predict live
# Map indices + 1
s_in = torch.tensor([x + 1 for x in songs_seq], device=device).unsqueeze(1) # (seq_len, 1)
a_in = torch.tensor([x + 1 for x in artists_seq], device=device).unsqueeze(1)
f_in = torch.tensor([x + 1 for x in feedbacks_seq], device=device).unsqueeze(1)
with torch.no_grad():
logits = model(s_in, a_in, f_in)
probs = torch.softmax(logits, dim=1).squeeze(0) # (num_lives)
# Check if model's top choice is invalid (pruned)
raw_top_idx = torch.argmax(probs).item()
raw_top_live_id = idx_to_live[raw_top_idx]
if raw_top_live_id not in game.possible_live_ids:
print(f" [Model Warning] Model wanted to pick {game.lives[raw_top_live_id]['name']} but it is pruned.")
# Apply hard constraints (pruning)
# Mask out impossible lives based on game.possible_live_ids
mask = torch.zeros_like(probs)
possible_indices = [live_to_idx[lid] for lid in game.possible_live_ids]
if not possible_indices:
print("Error: No possible lives remaining according to hard constraints!")
break
mask[possible_indices] = 1.0
probs = probs * mask
if probs.sum() == 0:
# Fallback (shouldn't happen if logic correct)
probs[possible_indices] = 1.0
probs = probs / (probs.sum() + 1e-9)
# Sort predictions
sorted_indices = torch.argsort(probs, descending=True)
top_idx = sorted_indices[0]
top_live_id = idx_to_live[top_idx.item()]
top_prob = probs[top_idx]
print(f"Turn {turn+1}: Top Prediction: {game.lives[top_live_id]['name']} ({top_prob.item():.4f}) [Candidates: {len(possible_indices)}]")
if top_prob.item() > 0.7 and top_live_id not in guessed_lives:
# Try guessing the live
print(">> Guessing LIVE!")
if game.guess_live(top_live_id):
print("CORRECT! Solved.")
solved = True
break
else:
print("WRONG Live guess. Continuing...")
guessed_lives.add(top_live_id)
if top_live_id in game.possible_live_ids:
game.possible_live_ids.remove(top_live_id)
# Choose next song: Use Game Engine's Best Move (Entropy)
# The game engine uses uniform probability over remaining candidates.
# We can upgrade this to use the model's probabilities?
# Option A: Use game.get_best_moves() (Pure Entropy on uniform priors)
# Option B: Use Model Weighted Entropy (similar to what I had, but maybe cleaner?)
# Let's use the game engine's pure entropy for robustness, as the model
# can be overconfident or biased. Pure entropy ensures we cut the search space.
best_moves = game.get_best_moves(top_k=1)
if best_moves:
guess_song_id = best_moves[0][0]
print(f"Guessing Song: {game.songs[guess_song_id]['name']} (Score: {best_moves[0][1]:.4f})")
else:
# Fallback if no moves (shouldn't happen if candidates > 1)
guess_song_id = random.choice(all_song_ids)
print(f"Guessing Song: {game.songs[guess_song_id]['name']} (Random Fallback)")
# Pick likely artist for this song
a_ids = game.songs[guess_song_id]['artist_ids']
guess_artist_id = a_ids[0] if a_ids else list(game.artists.keys())[0]
# Execute guess
feedback = game.guess_song(guess_song_id, guess_artist_id)
print(f"Feedback: {feedback}")
# Prune candidates based on feedback
game.prune_candidates(guess_song_id, guess_artist_id, feedback)
songs_seq.append(song_to_idx[guess_song_id])
artists_seq.append(artist_to_idx[guess_artist_id])
feedbacks_seq.append(feedback)
if not solved:
print("Failed to solve in max turns.")
if __name__ == "__main__":
evaluate()