Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| from src.predictor import MTGPickPredictor | |
| from src.utils import top_k_accuracy | |
| from src.mtga import card_name_to_image_url | |
| from src.mtga import mtga_id_to_image_url | |
| # TODO: find a cleaner way to define helper function required for the model | |
| top_3_accuracy = lambda o, t: top_k_accuracy(o, t, k=3) | |
| def main(): | |
| examples = { | |
| "DSK": { | |
| "pack": [92163, 92375, 92091, 92173, 92257, 92112, 92212], | |
| "pool": [92218, 92127, 92148, 92152, 92242, 92113, 92140] | |
| }, | |
| "FDN": { | |
| 'pack': [93741, 93974, 93966, 93870, 93781, 93796, 93875, 93744, 93938, 93733, 93801, 93732, 93943, 93832], | |
| 'pool': [] | |
| } | |
| } | |
| predictor = MTGPickPredictor() | |
| def predict_fn(set_code, input): | |
| state = json.loads(input.replace("'", "\"")) | |
| predictions = predictor.predict(set_code, state) | |
| # For gallery | |
| gallery_output = [ | |
| (card_name_to_image_url(card_name), f"{card_name} ({probability*100:.0f}%)") | |
| for _, card_name, probability in predictions | |
| ] | |
| # For probabilities list | |
| prob_output = [ | |
| (mtga_card_id, float(probability)) | |
| for mtga_card_id, _, probability in predictions | |
| ] | |
| return gallery_output, prob_output | |
| def update_example(set_choice): | |
| return json.dumps(examples.get(set_choice, {}), indent=2) | |
| def preview_cards(set_choice, text_input): | |
| try: | |
| input_json = json.loads(text_input.replace("'", "\"")) | |
| pack_cards = [ | |
| (mtga_id_to_image_url(card), f"{card}") | |
| for card in input_json.get('pack', []) | |
| ] | |
| pool_cards = [ | |
| (mtga_id_to_image_url(card), f"{card}") | |
| for card in input_json.get('pool', []) | |
| ] | |
| return pack_cards, pool_cards | |
| except json.JSONDecodeError: | |
| print("Error parsing JSON") | |
| return [], [] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# MTG Draft Pick Predictor") | |
| gr.Markdown("🔮 Predict the next best card for your draft deck.") | |
| with gr.Row(): | |
| set_dropdown = gr.Dropdown( | |
| choices=[('(FDN) Foundations','FDN'), ('(DSK) Duskmourn', 'DSK')], | |
| label='Set' | |
| ) | |
| text_input = gr.Text( | |
| value=json.dumps(examples.get('FDN', {}), indent=2), # Default example | |
| label="Pack/Pool:" | |
| ) | |
| with gr.Row(): | |
| pack_gallery = gr.Gallery( | |
| label="Pack", | |
| columns=3, | |
| object_fit='scale-down' | |
| ) | |
| pool_gallery = gr.Gallery( | |
| label="Pool", | |
| columns=3, | |
| object_fit='scale-down' | |
| ) | |
| button = gr.Button("Preview inputs").click( | |
| fn=preview_cards, | |
| inputs=[set_dropdown, text_input], | |
| outputs=[pack_gallery, pool_gallery] | |
| ) | |
| with gr.Row(): | |
| preidctions_gallery = gr.Gallery( | |
| label="Predictions", | |
| #height=300, | |
| columns=3, | |
| object_fit='scale-down', | |
| allow_preview=False | |
| ) | |
| predictions_probs = gr.JSON(label="Card IDs and Probabilities") | |
| # Update text input when dropdown changes | |
| set_dropdown.change( | |
| fn=update_example, | |
| inputs=[set_dropdown], | |
| outputs=[text_input] | |
| ) | |
| # Predict button | |
| gr.Button("Predict").click( | |
| fn=predict_fn, | |
| inputs=[set_dropdown, text_input], | |
| outputs=[preidctions_gallery, predictions_probs] | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |