Timo commited on
Commit
e6f2697
·
1 Parent(s): cc0b700
src/__pycache__/draft_model.cpython-39.pyc DELETED
Binary file (2.29 kB)
 
src/__pycache__/helpers.cpython-39.pyc DELETED
Binary file (7.94 kB)
 
src/draft_model.py CHANGED
@@ -2,6 +2,7 @@ from pathlib import Path
2
  import json
3
  import torch
4
  from typing import List, Dict
 
5
 
6
  from huggingface_hub import hf_hub_download
7
  from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention
@@ -23,12 +24,12 @@ class DraftModel:
23
  def __init__(self):
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
- cfg_path = hf_hub_download(
27
- repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
28
- )
29
  weight_path = hf_hub_download(
30
  repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
31
  )
 
 
 
32
 
33
  with open(cfg_path, "r") as f:
34
  cfg = json.load(f)
@@ -37,7 +38,7 @@ class DraftModel:
37
 
38
 
39
  self.net = MLP_CrossAttention(**cfg).to(self.device)
40
- self.net.load_state_dict(weight_path, map_location=self.device)
41
  self.net.eval()
42
 
43
  # ---- embeddings – one-time load ------------------------------------
@@ -47,19 +48,43 @@ class DraftModel:
47
  )
48
  self.emb_size = next(iter(self.embed_dict.values())).shape[0]
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  # --------------------------------------------------------------------- #
51
  # Public API expected by streamlit_app.py #
52
  # --------------------------------------------------------------------- #
53
  @torch.no_grad()
54
- def predict(self, pack: List[Dict], picks: List[Dict], deck: List[Dict]) -> Dict:
55
- names = [c["name"] for c in pack]
56
-
57
- def embed(name): # helper
58
- return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
59
-
60
- card_t = torch.stack([embed(n) for n in names]).unsqueeze(0).to(self.device)
61
- deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- return torch.softmax(self.net(card_t, deck_t), dim=1).squeeze(0).cpu().numpy().tolist()
 
64
 
65
 
 
2
  import json
3
  import torch
4
  from typing import List, Dict
5
+ from collections import defaultdict
6
 
7
  from huggingface_hub import hf_hub_download
8
  from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention
 
24
  def __init__(self):
25
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
 
 
 
27
  weight_path = hf_hub_download(
28
  repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
29
  )
30
+ cfg_path = hf_hub_download(
31
+ repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
32
+ )
33
 
34
  with open(cfg_path, "r") as f:
35
  cfg = json.load(f)
 
38
 
39
 
40
  self.net = MLP_CrossAttention(**cfg).to(self.device)
41
+ self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
42
  self.net.eval()
43
 
44
  # ---- embeddings – one-time load ------------------------------------
 
48
  )
49
  self.emb_size = next(iter(self.embed_dict.values())).shape[0]
50
 
51
+ raw_card_file = json.load(open(hf_hub_download(
52
+ repo_id=DATA_REPO, filename=CARD_FILE, repo_type="dataset"
53
+ )))
54
+ self.cards = defaultdict(dict)
55
+ for card in raw_card_file:
56
+ self.cards[card["set"]][card["name"]] = card
57
+
58
+
59
+ def _embed(self, name): # helper
60
+ return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
61
+
62
  # --------------------------------------------------------------------- #
63
  # Public API expected by streamlit_app.py #
64
  # --------------------------------------------------------------------- #
65
  @torch.no_grad()
66
+ def predict(self, pack: List[Dict], deck: List[Dict]) -> Dict:
67
+
68
+ card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device)
69
+ if deck is None:
70
+ deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
71
+ else:
72
+ deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device)
73
+
74
+ vals = self.net(deck = deck_t, cards = card_t)
75
+ #scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
76
+ scores = vals.squeeze(0).cpu().numpy()
77
+ print(scores)
78
+ return {
79
+ "pick": pack[scores.argmax()],
80
+ "scores": scores.tolist(),
81
+ }
82
+ @torch.no_grad()
83
+ def get_p1p1(self, set_code:str):
84
+ keys = list(self.cards[set_code].keys())
85
+ cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device)
86
 
87
+ vals = self.predict(pack=keys, deck=None)["scores"]
88
+ return keys, vals
89
 
90
 
src/helper_files/supported_sets.txt CHANGED
@@ -1,3 +1,5 @@
 
 
1
  DFT
2
  PIO
3
  FDN
 
1
+ EOE
2
+ FIN
3
  DFT
4
  PIO
5
  FDN
src/helpers.py CHANGED
@@ -280,4 +280,45 @@ def get_card_embeddings(card_names, embedding_dict, embedding_size = 1330):
280
  else:
281
  embedding, got_new = get_embedding_of_card(card, embedding_dict)
282
  embeddings.append(embedding)
283
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  else:
281
  embedding, got_new = get_embedding_of_card(card, embedding_dict)
282
  embeddings.append(embedding)
283
+ return embeddings
284
+
285
+ def check_for_basics(card_name, embedding_dict):
286
+ ints = ['1','2','3','4','5']
287
+ basics = ['Mountain','Forest','Swamp','Island','Plains']
288
+ for b in basics:
289
+ if b in card_name:
290
+ for i in ints:
291
+ if card_name == f'{b}_{i}':
292
+ return b
293
+ return card_name
294
+
295
+ def get_embedding_of_card(card_name, embedding_dict):
296
+ try:
297
+ card_name = check_for_basics(card_name, embedding_dict)
298
+ card_name = card_name.replace('_', ' ')
299
+ card_name = card_name.replace("Sol'kanar","Sol'Kanar")
300
+ if card_name not in embedding_dict and card_name.split(' // ')[0] not in embedding_dict and card_name.replace('A-','') not in embedding_dict:
301
+ # print(f'Requesting new embedding for {card_name}')
302
+ # attributes, text = get_card_representation(card_name = card_name)
303
+ # text_embedding = embedd_text([text]).squeeze()
304
+ # return torch.Tensor(np.concatenate((attributes, text_embedding), axis = 0)), True
305
+ raise Exception(f'Could not find {card_name}')
306
+ else:
307
+ try:
308
+ return torch.Tensor(embedding_dict[card_name]), False
309
+ except:
310
+ try:
311
+ return torch.Tensor(embedding_dict[card_name.split(' // ')[0]]), False
312
+ except:
313
+ try:
314
+ return torch.Tensor(embedding_dict[card_name.replace('_',' ')]), False
315
+ except:
316
+ try:
317
+ return torch.Tensor(embedding_dict[card_name.replace('A-','')]), False
318
+ except:
319
+ print(f'Could not find {card_name}')
320
+ raise Exception
321
+ except Exception as e:
322
+ print(f'Could not find {card_name}')
323
+ print(e)
324
+ raise e
src/streamlit_app.py CHANGED
@@ -32,13 +32,14 @@ import random
32
  from pathlib import Path
33
 
34
  from typing import Dict, List
 
35
 
36
  import requests
37
  import streamlit as st
38
 
39
  from draft_model import DraftModel
40
 
41
- SUPPORTED_SETS_PATH = Path("supported_sets.txt")
42
 
43
 
44
  @st.cache_data(show_spinner="Reading supported sets …")
@@ -54,6 +55,14 @@ def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
54
  def load_model():
55
  return DraftModel()
56
 
 
 
 
 
 
 
 
 
57
 
58
  @st.cache_data(show_spinner="Calculating card rankings …")
59
  def rank_cards(set_code: str) -> List[Dict]:
@@ -82,11 +91,6 @@ def rank_cards(set_code: str) -> List[Dict]:
82
 
83
  model = load_model()
84
 
85
-
86
- # -----------------------------------------------------------------------------
87
- # 2. Draft‑logic helpers (stubs)
88
- # -----------------------------------------------------------------------------
89
-
90
  def suggest_pick(pack: List[Dict], picks: List[Dict]) -> Dict:
91
  if model is None:
92
  return random.choice(pack)
@@ -105,18 +109,6 @@ def fetch_card_image(card_name: str) -> str:
105
  return data["card_faces"][0]["image_uris"]["normal"]
106
 
107
 
108
- def generate_booster(set_code: str) -> List[Dict]:
109
- url = f"https://api.scryfall.com/cards/search?q=set%3A{set_code}+is%3Abooster+unique%3Aprints"
110
- cards: List[Dict] = []
111
- while url:
112
- resp = requests.get(url)
113
- resp.raise_for_status()
114
- payload = resp.json()
115
- cards += payload["data"]
116
- url = payload.get("next_page") if payload.get("has_more") else None
117
- return random.sample(cards, 15)
118
-
119
-
120
  # -----------------------------------------------------------------------------
121
  # 3. Streamlit UI
122
  # -----------------------------------------------------------------------------
@@ -135,13 +127,19 @@ with st.sidebar:
135
  # Hide control in an expander (collapsed by default)
136
  with st.expander("Set selection", expanded=False):
137
  if supported_sets:
138
- set_code = st.radio("Choose a set to draft", supported_sets, index=0)
 
 
 
 
 
 
139
  else:
140
  st.warning(
141
- "*supported_sets.txt* not found or empty. Using freetext input instead.",
142
  icon="⚠️",
143
  )
144
- set_code = st.text_input("Set code", value="WOE")
145
 
146
  if st.button("Start new draft", type="primary"):
147
  st.session_state["pack"] = generate_booster(set_code)
@@ -154,7 +152,7 @@ st.session_state.setdefault("picks", [])
154
 
155
  # -------- Main content organised in tabs ------------------------------------
156
 
157
- tabs = st.tabs(["Draft", "Card rankings"])
158
 
159
  # --- Tab 1: Draft ------------------------------------------------------------
160
 
@@ -189,11 +187,13 @@ with tabs[1]:
189
  st.header("Card rankings for set " + set_code)
190
 
191
  if set_code:
192
- ranking = rank_cards(set_code)
193
- if ranking:
194
- for idx, card in enumerate(ranking, start=1):
195
- st.write(f"{idx}. {card['name']} — **{card['score']:.2f}**")
196
- else:
197
- st.info("No cards found for this set (or Scryfall unavailable).")
 
 
198
  else:
199
- st.info("Select a set in the sidebar to view rankings.")
 
32
  from pathlib import Path
33
 
34
  from typing import Dict, List
35
+ import pandas as pd
36
 
37
  import requests
38
  import streamlit as st
39
 
40
  from draft_model import DraftModel
41
 
42
+ SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
43
 
44
 
45
  @st.cache_data(show_spinner="Reading supported sets …")
 
55
  def load_model():
56
  return DraftModel()
57
 
58
+ @st.cache_data(show_spinner="Calculating P1P1 ...")
59
+ def p1p1_ranking(set_code: str):
60
+ names, scores = model.get_p1p1(set_code)
61
+ df = pd.DataFrame({"card": names, "p1p1_score": scores})
62
+ df = df.sort_values("p1p1_score", ascending=False, ignore_index=True)
63
+ df.index += 1 # 1-based ranks look nicer
64
+ return df
65
+
66
 
67
  @st.cache_data(show_spinner="Calculating card rankings …")
68
  def rank_cards(set_code: str) -> List[Dict]:
 
91
 
92
  model = load_model()
93
 
 
 
 
 
 
94
  def suggest_pick(pack: List[Dict], picks: List[Dict]) -> Dict:
95
  if model is None:
96
  return random.choice(pack)
 
109
  return data["card_faces"][0]["image_uris"]["normal"]
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # -----------------------------------------------------------------------------
113
  # 3. Streamlit UI
114
  # -----------------------------------------------------------------------------
 
127
  # Hide control in an expander (collapsed by default)
128
  with st.expander("Set selection", expanded=False):
129
  if supported_sets:
130
+ # UPDATED: dropdown instead of radio
131
+ set_code = st.selectbox(
132
+ "Choose a set",
133
+ supported_sets,
134
+ index=0,
135
+ key="set_code",
136
+ )
137
  else:
138
  st.warning(
139
+ "*supported_sets.txt* not found or empty. Using free-text input instead.",
140
  icon="⚠️",
141
  )
142
+ set_code = st.text_input("Set code", value="EOE", key="set_code")
143
 
144
  if st.button("Start new draft", type="primary"):
145
  st.session_state["pack"] = generate_booster(set_code)
 
152
 
153
  # -------- Main content organised in tabs ------------------------------------
154
 
155
+ tabs = st.tabs(["Draft", "P1P1 Rankings"])
156
 
157
  # --- Tab 1: Draft ------------------------------------------------------------
158
 
 
187
  st.header("Card rankings for set " + set_code)
188
 
189
  if set_code:
190
+ try:
191
+ df = p1p1_ranking(set_code.lower()) # cached; auto-updates on dropdown change
192
+ if not df.empty:
193
+ st.dataframe(df, use_container_width=True)
194
+ else:
195
+ st.info("No P1P1 results returned for this set.")
196
+ except Exception as e:
197
+ st.error(f"Could not calculate P1P1: {e}")
198
  else:
199
+ st.info("Select a set in the sidebar to view P1P1.")