Denis Lebedev
Test emoji
096e0d2 unverified
import gradio as gr
import json
from src.predictor import MTGPickPredictor
from src.utils import top_k_accuracy
from src.mtga import card_name_to_image_url
from src.mtga import mtga_id_to_image_url
# TODO: find a cleaner way to define helper function required for the model
top_3_accuracy = lambda o, t: top_k_accuracy(o, t, k=3)
def main():
examples = {
"DSK": {
"pack": [92163, 92375, 92091, 92173, 92257, 92112, 92212],
"pool": [92218, 92127, 92148, 92152, 92242, 92113, 92140]
},
"FDN": {
'pack': [93741, 93974, 93966, 93870, 93781, 93796, 93875, 93744, 93938, 93733, 93801, 93732, 93943, 93832],
'pool': []
}
}
predictor = MTGPickPredictor()
def predict_fn(set_code, input):
state = json.loads(input.replace("'", "\""))
predictions = predictor.predict(set_code, state)
# For gallery
gallery_output = [
(card_name_to_image_url(card_name), f"{card_name} ({probability*100:.0f}%)")
for _, card_name, probability in predictions
]
# For probabilities list
prob_output = [
(mtga_card_id, float(probability))
for mtga_card_id, _, probability in predictions
]
return gallery_output, prob_output
def update_example(set_choice):
return json.dumps(examples.get(set_choice, {}), indent=2)
def preview_cards(set_choice, text_input):
try:
input_json = json.loads(text_input.replace("'", "\""))
pack_cards = [
(mtga_id_to_image_url(card), f"{card}")
for card in input_json.get('pack', [])
]
pool_cards = [
(mtga_id_to_image_url(card), f"{card}")
for card in input_json.get('pool', [])
]
return pack_cards, pool_cards
except json.JSONDecodeError:
print("Error parsing JSON")
return [], []
with gr.Blocks() as demo:
gr.Markdown("# MTG Draft Pick Predictor")
gr.Markdown("🔮 Predict the next best card for your draft deck.")
with gr.Row():
set_dropdown = gr.Dropdown(
choices=[('(FDN) Foundations','FDN'), ('(DSK) Duskmourn', 'DSK')],
label='Set'
)
text_input = gr.Text(
value=json.dumps(examples.get('FDN', {}), indent=2), # Default example
label="Pack/Pool:"
)
with gr.Row():
pack_gallery = gr.Gallery(
label="Pack",
columns=3,
object_fit='scale-down'
)
pool_gallery = gr.Gallery(
label="Pool",
columns=3,
object_fit='scale-down'
)
button = gr.Button("Preview inputs").click(
fn=preview_cards,
inputs=[set_dropdown, text_input],
outputs=[pack_gallery, pool_gallery]
)
with gr.Row():
preidctions_gallery = gr.Gallery(
label="Predictions",
#height=300,
columns=3,
object_fit='scale-down',
allow_preview=False
)
predictions_probs = gr.JSON(label="Card IDs and Probabilities")
# Update text input when dropdown changes
set_dropdown.change(
fn=update_example,
inputs=[set_dropdown],
outputs=[text_input]
)
# Predict button
gr.Button("Predict").click(
fn=predict_fn,
inputs=[set_dropdown, text_input],
outputs=[preidctions_gallery, predictions_probs]
)
demo.launch()
if __name__ == "__main__":
main()