Spaces:
Sleeping
Sleeping
File size: 2,347 Bytes
87acbba 8f19b3a 87acbba f1a26d2 87acbba f1a26d2 87acbba f1a26d2 8f19b3a 87acbba f1a26d2 87acbba f1a26d2 87acbba f1a26d2 87acbba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import pandas as pd
from fastai.tabular.all import *
from .mtga import mtga_id_to_card_name, card_name_to_mtga_id
class MTGPickPredictor:
def __init__(self):
self._models: Dict[str, any] = {} # Cache for loaded models
self._card_names: Dict[str, list] = {} # Cache for card names per set
def predict(self, set_name, input_data):
if set_name not in self._models:
self._load_model(set_name)
df = self._json_to_df(set_name, input_data)
_, _, pred_probs = self._models[set_name].predict(df.iloc[0])
topk_values, topk_indices = pred_probs.topk(3)
result = ""
output = []
for _, (prob, idx) in enumerate(zip(topk_values, topk_indices)):
card_name = self._card_names[set_name][idx]
result = result + (f"{card_name}: {prob*100:.0f}%\n")
output.append((card_name_to_mtga_id(card_name), card_name, prob))
print(result)
return output
def _load_model(self, set_name: str):
"""Lazily loads a model for a specific set if not already loaded"""
if set_name in self._models:
return
model_path = f"models/{set_name}_draft.pkl"
model = load_learner(model_path)
self._models[set_name] = model
# Cache card names for this set
self._card_names[set_name] = [
col.replace('pack_card_', '')
for col in model.dls.train_ds.cont_names
if col.startswith('pack_card_')
]
def _json_to_df(self, set_name, json_data):
# Initialize all card columns with 0
#all_cols = {f"pack_card_{card}": 0 for card in card_names} | {f"pool_{card}": 0 for card in card_names}
# TODO: this is shortcut for testing
all_cols = {col_name: 0 for col_name in self._models[set_name].dls.train_ds.cont_names}
# Fill in pack cards
for card_id in json_data['pack']:
col_name = f"pack_card_{mtga_id_to_card_name(card_id)}"
all_cols[col_name] += 1
# Fill in pool cards
for card_id in json_data['pool']:
col_name = f"pool_{mtga_id_to_card_name(card_id)}"
all_cols[col_name] +=1
return pd.DataFrame([all_cols]).astype(float)
|