File size: 3,909 Bytes
f1a26d2
 
 
87acbba
 
f1a26d2
 
87acbba
 
 
 
 
f1a26d2
 
 
 
 
 
 
 
 
 
 
87acbba
f1a26d2
 
 
8f19b3a
 
 
 
 
7848c86
8f19b3a
 
 
 
 
 
 
 
 
 
f1a26d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87acbba
f1a26d2
 
096e0d2
f1a26d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f19b3a
 
 
 
 
 
 
 
 
 
 
f1a26d2
 
 
 
 
 
 
 
 
 
 
8f19b3a
f1a26d2
 
bc19533
87acbba
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import json

from src.predictor import MTGPickPredictor
from src.utils import top_k_accuracy
from src.mtga import card_name_to_image_url
from src.mtga import mtga_id_to_image_url

# TODO: find a cleaner way to define helper function required for the model
top_3_accuracy = lambda o, t: top_k_accuracy(o, t, k=3)

def main():
    examples = {
        "DSK": { 
            "pack": [92163, 92375, 92091, 92173, 92257, 92112, 92212], 
            "pool": [92218, 92127, 92148, 92152, 92242, 92113, 92140]
            },
        "FDN": {
            'pack': [93741, 93974, 93966, 93870, 93781, 93796, 93875, 93744, 93938, 93733, 93801, 93732, 93943, 93832], 
            'pool': []
            }
    }

    predictor = MTGPickPredictor()

    def predict_fn(set_code, input):
        state = json.loads(input.replace("'", "\""))
        predictions = predictor.predict(set_code, state)
        
        # For gallery
        gallery_output = [
            (card_name_to_image_url(card_name), f"{card_name} ({probability*100:.0f}%)")
            for _, card_name, probability in predictions
        ]
        
        # For probabilities list
        prob_output = [
            (mtga_card_id, float(probability))
            for mtga_card_id, _, probability in predictions
        ]
        
        return gallery_output, prob_output

    def update_example(set_choice):
        return json.dumps(examples.get(set_choice, {}), indent=2)
    
    def preview_cards(set_choice, text_input):
        try:
            input_json = json.loads(text_input.replace("'", "\""))
            
            pack_cards = [
                (mtga_id_to_image_url(card), f"{card}")
                for card in input_json.get('pack', [])
            ]
            
            pool_cards = [
                (mtga_id_to_image_url(card), f"{card}")
                for card in input_json.get('pool', [])
            ]
            
            return pack_cards, pool_cards
        
        except json.JSONDecodeError:
            print("Error parsing JSON")
            return [], []
    
    with gr.Blocks() as demo:
        gr.Markdown("# MTG Draft Pick Predictor")
        gr.Markdown("🔮 Predict the next best card for your draft deck.")

        with gr.Row():
            set_dropdown = gr.Dropdown(
                choices=[('(FDN) Foundations','FDN'), ('(DSK) Duskmourn', 'DSK')],
                label='Set'
            )
            text_input = gr.Text(
                value=json.dumps(examples.get('FDN', {}), indent=2),  # Default example
                label="Pack/Pool:"
            )
        
        with gr.Row():
            pack_gallery = gr.Gallery(
                label="Pack",
                columns=3,
                object_fit='scale-down'
            )
            pool_gallery = gr.Gallery(
                label="Pool",
                columns=3,
                object_fit='scale-down'
            )

        button = gr.Button("Preview inputs").click(
            fn=preview_cards,
            inputs=[set_dropdown, text_input],
            outputs=[pack_gallery, pool_gallery]
        )

        with gr.Row():
            preidctions_gallery = gr.Gallery(
                label="Predictions",
                #height=300,
                columns=3,
                object_fit='scale-down',
                allow_preview=False
            )

            predictions_probs = gr.JSON(label="Card IDs and Probabilities")

        # Update text input when dropdown changes
        set_dropdown.change(
            fn=update_example,
            inputs=[set_dropdown],
            outputs=[text_input]
        )

        # Predict button
        gr.Button("Predict").click(
            fn=predict_fn,
            inputs=[set_dropdown, text_input],
            outputs=[preidctions_gallery, predictions_probs]
        )

    demo.launch()

if __name__ == "__main__":
    main()