MTG_Drafting_AI / src /streamlit_app.py
Timo
Works now
8ef1a8f
raw
history blame
11 kB
"""
---------------------------------
Booster‑draft helper for Magic: The Gathering, built with Streamlit and ready
for Hugging Face Spaces deployment.
🆕 UI tweaks in this revision
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* **Set selection is now hidden** in an *expander* inside the sidebar—keeps the
layout clean.
* Added a **second tab – “Card rankings”**. When you pick a set, the tab shows a
(stub) ranked list of cards from that set. Replace the `rank_cards()` stub
with real logic later.
"""
from __future__ import annotations
import os, pathlib
# Writable locations --------------------------------------------------
os.environ.setdefault("HOME", "/tmp") # fixes expanduser("~")
os.environ.setdefault("STREAMLIT_HOME", "/tmp/.st_home") # Streamlit telemetry
os.environ.setdefault("HF_HOME", "/tmp/hf_cache") # Hub model cache
# --------------------------------------------------------------------
# create the dirs once
for _dir in (os.environ["STREAMLIT_HOME"], os.environ["HF_HOME"]):
pathlib.Path(_dir).mkdir(parents=True, exist_ok=True)
# disable usage-stats banner *before* importing streamlit
os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
import random
from pathlib import Path
from typing import Dict, List
import pandas as pd
import copy
import numpy as np
import requests
import streamlit as st
from draft_model import DraftModel
SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
@st.cache_resource(show_spinner="Loading draft model …")
def load_model():
return DraftModel()
# --- callbacks ---
def _add_selected_to_deck():
val = st.session_state.get("deck_selectbox")
if val:
add_card("deck", val)
st.session_state["deck_selectbox"] = None # clear selection
st.toast(f"Added to deck: {val}")
def _add_selected_to_pack():
val = st.session_state.get("pack_selectbox")
if val:
success = add_card("pack", val)
st.session_state["pack_selectbox"] = None # clear selection
if success:
st.toast(f"Added to pack: {val}")
#st.rerun()
def _reset_draft_state():
st.session_state["pack"] = []
st.session_state["deck"] = []
st.session_state["undo_stack"] = []
st.session_state["deck_selectbox"] = None
st.session_state["pack_selectbox"] = None
def _on_set_changed():
curr = st.session_state.get("set_code")
prev = st.session_state.get("prev_set_code")
if prev != curr:
_reset_draft_state()
st.session_state["prev_set_code"] = curr
st.toast(f"Switched to set {curr}. Cleared current pack & deck.")
if "model" not in st.session_state:
st.session_state.model = load_model() # your class
if "deck" not in st.session_state:
st.session_state.deck: List[str] = []
if "pack" not in st.session_state:
st.session_state.pack: List[str] = []
if "undo_stack" not in st.session_state:
st.session_state.undo_stack: List[str] = []
if "set_code" not in st.session_state:
# choose a default set code that exists in model.cards, e.g., "eoe"
st.session_state.set_code = "EOE"
model = st.session_state.model
@st.cache_data(show_spinner="Reading supported sets …")
def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
"""Return a list of legal set codes read from *supported_sets.txt*."""
if path.is_file():
return [ln.strip() for ln in path.read_text().splitlines() if ln.strip()]
return []
@st.cache_data(show_spinner="Calculating P1P1 ...")
def p1p1_ranking(set_code: str):
names, scores = model.get_p1p1(set_code)
df = pd.DataFrame({"card": names, "p1p1_score": scores})
df = df.sort_values("p1p1_score", ascending=False, ignore_index=True)
df.index += 1 # 1-based ranks look nicer
return df
@st.cache_data(show_spinner="Calculating card rankings …")
def rank_cards(deck: List[str], pack: List[str]) -> List[Dict]:
if not deck:
out = model.predict(pack, deck = None)
else:
out = model.predict(pack, deck = deck)
pick = out["pick"]
logits = {pack[i]: score for i, score in enumerate(out["logits"])}
scores = {pack[i]: score for i, score in enumerate(out["scores"])}
return pick, logits, scores
def fetch_card_image(card_name: str) -> str:
r = requests.get(
"https://api.scryfall.com/cards/named",
params={"exact": card_name, "format": "json"},
)
r.raise_for_status()
data = r.json()
if "image_uris" in data:
return data["image_uris"]["normal"]
return data["card_faces"][0]["image_uris"]["normal"]
# -----------------------------------------------------------------------------
# 3. Streamlit UI
# -----------------------------------------------------------------------------
st.set_page_config(page_title="MTG Draft Assistant", page_icon="🃏")
st.title("🃏 MTG Draft Assistant")
# -------- Sidebar ------------------------------------------------------------
with st.sidebar:
st.header("Draft setup")
supported_sets = get_supported_sets()
set_code = st.selectbox(
"Choose a set",
supported_sets,
index=0,
key="set_code",
on_change=_on_set_changed
)
# -------- Main content organised in tabs ------------------------------------
tabs = st.tabs(["Draft", "P1P1 Rankings"])
def add_card(target: str, card: str):
"""target is 'pack' or 'picks'."""
if target == "pack":
if card not in st.session_state["pack"]:
st.session_state[target].append(card)
else:
st.warning(f"{card} is already in the pack.", icon="⚠️")
return False
elif target == "deck":
st.session_state[target].append(card)
return True
def remove_card(target: str, key: str):
lst = st.session_state[target]
idx = next((i for i, c in enumerate(lst) if c == key), None)
if idx is not None:
lst.pop(idx)
def push_undo():
"""Save a snapshot of pack + picks so we can undo one step."""
st.session_state["undo_stack"].append({
"pack": copy.deepcopy(st.session_state["pack"]),
"deck": copy.deepcopy(st.session_state["deck"]),
})
# (optional) cap history
if len(st.session_state["undo_stack"]) > 20:
st.session_state["undo_stack"].pop(0)
def undo_last():
if st.session_state.get("undo_stack"):
snap = st.session_state["undo_stack"].pop()
st.session_state["pack"] = snap["pack"]
st.session_state["deck"] = snap["deck"]
# --- Tab 1: Draft -------------------------------------------------------
with tabs[0]:
if st.session_state["undo_stack"]:
st.button("↩️ Undo last action", on_click=undo_last)
scores = {}
pick = None
if st.session_state["pack"]:
pack = st.session_state["pack"]
deck = st.session_state["deck"]
try:
pick, logits, scores = rank_cards(deck, pack)
if pick:
st.success(f"💡 Suggested pick: **{pick}**", icon="✨")
except Exception as e:
st.error(f"Error calculating card rankings: {e}")
options = list(model.cards[set_code.lower()].keys())
c1, c2 = st.columns(2)
with c1:
st.subheader("Add to Deck")
deck_sel = st.selectbox(
"Search card (deck)",
options,
index=None,
placeholder="Type to search…",
key="deck_selectbox",
on_change=_add_selected_to_deck, # <- auto-add
)
if st.session_state["deck"]:
# header row
st.button("🗑️ Clear deck", on_click=lambda: st.session_state.update(deck=[]), use_container_width=True)
h1, h2 = st.columns([6, 3])
h1.markdown("**Card**")
h2.markdown("**Remove?**")
for i, card in enumerate(st.session_state["deck"]):
name_col, rm_col = st.columns([6, 3], gap="small")
name_col.write(card)
with rm_col:
if st.button("Remove", key=f"rm-deck-{i}", use_container_width=True):
remove_card("deck", card)
st.rerun()
else:
st.caption("Deck is empty.")
with c2:
st.subheader("Add to pack")
pack_sel = st.selectbox(
"Search card (pack)",
options,
index=None,
placeholder="Type to search…",
key="pack_selectbox",
on_change=_add_selected_to_pack, # <- auto-add
)
if st.session_state["pack"]:
# header row
st.button("🗑️ Clear pack", on_click=lambda: st.session_state.update(pack=[]), use_container_width=True)
h1, h2, h3 = st.columns([6, 2, 3])
h1.markdown("**Card**")
h2.markdown("**Score**")
h3.markdown("**Pick?**")
pack_list = st.session_state["pack"]
vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
logits = [logits.get(c) if logits and c in logits else np.nan for c in pack_list]
df_scores = pd.DataFrame({"Card": pack_list, "Score": vals, "Logits": logits})
df_scores = df_scores.sort_values("Score", ascending=False, na_position="last").reset_index(drop=True)
# rows
for i, row in df_scores.iterrows():
card = row["Card"]
score = row["Score"]
logit = row["Logits"]
c1, c2, c3 = st.columns([6, 2, 3], gap="small")
c1.write(card)
tooltip_html = f"""
<div title="{logit:.4f}">
<progress Value="{score}" max="1" style="width: 100%; height: 20px;"></progress>
</div>
"""
c2.markdown(tooltip_html, unsafe_allow_html=True)
with c3:
if st.button("Pick", key=f"pick_btn_{i}", use_container_width=True, help="Add to deck & clear pack"):
push_undo()
st.session_state["deck"].append(card)
st.session_state["pack"] = []
st.rerun()
else:
st.caption("Pack is empty.")
# --- Tab 2: Card rankings ----------------------------------------------------
with tabs[1]:
st.header("Card rankings for set " + set_code)
if set_code:
try:
df = p1p1_ranking(set_code.lower()) # cached; auto-updates on dropdown change
if not df.empty:
st.dataframe(df, use_container_width=True)
else:
st.info("No P1P1 results returned for this set.")
except Exception as e:
st.error(f"Could not calculate P1P1: {e}")
else:
st.info("Select a set in the sidebar to view P1P1.")