import gradio as gr import json from src.predictor import MTGPickPredictor from src.utils import top_k_accuracy from src.mtga import card_name_to_image_url from src.mtga import mtga_id_to_image_url # TODO: find a cleaner way to define helper function required for the model top_3_accuracy = lambda o, t: top_k_accuracy(o, t, k=3) def main(): examples = { "DSK": { "pack": [92163, 92375, 92091, 92173, 92257, 92112, 92212], "pool": [92218, 92127, 92148, 92152, 92242, 92113, 92140] }, "FDN": { 'pack': [93741, 93974, 93966, 93870, 93781, 93796, 93875, 93744, 93938, 93733, 93801, 93732, 93943, 93832], 'pool': [] } } predictor = MTGPickPredictor() def predict_fn(set_code, input): state = json.loads(input.replace("'", "\"")) predictions = predictor.predict(set_code, state) # For gallery gallery_output = [ (card_name_to_image_url(card_name), f"{card_name} ({probability*100:.0f}%)") for _, card_name, probability in predictions ] # For probabilities list prob_output = [ (mtga_card_id, float(probability)) for mtga_card_id, _, probability in predictions ] return gallery_output, prob_output def update_example(set_choice): return json.dumps(examples.get(set_choice, {}), indent=2) def preview_cards(set_choice, text_input): try: input_json = json.loads(text_input.replace("'", "\"")) pack_cards = [ (mtga_id_to_image_url(card), f"{card}") for card in input_json.get('pack', []) ] pool_cards = [ (mtga_id_to_image_url(card), f"{card}") for card in input_json.get('pool', []) ] return pack_cards, pool_cards except json.JSONDecodeError: print("Error parsing JSON") return [], [] with gr.Blocks() as demo: gr.Markdown("# MTG Draft Pick Predictor") gr.Markdown("🔮 Predict the next best card for your draft deck.") with gr.Row(): set_dropdown = gr.Dropdown( choices=[('(FDN) Foundations','FDN'), ('(DSK) Duskmourn', 'DSK')], label='Set' ) text_input = gr.Text( value=json.dumps(examples.get('FDN', {}), indent=2), # Default example label="Pack/Pool:" ) with gr.Row(): pack_gallery = gr.Gallery( label="Pack", columns=3, object_fit='scale-down' ) pool_gallery = gr.Gallery( label="Pool", columns=3, object_fit='scale-down' ) button = gr.Button("Preview inputs").click( fn=preview_cards, inputs=[set_dropdown, text_input], outputs=[pack_gallery, pool_gallery] ) with gr.Row(): preidctions_gallery = gr.Gallery( label="Predictions", #height=300, columns=3, object_fit='scale-down', allow_preview=False ) predictions_probs = gr.JSON(label="Card IDs and Probabilities") # Update text input when dropdown changes set_dropdown.change( fn=update_example, inputs=[set_dropdown], outputs=[text_input] ) # Predict button gr.Button("Predict").click( fn=predict_fn, inputs=[set_dropdown, text_input], outputs=[preidctions_gallery, predictions_probs] ) demo.launch() if __name__ == "__main__": main()