Upload 4 files
Browse files- app.py +84 -0
- formation_predictor.onnx +3 -0
- label_encoder.pkl +3 -0
- requirements.txt +4 -0
app.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import onnxruntime as ort
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pickle
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
# Load the ONNX model
|
| 8 |
+
onnx_model_path = "formation_predictor.onnx"
|
| 9 |
+
ort_session = ort.InferenceSession(onnx_model_path)
|
| 10 |
+
|
| 11 |
+
# Function to convert input data to one-hot encoding
|
| 12 |
+
def to_one_hot(indices, num_classes):
|
| 13 |
+
indices = np.array(indices, dtype=int)
|
| 14 |
+
return np.eye(num_classes)[indices]
|
| 15 |
+
|
| 16 |
+
# Load the label encoder
|
| 17 |
+
def load_label_encoder():
|
| 18 |
+
with open("label_encoder.pkl", "rb") as f:
|
| 19 |
+
le = pickle.load(f)
|
| 20 |
+
return le
|
| 21 |
+
|
| 22 |
+
le = load_label_encoder()
|
| 23 |
+
num_classes = len(le.classes_)
|
| 24 |
+
|
| 25 |
+
# Function to prepare input data
|
| 26 |
+
def prepare_input(opponent_formation, le, num_classes):
|
| 27 |
+
opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
|
| 28 |
+
opp_idx = le.transform([opponent_formation])[0] if isinstance(opponent_formation, str) else opponent_formation
|
| 29 |
+
opp_one_hot = to_one_hot([opp_idx], num_classes)
|
| 30 |
+
return opp_one_hot
|
| 31 |
+
|
| 32 |
+
# Function to recommend formation using ONNX model
|
| 33 |
+
def recommend_formation_onnx(opponent_formation, ort_session, le, num_classes):
|
| 34 |
+
opp_one_hot = prepare_input(opponent_formation, le, num_classes)
|
| 35 |
+
|
| 36 |
+
best_formation, best_score = None, -float("inf")
|
| 37 |
+
evaluated_formations = []
|
| 38 |
+
for our_idx in range(num_classes):
|
| 39 |
+
our_one_hot = to_one_hot([our_idx], num_classes)
|
| 40 |
+
input_vector = np.concatenate([opp_one_hot, our_one_hot], axis=1).astype(np.float32)
|
| 41 |
+
|
| 42 |
+
# Run the ONNX model
|
| 43 |
+
ort_inputs = {ort_session.get_inputs()[0].name: input_vector}
|
| 44 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
| 45 |
+
score = ort_outs[0][0, 0]
|
| 46 |
+
|
| 47 |
+
formation = le.inverse_transform([our_idx])[0]
|
| 48 |
+
evaluated_formations.append((formation, score))
|
| 49 |
+
|
| 50 |
+
if score > best_score:
|
| 51 |
+
best_score = score
|
| 52 |
+
best_formation = formation
|
| 53 |
+
|
| 54 |
+
evaluated_formations.sort(key=lambda x: x[1], reverse=True)
|
| 55 |
+
return best_formation, evaluated_formations
|
| 56 |
+
|
| 57 |
+
# Function to handle the recommend button click
|
| 58 |
+
def recommend(opponent_formation):
|
| 59 |
+
opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
|
| 60 |
+
|
| 61 |
+
# Validate the format of the opponent formation
|
| 62 |
+
if not re.match(r'^\d+(-\d+)+$', opponent_formation):
|
| 63 |
+
return f"Error: Formation '{opponent_formation}' is not in the correct format (e.g., '3-4-2-1').", []
|
| 64 |
+
|
| 65 |
+
if opponent_formation not in le.classes_:
|
| 66 |
+
return f"Error: Formation '{opponent_formation}' not recognized.", []
|
| 67 |
+
|
| 68 |
+
best_formation, evaluated_formations = recommend_formation_onnx(opponent_formation, ort_session, le, num_classes)
|
| 69 |
+
return f"Recommended formation: {best_formation}", evaluated_formations
|
| 70 |
+
|
| 71 |
+
# Create the Gradio interface
|
| 72 |
+
iface = gr.Interface(
|
| 73 |
+
fn=recommend,
|
| 74 |
+
inputs=gr.inputs.Textbox(lines=1, placeholder="Enter opponent formation (e.g., '3-4-2-1')"),
|
| 75 |
+
outputs=[
|
| 76 |
+
gr.outputs.Textbox(label="Recommended Formation"),
|
| 77 |
+
gr.outputs.Dataframe(headers=["Formation", "Score"], label="Evaluated Formations")
|
| 78 |
+
],
|
| 79 |
+
title="Deepfield Proyecto Maradona E3 Football Formation Recommender",
|
| 80 |
+
description="Enter the opponent formation to get the recommended formation and a list of evaluated formations with their scores."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Launch the Gradio interface
|
| 84 |
+
iface.launch()
|
formation_predictor.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1de5260c6c8457b12da0c2930bc5565b55099bdad0831cab031a690bd8244d94
|
| 3 |
+
size 18124
|
label_encoder.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c26c641ba961031566abddac21718a0bd1e01e1e1e31022a08a81b594c7095fd
|
| 3 |
+
size 402
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
onnxruntime
|
| 3 |
+
numpy
|
| 4 |
+
scikit-learn
|