Spaces:
Sleeping
Sleeping
File size: 3,909 Bytes
f1a26d2 87acbba f1a26d2 87acbba f1a26d2 87acbba f1a26d2 8f19b3a 7848c86 8f19b3a f1a26d2 87acbba f1a26d2 096e0d2 f1a26d2 8f19b3a f1a26d2 8f19b3a f1a26d2 bc19533 87acbba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
import json
from src.predictor import MTGPickPredictor
from src.utils import top_k_accuracy
from src.mtga import card_name_to_image_url
from src.mtga import mtga_id_to_image_url
# TODO: find a cleaner way to define helper function required for the model
top_3_accuracy = lambda o, t: top_k_accuracy(o, t, k=3)
def main():
examples = {
"DSK": {
"pack": [92163, 92375, 92091, 92173, 92257, 92112, 92212],
"pool": [92218, 92127, 92148, 92152, 92242, 92113, 92140]
},
"FDN": {
'pack': [93741, 93974, 93966, 93870, 93781, 93796, 93875, 93744, 93938, 93733, 93801, 93732, 93943, 93832],
'pool': []
}
}
predictor = MTGPickPredictor()
def predict_fn(set_code, input):
state = json.loads(input.replace("'", "\""))
predictions = predictor.predict(set_code, state)
# For gallery
gallery_output = [
(card_name_to_image_url(card_name), f"{card_name} ({probability*100:.0f}%)")
for _, card_name, probability in predictions
]
# For probabilities list
prob_output = [
(mtga_card_id, float(probability))
for mtga_card_id, _, probability in predictions
]
return gallery_output, prob_output
def update_example(set_choice):
return json.dumps(examples.get(set_choice, {}), indent=2)
def preview_cards(set_choice, text_input):
try:
input_json = json.loads(text_input.replace("'", "\""))
pack_cards = [
(mtga_id_to_image_url(card), f"{card}")
for card in input_json.get('pack', [])
]
pool_cards = [
(mtga_id_to_image_url(card), f"{card}")
for card in input_json.get('pool', [])
]
return pack_cards, pool_cards
except json.JSONDecodeError:
print("Error parsing JSON")
return [], []
with gr.Blocks() as demo:
gr.Markdown("# MTG Draft Pick Predictor")
gr.Markdown("🔮 Predict the next best card for your draft deck.")
with gr.Row():
set_dropdown = gr.Dropdown(
choices=[('(FDN) Foundations','FDN'), ('(DSK) Duskmourn', 'DSK')],
label='Set'
)
text_input = gr.Text(
value=json.dumps(examples.get('FDN', {}), indent=2), # Default example
label="Pack/Pool:"
)
with gr.Row():
pack_gallery = gr.Gallery(
label="Pack",
columns=3,
object_fit='scale-down'
)
pool_gallery = gr.Gallery(
label="Pool",
columns=3,
object_fit='scale-down'
)
button = gr.Button("Preview inputs").click(
fn=preview_cards,
inputs=[set_dropdown, text_input],
outputs=[pack_gallery, pool_gallery]
)
with gr.Row():
preidctions_gallery = gr.Gallery(
label="Predictions",
#height=300,
columns=3,
object_fit='scale-down',
allow_preview=False
)
predictions_probs = gr.JSON(label="Card IDs and Probabilities")
# Update text input when dropdown changes
set_dropdown.change(
fn=update_example,
inputs=[set_dropdown],
outputs=[text_input]
)
# Predict button
gr.Button("Predict").click(
fn=predict_fn,
inputs=[set_dropdown, text_input],
outputs=[preidctions_gallery, predictions_probs]
)
demo.launch()
if __name__ == "__main__":
main() |