mtg-draft-predictor / src /predictor.py
Denis Lebedev
Return mtga_card_id in predict output
8f19b3a unverified
import pandas as pd
from fastai.tabular.all import *
from .mtga import mtga_id_to_card_name, card_name_to_mtga_id
class MTGPickPredictor:
def __init__(self):
self._models: Dict[str, any] = {} # Cache for loaded models
self._card_names: Dict[str, list] = {} # Cache for card names per set
def predict(self, set_name, input_data):
if set_name not in self._models:
self._load_model(set_name)
df = self._json_to_df(set_name, input_data)
_, _, pred_probs = self._models[set_name].predict(df.iloc[0])
topk_values, topk_indices = pred_probs.topk(3)
result = ""
output = []
for _, (prob, idx) in enumerate(zip(topk_values, topk_indices)):
card_name = self._card_names[set_name][idx]
result = result + (f"{card_name}: {prob*100:.0f}%\n")
output.append((card_name_to_mtga_id(card_name), card_name, prob))
print(result)
return output
def _load_model(self, set_name: str):
"""Lazily loads a model for a specific set if not already loaded"""
if set_name in self._models:
return
model_path = f"models/{set_name}_draft.pkl"
model = load_learner(model_path)
self._models[set_name] = model
# Cache card names for this set
self._card_names[set_name] = [
col.replace('pack_card_', '')
for col in model.dls.train_ds.cont_names
if col.startswith('pack_card_')
]
def _json_to_df(self, set_name, json_data):
# Initialize all card columns with 0
#all_cols = {f"pack_card_{card}": 0 for card in card_names} | {f"pool_{card}": 0 for card in card_names}
# TODO: this is shortcut for testing
all_cols = {col_name: 0 for col_name in self._models[set_name].dls.train_ds.cont_names}
# Fill in pack cards
for card_id in json_data['pack']:
col_name = f"pack_card_{mtga_id_to_card_name(card_id)}"
all_cols[col_name] += 1
# Fill in pool cards
for card_id in json_data['pool']:
col_name = f"pool_{mtga_id_to_card_name(card_id)}"
all_cols[col_name] +=1
return pd.DataFrame([all_cols]).astype(float)