Spaces:
Build error
Build error
Timo commited on
Commit ·
e6f2697
1
Parent(s): cc0b700
OK
Browse files- src/__pycache__/draft_model.cpython-39.pyc +0 -0
- src/__pycache__/helpers.cpython-39.pyc +0 -0
- src/draft_model.py +38 -13
- src/helper_files/supported_sets.txt +2 -0
- src/helpers.py +42 -1
- src/streamlit_app.py +29 -29
src/__pycache__/draft_model.cpython-39.pyc
DELETED
|
Binary file (2.29 kB)
|
|
|
src/__pycache__/helpers.cpython-39.pyc
DELETED
|
Binary file (7.94 kB)
|
|
|
src/draft_model.py
CHANGED
|
@@ -2,6 +2,7 @@ from pathlib import Path
|
|
| 2 |
import json
|
| 3 |
import torch
|
| 4 |
from typing import List, Dict
|
|
|
|
| 5 |
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention
|
|
@@ -23,12 +24,12 @@ class DraftModel:
|
|
| 23 |
def __init__(self):
|
| 24 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
|
| 26 |
-
cfg_path = hf_hub_download(
|
| 27 |
-
repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
|
| 28 |
-
)
|
| 29 |
weight_path = hf_hub_download(
|
| 30 |
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
|
| 31 |
)
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
with open(cfg_path, "r") as f:
|
| 34 |
cfg = json.load(f)
|
|
@@ -37,7 +38,7 @@ class DraftModel:
|
|
| 37 |
|
| 38 |
|
| 39 |
self.net = MLP_CrossAttention(**cfg).to(self.device)
|
| 40 |
-
self.net.load_state_dict(weight_path, map_location=self.device)
|
| 41 |
self.net.eval()
|
| 42 |
|
| 43 |
# ---- embeddings – one-time load ------------------------------------
|
|
@@ -47,19 +48,43 @@ class DraftModel:
|
|
| 47 |
)
|
| 48 |
self.emb_size = next(iter(self.embed_dict.values())).shape[0]
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# --------------------------------------------------------------------- #
|
| 51 |
# Public API expected by streamlit_app.py #
|
| 52 |
# --------------------------------------------------------------------- #
|
| 53 |
@torch.no_grad()
|
| 54 |
-
def predict(self, pack: List[Dict],
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
|
|
|
|
| 2 |
import json
|
| 3 |
import torch
|
| 4 |
from typing import List, Dict
|
| 5 |
+
from collections import defaultdict
|
| 6 |
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention
|
|
|
|
| 24 |
def __init__(self):
|
| 25 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
weight_path = hf_hub_download(
|
| 28 |
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
|
| 29 |
)
|
| 30 |
+
cfg_path = hf_hub_download(
|
| 31 |
+
repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
|
| 32 |
+
)
|
| 33 |
|
| 34 |
with open(cfg_path, "r") as f:
|
| 35 |
cfg = json.load(f)
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
self.net = MLP_CrossAttention(**cfg).to(self.device)
|
| 41 |
+
self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
|
| 42 |
self.net.eval()
|
| 43 |
|
| 44 |
# ---- embeddings – one-time load ------------------------------------
|
|
|
|
| 48 |
)
|
| 49 |
self.emb_size = next(iter(self.embed_dict.values())).shape[0]
|
| 50 |
|
| 51 |
+
raw_card_file = json.load(open(hf_hub_download(
|
| 52 |
+
repo_id=DATA_REPO, filename=CARD_FILE, repo_type="dataset"
|
| 53 |
+
)))
|
| 54 |
+
self.cards = defaultdict(dict)
|
| 55 |
+
for card in raw_card_file:
|
| 56 |
+
self.cards[card["set"]][card["name"]] = card
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _embed(self, name): # helper
|
| 60 |
+
return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
|
| 61 |
+
|
| 62 |
# --------------------------------------------------------------------- #
|
| 63 |
# Public API expected by streamlit_app.py #
|
| 64 |
# --------------------------------------------------------------------- #
|
| 65 |
@torch.no_grad()
|
| 66 |
+
def predict(self, pack: List[Dict], deck: List[Dict]) -> Dict:
|
| 67 |
+
|
| 68 |
+
card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device)
|
| 69 |
+
if deck is None:
|
| 70 |
+
deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
|
| 71 |
+
else:
|
| 72 |
+
deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device)
|
| 73 |
+
|
| 74 |
+
vals = self.net(deck = deck_t, cards = card_t)
|
| 75 |
+
#scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
|
| 76 |
+
scores = vals.squeeze(0).cpu().numpy()
|
| 77 |
+
print(scores)
|
| 78 |
+
return {
|
| 79 |
+
"pick": pack[scores.argmax()],
|
| 80 |
+
"scores": scores.tolist(),
|
| 81 |
+
}
|
| 82 |
+
@torch.no_grad()
|
| 83 |
+
def get_p1p1(self, set_code:str):
|
| 84 |
+
keys = list(self.cards[set_code].keys())
|
| 85 |
+
cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device)
|
| 86 |
|
| 87 |
+
vals = self.predict(pack=keys, deck=None)["scores"]
|
| 88 |
+
return keys, vals
|
| 89 |
|
| 90 |
|
src/helper_files/supported_sets.txt
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
DFT
|
| 2 |
PIO
|
| 3 |
FDN
|
|
|
|
| 1 |
+
EOE
|
| 2 |
+
FIN
|
| 3 |
DFT
|
| 4 |
PIO
|
| 5 |
FDN
|
src/helpers.py
CHANGED
|
@@ -280,4 +280,45 @@ def get_card_embeddings(card_names, embedding_dict, embedding_size = 1330):
|
|
| 280 |
else:
|
| 281 |
embedding, got_new = get_embedding_of_card(card, embedding_dict)
|
| 282 |
embeddings.append(embedding)
|
| 283 |
-
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
else:
|
| 281 |
embedding, got_new = get_embedding_of_card(card, embedding_dict)
|
| 282 |
embeddings.append(embedding)
|
| 283 |
+
return embeddings
|
| 284 |
+
|
| 285 |
+
def check_for_basics(card_name, embedding_dict):
|
| 286 |
+
ints = ['1','2','3','4','5']
|
| 287 |
+
basics = ['Mountain','Forest','Swamp','Island','Plains']
|
| 288 |
+
for b in basics:
|
| 289 |
+
if b in card_name:
|
| 290 |
+
for i in ints:
|
| 291 |
+
if card_name == f'{b}_{i}':
|
| 292 |
+
return b
|
| 293 |
+
return card_name
|
| 294 |
+
|
| 295 |
+
def get_embedding_of_card(card_name, embedding_dict):
|
| 296 |
+
try:
|
| 297 |
+
card_name = check_for_basics(card_name, embedding_dict)
|
| 298 |
+
card_name = card_name.replace('_', ' ')
|
| 299 |
+
card_name = card_name.replace("Sol'kanar","Sol'Kanar")
|
| 300 |
+
if card_name not in embedding_dict and card_name.split(' // ')[0] not in embedding_dict and card_name.replace('A-','') not in embedding_dict:
|
| 301 |
+
# print(f'Requesting new embedding for {card_name}')
|
| 302 |
+
# attributes, text = get_card_representation(card_name = card_name)
|
| 303 |
+
# text_embedding = embedd_text([text]).squeeze()
|
| 304 |
+
# return torch.Tensor(np.concatenate((attributes, text_embedding), axis = 0)), True
|
| 305 |
+
raise Exception(f'Could not find {card_name}')
|
| 306 |
+
else:
|
| 307 |
+
try:
|
| 308 |
+
return torch.Tensor(embedding_dict[card_name]), False
|
| 309 |
+
except:
|
| 310 |
+
try:
|
| 311 |
+
return torch.Tensor(embedding_dict[card_name.split(' // ')[0]]), False
|
| 312 |
+
except:
|
| 313 |
+
try:
|
| 314 |
+
return torch.Tensor(embedding_dict[card_name.replace('_',' ')]), False
|
| 315 |
+
except:
|
| 316 |
+
try:
|
| 317 |
+
return torch.Tensor(embedding_dict[card_name.replace('A-','')]), False
|
| 318 |
+
except:
|
| 319 |
+
print(f'Could not find {card_name}')
|
| 320 |
+
raise Exception
|
| 321 |
+
except Exception as e:
|
| 322 |
+
print(f'Could not find {card_name}')
|
| 323 |
+
print(e)
|
| 324 |
+
raise e
|
src/streamlit_app.py
CHANGED
|
@@ -32,13 +32,14 @@ import random
|
|
| 32 |
from pathlib import Path
|
| 33 |
|
| 34 |
from typing import Dict, List
|
|
|
|
| 35 |
|
| 36 |
import requests
|
| 37 |
import streamlit as st
|
| 38 |
|
| 39 |
from draft_model import DraftModel
|
| 40 |
|
| 41 |
-
SUPPORTED_SETS_PATH = Path("supported_sets.txt")
|
| 42 |
|
| 43 |
|
| 44 |
@st.cache_data(show_spinner="Reading supported sets …")
|
|
@@ -54,6 +55,14 @@ def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
|
|
| 54 |
def load_model():
|
| 55 |
return DraftModel()
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
@st.cache_data(show_spinner="Calculating card rankings …")
|
| 59 |
def rank_cards(set_code: str) -> List[Dict]:
|
|
@@ -82,11 +91,6 @@ def rank_cards(set_code: str) -> List[Dict]:
|
|
| 82 |
|
| 83 |
model = load_model()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
# -----------------------------------------------------------------------------
|
| 87 |
-
# 2. Draft‑logic helpers (stubs)
|
| 88 |
-
# -----------------------------------------------------------------------------
|
| 89 |
-
|
| 90 |
def suggest_pick(pack: List[Dict], picks: List[Dict]) -> Dict:
|
| 91 |
if model is None:
|
| 92 |
return random.choice(pack)
|
|
@@ -105,18 +109,6 @@ def fetch_card_image(card_name: str) -> str:
|
|
| 105 |
return data["card_faces"][0]["image_uris"]["normal"]
|
| 106 |
|
| 107 |
|
| 108 |
-
def generate_booster(set_code: str) -> List[Dict]:
|
| 109 |
-
url = f"https://api.scryfall.com/cards/search?q=set%3A{set_code}+is%3Abooster+unique%3Aprints"
|
| 110 |
-
cards: List[Dict] = []
|
| 111 |
-
while url:
|
| 112 |
-
resp = requests.get(url)
|
| 113 |
-
resp.raise_for_status()
|
| 114 |
-
payload = resp.json()
|
| 115 |
-
cards += payload["data"]
|
| 116 |
-
url = payload.get("next_page") if payload.get("has_more") else None
|
| 117 |
-
return random.sample(cards, 15)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
# -----------------------------------------------------------------------------
|
| 121 |
# 3. Streamlit UI
|
| 122 |
# -----------------------------------------------------------------------------
|
|
@@ -135,13 +127,19 @@ with st.sidebar:
|
|
| 135 |
# Hide control in an expander (collapsed by default)
|
| 136 |
with st.expander("Set selection", expanded=False):
|
| 137 |
if supported_sets:
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
else:
|
| 140 |
st.warning(
|
| 141 |
-
"*supported_sets.txt* not found or empty. Using free
|
| 142 |
icon="⚠️",
|
| 143 |
)
|
| 144 |
-
set_code = st.text_input("Set code", value="
|
| 145 |
|
| 146 |
if st.button("Start new draft", type="primary"):
|
| 147 |
st.session_state["pack"] = generate_booster(set_code)
|
|
@@ -154,7 +152,7 @@ st.session_state.setdefault("picks", [])
|
|
| 154 |
|
| 155 |
# -------- Main content organised in tabs ------------------------------------
|
| 156 |
|
| 157 |
-
tabs = st.tabs(["Draft", "
|
| 158 |
|
| 159 |
# --- Tab 1: Draft ------------------------------------------------------------
|
| 160 |
|
|
@@ -189,11 +187,13 @@ with tabs[1]:
|
|
| 189 |
st.header("Card rankings for set " + set_code)
|
| 190 |
|
| 191 |
if set_code:
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
st.
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
| 198 |
else:
|
| 199 |
-
st.info("Select a set in the sidebar to view
|
|
|
|
| 32 |
from pathlib import Path
|
| 33 |
|
| 34 |
from typing import Dict, List
|
| 35 |
+
import pandas as pd
|
| 36 |
|
| 37 |
import requests
|
| 38 |
import streamlit as st
|
| 39 |
|
| 40 |
from draft_model import DraftModel
|
| 41 |
|
| 42 |
+
SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
|
| 43 |
|
| 44 |
|
| 45 |
@st.cache_data(show_spinner="Reading supported sets …")
|
|
|
|
| 55 |
def load_model():
|
| 56 |
return DraftModel()
|
| 57 |
|
| 58 |
+
@st.cache_data(show_spinner="Calculating P1P1 ...")
|
| 59 |
+
def p1p1_ranking(set_code: str):
|
| 60 |
+
names, scores = model.get_p1p1(set_code)
|
| 61 |
+
df = pd.DataFrame({"card": names, "p1p1_score": scores})
|
| 62 |
+
df = df.sort_values("p1p1_score", ascending=False, ignore_index=True)
|
| 63 |
+
df.index += 1 # 1-based ranks look nicer
|
| 64 |
+
return df
|
| 65 |
+
|
| 66 |
|
| 67 |
@st.cache_data(show_spinner="Calculating card rankings …")
|
| 68 |
def rank_cards(set_code: str) -> List[Dict]:
|
|
|
|
| 91 |
|
| 92 |
model = load_model()
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def suggest_pick(pack: List[Dict], picks: List[Dict]) -> Dict:
|
| 95 |
if model is None:
|
| 96 |
return random.choice(pack)
|
|
|
|
| 109 |
return data["card_faces"][0]["image_uris"]["normal"]
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# -----------------------------------------------------------------------------
|
| 113 |
# 3. Streamlit UI
|
| 114 |
# -----------------------------------------------------------------------------
|
|
|
|
| 127 |
# Hide control in an expander (collapsed by default)
|
| 128 |
with st.expander("Set selection", expanded=False):
|
| 129 |
if supported_sets:
|
| 130 |
+
# UPDATED: dropdown instead of radio
|
| 131 |
+
set_code = st.selectbox(
|
| 132 |
+
"Choose a set",
|
| 133 |
+
supported_sets,
|
| 134 |
+
index=0,
|
| 135 |
+
key="set_code",
|
| 136 |
+
)
|
| 137 |
else:
|
| 138 |
st.warning(
|
| 139 |
+
"*supported_sets.txt* not found or empty. Using free-text input instead.",
|
| 140 |
icon="⚠️",
|
| 141 |
)
|
| 142 |
+
set_code = st.text_input("Set code", value="EOE", key="set_code")
|
| 143 |
|
| 144 |
if st.button("Start new draft", type="primary"):
|
| 145 |
st.session_state["pack"] = generate_booster(set_code)
|
|
|
|
| 152 |
|
| 153 |
# -------- Main content organised in tabs ------------------------------------
|
| 154 |
|
| 155 |
+
tabs = st.tabs(["Draft", "P1P1 Rankings"])
|
| 156 |
|
| 157 |
# --- Tab 1: Draft ------------------------------------------------------------
|
| 158 |
|
|
|
|
| 187 |
st.header("Card rankings for set " + set_code)
|
| 188 |
|
| 189 |
if set_code:
|
| 190 |
+
try:
|
| 191 |
+
df = p1p1_ranking(set_code.lower()) # cached; auto-updates on dropdown change
|
| 192 |
+
if not df.empty:
|
| 193 |
+
st.dataframe(df, use_container_width=True)
|
| 194 |
+
else:
|
| 195 |
+
st.info("No P1P1 results returned for this set.")
|
| 196 |
+
except Exception as e:
|
| 197 |
+
st.error(f"Could not calculate P1P1: {e}")
|
| 198 |
else:
|
| 199 |
+
st.info("Select a set in the sidebar to view P1P1.")
|