Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| from fastai.tabular.all import * | |
| from .mtga import mtga_id_to_card_name, card_name_to_mtga_id | |
| class MTGPickPredictor: | |
| def __init__(self): | |
| self._models: Dict[str, any] = {} # Cache for loaded models | |
| self._card_names: Dict[str, list] = {} # Cache for card names per set | |
| def predict(self, set_name, input_data): | |
| if set_name not in self._models: | |
| self._load_model(set_name) | |
| df = self._json_to_df(set_name, input_data) | |
| _, _, pred_probs = self._models[set_name].predict(df.iloc[0]) | |
| topk_values, topk_indices = pred_probs.topk(3) | |
| result = "" | |
| output = [] | |
| for _, (prob, idx) in enumerate(zip(topk_values, topk_indices)): | |
| card_name = self._card_names[set_name][idx] | |
| result = result + (f"{card_name}: {prob*100:.0f}%\n") | |
| output.append((card_name_to_mtga_id(card_name), card_name, prob)) | |
| print(result) | |
| return output | |
| def _load_model(self, set_name: str): | |
| """Lazily loads a model for a specific set if not already loaded""" | |
| if set_name in self._models: | |
| return | |
| model_path = f"models/{set_name}_draft.pkl" | |
| model = load_learner(model_path) | |
| self._models[set_name] = model | |
| # Cache card names for this set | |
| self._card_names[set_name] = [ | |
| col.replace('pack_card_', '') | |
| for col in model.dls.train_ds.cont_names | |
| if col.startswith('pack_card_') | |
| ] | |
| def _json_to_df(self, set_name, json_data): | |
| # Initialize all card columns with 0 | |
| #all_cols = {f"pack_card_{card}": 0 for card in card_names} | {f"pool_{card}": 0 for card in card_names} | |
| # TODO: this is shortcut for testing | |
| all_cols = {col_name: 0 for col_name in self._models[set_name].dls.train_ds.cont_names} | |
| # Fill in pack cards | |
| for card_id in json_data['pack']: | |
| col_name = f"pack_card_{mtga_id_to_card_name(card_id)}" | |
| all_cols[col_name] += 1 | |
| # Fill in pool cards | |
| for card_id in json_data['pool']: | |
| col_name = f"pool_{mtga_id_to_card_name(card_id)}" | |
| all_cols[col_name] +=1 | |
| return pd.DataFrame([all_cols]).astype(float) | |