Spaces:
Running
Running
sghorbal
commited on
Commit
·
7effb2a
0
Parent(s):
initial commit
Browse files- .env.example +16 -0
- .gitattributes +1 -0
- .gitignore +177 -0
- Dockerfile +26 -0
- Jenkinsfile +79 -0
- LICENSE +21 -0
- README.md +44 -0
- api.png +3 -0
- entrypoint.sh +4 -0
- requirements.txt +10 -0
- src/__init__.py +0 -0
- src/enums.py +38 -0
- src/main.py +222 -0
- src/model.py +310 -0
- src/sql.py +73 -0
- tests/__init__.py +0 -0
- tests/conftest.py +39 -0
- tests/test_enums.py +21 -0
- tests/test_model.py +23 -0
.env.example
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PostgreSQL database connection information
|
| 2 |
+
PG_USER=
|
| 3 |
+
PG_PASSWORD=
|
| 4 |
+
PG_HOST=
|
| 5 |
+
PG_PORT=
|
| 6 |
+
PG_DB=
|
| 7 |
+
# One among disable, allow, prefer, require, verify-ca and verify-full
|
| 8 |
+
PG_SSLMODE=
|
| 9 |
+
|
| 10 |
+
# If set, protects the API from unauthorized called
|
| 11 |
+
FASTAPI_API_KEY=
|
| 12 |
+
|
| 13 |
+
MLFLOW_SERVER_URI=
|
| 14 |
+
|
| 15 |
+
AWS_ACCESS_KEY_ID=
|
| 16 |
+
AWS_SECRET_ACCESS_KEY=
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
api.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/*
|
| 2 |
+
**/*.ipynb
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py,cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
# IPython
|
| 85 |
+
profile_default/
|
| 86 |
+
ipython_config.py
|
| 87 |
+
|
| 88 |
+
# pyenv
|
| 89 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 91 |
+
# .python-version
|
| 92 |
+
|
| 93 |
+
# pipenv
|
| 94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 97 |
+
# install all needed dependencies.
|
| 98 |
+
#Pipfile.lock
|
| 99 |
+
|
| 100 |
+
# UV
|
| 101 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 103 |
+
# commonly ignored for libraries.
|
| 104 |
+
#uv.lock
|
| 105 |
+
|
| 106 |
+
# poetry
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 111 |
+
#poetry.lock
|
| 112 |
+
|
| 113 |
+
# pdm
|
| 114 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 115 |
+
#pdm.lock
|
| 116 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 117 |
+
# in version control.
|
| 118 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 119 |
+
.pdm.toml
|
| 120 |
+
.pdm-python
|
| 121 |
+
.pdm-build/
|
| 122 |
+
|
| 123 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 124 |
+
__pypackages__/
|
| 125 |
+
|
| 126 |
+
# Celery stuff
|
| 127 |
+
celerybeat-schedule
|
| 128 |
+
celerybeat.pid
|
| 129 |
+
|
| 130 |
+
# SageMath parsed files
|
| 131 |
+
*.sage.py
|
| 132 |
+
|
| 133 |
+
# Environments
|
| 134 |
+
.env
|
| 135 |
+
.venv
|
| 136 |
+
env/
|
| 137 |
+
venv/
|
| 138 |
+
ENV/
|
| 139 |
+
env.bak/
|
| 140 |
+
venv.bak/
|
| 141 |
+
|
| 142 |
+
# Spyder project settings
|
| 143 |
+
.spyderproject
|
| 144 |
+
.spyproject
|
| 145 |
+
|
| 146 |
+
# Rope project settings
|
| 147 |
+
.ropeproject
|
| 148 |
+
|
| 149 |
+
# mkdocs documentation
|
| 150 |
+
/site
|
| 151 |
+
|
| 152 |
+
# mypy
|
| 153 |
+
.mypy_cache/
|
| 154 |
+
.dmypy.json
|
| 155 |
+
dmypy.json
|
| 156 |
+
|
| 157 |
+
# Pyre type checker
|
| 158 |
+
.pyre/
|
| 159 |
+
|
| 160 |
+
# pytype static type analyzer
|
| 161 |
+
.pytype/
|
| 162 |
+
|
| 163 |
+
# Cython debug symbols
|
| 164 |
+
cython_debug/
|
| 165 |
+
|
| 166 |
+
# PyCharm
|
| 167 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 168 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 169 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 170 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 171 |
+
#.idea/
|
| 172 |
+
|
| 173 |
+
# PyPI configuration file
|
| 174 |
+
.pypirc
|
| 175 |
+
|
| 176 |
+
.DS_Store
|
| 177 |
+
*.bck
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM tiangolo/uvicorn-gunicorn:python3.11
|
| 2 |
+
|
| 3 |
+
COPY ./requirements.txt /tmp/requirements.txt
|
| 4 |
+
COPY ./entrypoint.sh /tmp/entrypoint.sh
|
| 5 |
+
COPY ./src /app/src
|
| 6 |
+
COPY ./tests /app/tests
|
| 7 |
+
|
| 8 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
ENV PYTHONPATH=/app
|
| 13 |
+
|
| 14 |
+
# Port to expose
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
|
| 17 |
+
# Health Check
|
| 18 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 19 |
+
CMD [ "curl", "-f", "http://localhost:7860/check_health" ]
|
| 20 |
+
|
| 21 |
+
# Create a non-root user 'appuser' and switch to this user
|
| 22 |
+
RUN useradd --create-home appuser
|
| 23 |
+
USER appuser
|
| 24 |
+
|
| 25 |
+
# CMD with JSON notation
|
| 26 |
+
CMD ["/tmp/entrypoint.sh"]
|
Jenkinsfile
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline {
|
| 2 |
+
agent any
|
| 3 |
+
|
| 4 |
+
stages {
|
| 5 |
+
stage('Checkout') {
|
| 6 |
+
steps {
|
| 7 |
+
// Checkout the code from the repository
|
| 8 |
+
git branch: 'master', url: 'https://github.com/slim-git/tennis-api/'
|
| 9 |
+
}
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
stage('Build Docker Image') {
|
| 13 |
+
steps {
|
| 14 |
+
script {
|
| 15 |
+
// Build the Docker image using the Dockerfile
|
| 16 |
+
sh 'docker build -t tennis_api .'
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
stage('Run Tests Inside Docker Container') {
|
| 22 |
+
steps {
|
| 23 |
+
withCredentials([
|
| 24 |
+
string(credentialsId: 'MLFLOW_SERVER_URI', variable: 'MLFLOW_SERVER_URI'),
|
| 25 |
+
string(credentialsId: 'AWS_ACCESS_KEY_ID', variable: 'AWS_ACCESS_KEY_ID'),
|
| 26 |
+
string(credentialsId: 'AWS_SECRET_ACCESS_KEY', variable: 'AWS_SECRET_ACCESS_KEY'),
|
| 27 |
+
string(credentialsId: 'PG_USER', variable: 'PG_USER'),
|
| 28 |
+
string(credentialsId: 'PG_PASSWORD', variable: 'PG_PASSWORD'),
|
| 29 |
+
string(credentialsId: 'PG_HOST', variable: 'PG_HOST'),
|
| 30 |
+
string(credentialsId: 'PG_PORT', variable: 'PG_PORT'),
|
| 31 |
+
string(credentialsId: 'PG_DB', variable: 'PG_DB'),
|
| 32 |
+
string(credentialsId: 'PG_SSLMODE', variable: 'PG_SSLMODE')
|
| 33 |
+
]) {
|
| 34 |
+
// Write environment variables to a temporary file
|
| 35 |
+
// KEEP SINGLE QUOTE FOR SECURITY PURPOSES (MORE INFO HERE: https://www.jenkins.io/doc/book/pipeline/jenkinsfile/#handling-credentials)
|
| 36 |
+
script {
|
| 37 |
+
writeFile file: 'env.list', text: '''
|
| 38 |
+
MLFLOW_SERVER_URI=${MLFLOW_SERVER_URI}
|
| 39 |
+
AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
|
| 40 |
+
AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
|
| 41 |
+
PG_USER=${PG_USER}
|
| 42 |
+
PG_PASSWORD=${PG_PASSWORD}
|
| 43 |
+
PG_HOST=${PG_HOST}
|
| 44 |
+
PG_PORT=${PG_PORT}
|
| 45 |
+
PG_DB=${PG_DB}
|
| 46 |
+
PG_SSLMODE=${PG_SSLMODE}
|
| 47 |
+
'''
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// Run a temporary Docker container and pass env variables securely via --env-file
|
| 51 |
+
sh '''
|
| 52 |
+
docker run --rm --env-file env.list \
|
| 53 |
+
tennis_api \
|
| 54 |
+
bash -c "pytest --maxfail=1 --disable-warnings"
|
| 55 |
+
'''
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
post {
|
| 62 |
+
always {
|
| 63 |
+
// Clean up workspace and remove dangling Docker images
|
| 64 |
+
sh 'docker system prune -f'
|
| 65 |
+
}
|
| 66 |
+
success {
|
| 67 |
+
withCredentials([
|
| 68 |
+
string(credentialsId: 'HF_USERNAME', variable: 'HF_USERNAME'),
|
| 69 |
+
string(credentialsId: 'HF_TOKEN', variable: 'HF_TOKEN')
|
| 70 |
+
]) {
|
| 71 |
+
echo 'Pipeline completed successfully! Pushing to huggingFace'
|
| 72 |
+
sh 'git push --force https://${HF_USERNAME}:${HF_TOKEN}@huggingface.co/spaces/${HF_USERNAME}/tennis-api master:main'
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
failure {
|
| 76 |
+
echo 'Pipeline failed. Check logs for errors.'
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
}
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 slim-git
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Tennis Api
|
| 3 |
+
emoji: ⚡
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
short_description: API for training and interacting with tennis-insights models
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 12 |
+
|
| 13 |
+
# tennis-api
|
| 14 |
+
|
| 15 |
+
## Docker Install
|
| 16 |
+
|
| 17 |
+
To get the service for training the model and giving predictions up and running locally, simply follow the steps hereafter:
|
| 18 |
+
|
| 19 |
+
### Build the API image
|
| 20 |
+
|
| 21 |
+
From the root of the project:
|
| 22 |
+
```bash
|
| 23 |
+
$> docker build . -t tennis_api:latest -f Dockerfile
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### Run it
|
| 27 |
+
|
| 28 |
+
From the root of the project:
|
| 29 |
+
```bash
|
| 30 |
+
$> docker run --rm -p 7860:7860 --mount type=bind,src=./.env,target=/app/.env tennis_api:latest
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Then go to [http://localhost:7860/](http://localhost:7860/)
|
| 34 |
+
|
| 35 |
+
The API should be accessible:
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
## Resources
|
| 39 |
+
|
| 40 |
+
Website: [http://www.tennis-data.co.uk/alldata.php](http://www.tennis-data.co.uk/alldata.php)
|
| 41 |
+
|
| 42 |
+
## License
|
| 43 |
+
|
| 44 |
+
©2025
|
api.png
ADDED
|
Git LFS Details
|
entrypoint.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Run the API
|
| 4 |
+
uvicorn src.main:app --host 0.0.0.0 --port 7860
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python-dotenv
|
| 2 |
+
fastapi
|
| 3 |
+
psycopg2-binary
|
| 4 |
+
pandas
|
| 5 |
+
scikit-learn
|
| 6 |
+
openpyxl
|
| 7 |
+
xlrd >= 2.0.1
|
| 8 |
+
mlflow
|
| 9 |
+
boto3
|
| 10 |
+
pytest
|
src/__init__.py
ADDED
|
File without changes
|
src/enums.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import List, Literal
|
| 4 |
+
|
| 5 |
+
class Feature(Enum):
|
| 6 |
+
_name: str
|
| 7 |
+
_type: Literal['category', 'number']
|
| 8 |
+
|
| 9 |
+
SERIES = ('Series', 'category')
|
| 10 |
+
SURFACE = ('Surface', 'category')
|
| 11 |
+
COURT = ('Court', 'category')
|
| 12 |
+
ROUND = ('Round', 'category')
|
| 13 |
+
DIFF_RANKING = ('diffRanking', 'number')
|
| 14 |
+
DIFF_POINTS = ('diffPoints', 'number')
|
| 15 |
+
|
| 16 |
+
def __new__(cls, name: str, type: Literal['category', 'number']):
|
| 17 |
+
obj = object.__new__(cls)
|
| 18 |
+
obj._value_ = name
|
| 19 |
+
obj._name = name
|
| 20 |
+
obj._type = type
|
| 21 |
+
|
| 22 |
+
return obj
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def name(self):
|
| 26 |
+
return self._name
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def type(self):
|
| 30 |
+
return self._type
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def get_features_by_type(cls, type: Literal['category', 'number']) -> List['Feature']:
|
| 34 |
+
return [feature for feature in cls if feature.type == type]
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def get_all_features(cls) -> List['Feature']:
|
| 38 |
+
return [feature for feature in cls]
|
src/main.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import joblib
|
| 3 |
+
import logging
|
| 4 |
+
import secrets
|
| 5 |
+
from typing import Literal, Optional, Annotated
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from fastapi import (
|
| 8 |
+
FastAPI,
|
| 9 |
+
Request,
|
| 10 |
+
HTTPException,
|
| 11 |
+
Query,
|
| 12 |
+
Security,
|
| 13 |
+
Depends
|
| 14 |
+
)
|
| 15 |
+
from fastapi.background import BackgroundTasks
|
| 16 |
+
from fastapi.responses import RedirectResponse
|
| 17 |
+
from fastapi.security.api_key import APIKeyHeader
|
| 18 |
+
from pydantic import BaseModel, Field
|
| 19 |
+
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
from mlflow.exceptions import RestException
|
| 22 |
+
|
| 23 |
+
from src.model import (
|
| 24 |
+
run_experiment,
|
| 25 |
+
train_model_from_scratch,
|
| 26 |
+
predict,
|
| 27 |
+
list_registered_models,
|
| 28 |
+
load_model
|
| 29 |
+
)
|
| 30 |
+
from src.sql import (
|
| 31 |
+
_get_connection,
|
| 32 |
+
list_tournaments as _list_tournaments,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# ------------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
load_dotenv()
|
| 38 |
+
FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
|
| 39 |
+
safe_clients = ['127.0.0.1']
|
| 40 |
+
|
| 41 |
+
api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
|
| 42 |
+
|
| 43 |
+
async def validate_api_key(request: Request, key: str = Security(api_key_header)):
|
| 44 |
+
'''
|
| 45 |
+
Check if the API key is valid
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
key (str): The API key to check
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
HTTPException: If the API key is invalid
|
| 52 |
+
'''
|
| 53 |
+
if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
|
| 54 |
+
raise HTTPException(
|
| 55 |
+
status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
|
| 56 |
+
)
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
|
| 60 |
+
title="Tennis Insights API")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ------------------------------------------------------------------------------
|
| 64 |
+
@app.get("/", include_in_schema=False)
|
| 65 |
+
def redirect_to_docs():
|
| 66 |
+
'''
|
| 67 |
+
Redirect to the API documentation.
|
| 68 |
+
'''
|
| 69 |
+
return RedirectResponse(url='/docs')
|
| 70 |
+
|
| 71 |
+
@app.get("/train_model", tags=["model"], deprecated=True)
|
| 72 |
+
async def train_model(
|
| 73 |
+
background_tasks: BackgroundTasks,
|
| 74 |
+
circuit: Literal["atp", "wta"] = 'atp',
|
| 75 |
+
from_date: str = "2024-01-01",
|
| 76 |
+
to_date: str = "2024-12-31"):
|
| 77 |
+
"""
|
| 78 |
+
Train the model
|
| 79 |
+
"""
|
| 80 |
+
# Check dates format
|
| 81 |
+
try:
|
| 82 |
+
datetime.strptime(from_date, "%Y-%m-%d")
|
| 83 |
+
datetime.strptime(to_date, "%Y-%m-%d")
|
| 84 |
+
except ValueError:
|
| 85 |
+
return {"message": "Invalid date format. Please use the format 'YYYY-MM-DD'"}
|
| 86 |
+
|
| 87 |
+
background_tasks.add_task(
|
| 88 |
+
func=train_model_from_scratch,
|
| 89 |
+
circuit=circuit,
|
| 90 |
+
from_date=from_date,
|
| 91 |
+
to_date=to_date)
|
| 92 |
+
|
| 93 |
+
return {"message": "Model training in progress"}
|
| 94 |
+
|
| 95 |
+
@app.get("/run_experiment", tags=["model"], description="Schedule a run of the ML experiment")
|
| 96 |
+
async def run_xp(
|
| 97 |
+
background_tasks: BackgroundTasks,
|
| 98 |
+
circuit: Literal["atp", "wta"] = 'atp',
|
| 99 |
+
from_date: str = "2024-01-01",
|
| 100 |
+
to_date: str = "2024-12-31"):
|
| 101 |
+
"""
|
| 102 |
+
Train the model
|
| 103 |
+
"""
|
| 104 |
+
# Check dates format
|
| 105 |
+
try:
|
| 106 |
+
datetime.strptime(from_date, "%Y-%m-%d")
|
| 107 |
+
datetime.strptime(to_date, "%Y-%m-%d")
|
| 108 |
+
except ValueError:
|
| 109 |
+
return {"message": "Invalid date format. Please use the format 'YYYY-MM-DD'"}
|
| 110 |
+
|
| 111 |
+
background_tasks.add_task(
|
| 112 |
+
func=run_experiment,
|
| 113 |
+
circuit=circuit,
|
| 114 |
+
from_date=from_date,
|
| 115 |
+
to_date=to_date)
|
| 116 |
+
|
| 117 |
+
return {"message": "Experiment scheduled"}
|
| 118 |
+
|
| 119 |
+
class ModelInput(BaseModel):
|
| 120 |
+
rank_player_1: int = Field(gt=0, default=1, description="The rank of the 1st player")
|
| 121 |
+
rank_player_2: int = Field(gt=0, default=100, description="The rank of the 2nd player")
|
| 122 |
+
points_player_1: int = Field(gt=0, default=4000, description="The number of points of the 1st player")
|
| 123 |
+
points_player_2: int = Field(gt=0, default=500, description="The number of points of the 2nd player")
|
| 124 |
+
court: Literal['Outdoor', 'Indoor'] = 'Outdoor'
|
| 125 |
+
surface: Literal['Grass', 'Carpet', 'Clay', 'Hard'] = 'Clay'
|
| 126 |
+
round: Literal['1st Round', '2nd Round', '3nd Round', '4th Round', 'Quarterfinals', 'Semifinals', 'The Final', 'Round Robin'] = '1st Round'
|
| 127 |
+
series: Literal['Grand Slam', 'Masters 1000', 'Masters', 'Masters Cup', 'ATP500', 'ATP250', 'International Gold', 'International'] = 'Grand Slam'
|
| 128 |
+
model: Optional[str] = 'LogisticRegression'
|
| 129 |
+
version: Optional[str] = 'latest'
|
| 130 |
+
|
| 131 |
+
class ModelOutput(BaseModel):
|
| 132 |
+
result: int = Field(description="The prediction result. 1 if player 1 is expected to win, 0 otherwise.", example=1)
|
| 133 |
+
prob: list[float] = Field(description="Probability of [defeat, victory] of player 1.", example=[0.15, 0.85])
|
| 134 |
+
|
| 135 |
+
@app.get("/predict",
|
| 136 |
+
tags=["model"],
|
| 137 |
+
description="Predict the outcome of a tennis match",
|
| 138 |
+
response_model=ModelOutput)
|
| 139 |
+
async def make_prediction(params: Annotated[ModelInput, Query()]):
|
| 140 |
+
"""
|
| 141 |
+
Predict the matches
|
| 142 |
+
"""
|
| 143 |
+
if not params.model:
|
| 144 |
+
# check the presence of 'model.pkl' file in data/
|
| 145 |
+
if not os.path.exists("/data/model.pkl"):
|
| 146 |
+
return {"message": "Model not trained. Please train the model first."}
|
| 147 |
+
|
| 148 |
+
# Load the model
|
| 149 |
+
pipeline = joblib.load("/data/model.pkl")
|
| 150 |
+
else:
|
| 151 |
+
# Get the model info
|
| 152 |
+
try:
|
| 153 |
+
pipeline = load_model(params.model, params.version)
|
| 154 |
+
except RestException as e:
|
| 155 |
+
logging.error(e)
|
| 156 |
+
|
| 157 |
+
# Return HTTP error 404
|
| 158 |
+
return HTTPException(
|
| 159 |
+
status=HTTP_404_NOT_FOUND,
|
| 160 |
+
detail=f"Model {params.model} not found"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Make the prediction
|
| 164 |
+
prediction = predict(
|
| 165 |
+
pipeline=pipeline,
|
| 166 |
+
rank_player_1=params.rank_player_1,
|
| 167 |
+
rank_player_2=params.rank_player_2,
|
| 168 |
+
points_player_1=params.points_player_1,
|
| 169 |
+
points_player_2=params.points_player_2,
|
| 170 |
+
court=params.court,
|
| 171 |
+
surface=params.surface,
|
| 172 |
+
round_stage=params.round,
|
| 173 |
+
series=params.series
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
logging.info(prediction)
|
| 177 |
+
|
| 178 |
+
return prediction
|
| 179 |
+
|
| 180 |
+
@app.get("/list_available_models", tags=["model"], description="List the available models")
|
| 181 |
+
async def list_available_models():
|
| 182 |
+
"""
|
| 183 |
+
List the available models
|
| 184 |
+
"""
|
| 185 |
+
return list_registered_models()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class Tournament(BaseModel):
|
| 189 |
+
name: str = Field(description="The tournament's name.", example='Wimbledon')
|
| 190 |
+
series: Literal['ATP250', 'ATP500', 'Grand Slam', 'Masters 1000', 'Masters', 'Masters Cup', 'International Gold', 'International'] = 'Grand Slam'
|
| 191 |
+
court: Literal['Outdoor', 'Indoor'] = 'Outdoor'
|
| 192 |
+
surface: Literal['Grass', 'Carpet', 'Clay', 'Hard'] = 'Grass'
|
| 193 |
+
|
| 194 |
+
@app.get("/{circuit}/tournaments", tags=["reference"], description="List the tournaments of the circuit", response_model=list[Tournament])
|
| 195 |
+
async def list_tournaments(circuit: Literal["atp", "wta"]):
|
| 196 |
+
"""
|
| 197 |
+
List the tournaments of the circuit
|
| 198 |
+
"""
|
| 199 |
+
return _list_tournaments(circuit)
|
| 200 |
+
|
| 201 |
+
@app.get("/check_health", tags=["general"], description="Check the health of the API")
|
| 202 |
+
async def check_health():
|
| 203 |
+
"""
|
| 204 |
+
Check all the services in the infrastructure are working
|
| 205 |
+
"""
|
| 206 |
+
healthy = 0
|
| 207 |
+
unhealthy = 1
|
| 208 |
+
|
| 209 |
+
# DB check
|
| 210 |
+
db_status = False
|
| 211 |
+
try:
|
| 212 |
+
with _get_connection() as conn:
|
| 213 |
+
with conn.cursor() as cursor:
|
| 214 |
+
cursor.execute("SELECT 1")
|
| 215 |
+
db_status = True
|
| 216 |
+
except Exception:
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
if db_status:
|
| 220 |
+
return healthy
|
| 221 |
+
else:
|
| 222 |
+
return unhealthy
|
src/model.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import joblib
|
| 4 |
+
import logging
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from typing import Literal, Any, Tuple, Dict, List
|
| 8 |
+
import mlflow
|
| 9 |
+
from mlflow.models import infer_signature
|
| 10 |
+
from mlflow.tracking import MlflowClient
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
from sklearn.impute import SimpleImputer
|
| 13 |
+
from sklearn.preprocessing import OneHotEncoder, StandardScaler
|
| 14 |
+
from sklearn.compose import ColumnTransformer
|
| 15 |
+
from sklearn.linear_model import LogisticRegression
|
| 16 |
+
from sklearn.pipeline import Pipeline
|
| 17 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
|
| 18 |
+
|
| 19 |
+
from src.sql import load_matches_from_postgres
|
| 20 |
+
from src.enums import Feature
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
models = {}
|
| 25 |
+
|
| 26 |
+
def create_pairwise_data(df: pd.DataFrame) -> pd.DataFrame:
|
| 27 |
+
"""
|
| 28 |
+
Creates a balanced dataset with pairwise comparisons
|
| 29 |
+
"""
|
| 30 |
+
records = []
|
| 31 |
+
for _, row in df.iterrows():
|
| 32 |
+
# Record 1 : original order (winner in position 1, loser in position 2)
|
| 33 |
+
record_1 = {
|
| 34 |
+
Feature.SERIES.name: row['series'],
|
| 35 |
+
Feature.SURFACE.name: row['surface'],
|
| 36 |
+
Feature.COURT.name: row['court'],
|
| 37 |
+
Feature.ROUND.name: row['round'],
|
| 38 |
+
Feature.DIFF_RANKING.name: row['w_rank'] - row['l_rank'], # rank difference
|
| 39 |
+
Feature.DIFF_POINTS.name: row['w_points'] - row['l_points'], # points difference
|
| 40 |
+
'target': 1 # Player in first position won
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Record 2 : invert players
|
| 44 |
+
record_2 = record_1.copy()
|
| 45 |
+
record_2[Feature.DIFF_RANKING.name] = -record_2['diffRanking'] # Invert the ranking difference
|
| 46 |
+
record_2[Feature.DIFF_POINTS.name] = -record_2['diffPoints'] # Invert the points difference
|
| 47 |
+
record_2['target'] = 0 # Player in first position lost
|
| 48 |
+
|
| 49 |
+
records.append(record_1)
|
| 50 |
+
records.append(record_2)
|
| 51 |
+
|
| 52 |
+
return pd.DataFrame(records)
|
| 53 |
+
|
| 54 |
+
def create_pipeline() -> Pipeline:
|
| 55 |
+
"""
|
| 56 |
+
Creates a machine learning pipeline with SimpleImputer, StandardScaler, OneHotEncoder and LogisticRegression.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Pipeline: A scikit-learn pipeline object.
|
| 60 |
+
"""
|
| 61 |
+
# Define the features, numerical and categorical
|
| 62 |
+
cat_features = [f.name for f in Feature.get_features_by_type('category')]
|
| 63 |
+
num_features = [f.name for f in Feature.get_features_by_type('number')]
|
| 64 |
+
|
| 65 |
+
# Pipeline for numerical variables
|
| 66 |
+
num_transformer = Pipeline(steps=[
|
| 67 |
+
('imputer', SimpleImputer(strategy='mean')),
|
| 68 |
+
('scaler', StandardScaler())
|
| 69 |
+
])
|
| 70 |
+
|
| 71 |
+
# Pipeline for categorical variables
|
| 72 |
+
cat_transformer = OneHotEncoder(handle_unknown='ignore')
|
| 73 |
+
|
| 74 |
+
# Preprocessor
|
| 75 |
+
preprocessor = ColumnTransformer(
|
| 76 |
+
transformers=[
|
| 77 |
+
('num', num_transformer, num_features),
|
| 78 |
+
('cat', cat_transformer, cat_features)
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Full pipeline
|
| 83 |
+
pipeline = Pipeline(steps=[
|
| 84 |
+
('preprocessor', preprocessor),
|
| 85 |
+
('classifier', LogisticRegression(solver='lbfgs', max_iter=1000))
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
return pipeline
|
| 89 |
+
|
| 90 |
+
def train_model_from_scratch(
|
| 91 |
+
circuit: Literal['atp', 'wta'],
|
| 92 |
+
from_date: str,
|
| 93 |
+
to_date: str,
|
| 94 |
+
output_path: str = '/data/model.pkl') -> Pipeline:
|
| 95 |
+
"""
|
| 96 |
+
Train a model from scratch
|
| 97 |
+
"""
|
| 98 |
+
# Load data
|
| 99 |
+
data = load_matches_from_postgres(
|
| 100 |
+
table_name=f"{circuit}_data",
|
| 101 |
+
from_date=from_date,
|
| 102 |
+
to_date=to_date)
|
| 103 |
+
|
| 104 |
+
# Train the model
|
| 105 |
+
pipeline = create_and_train_model(data)
|
| 106 |
+
|
| 107 |
+
# Save the model
|
| 108 |
+
joblib.dump(pipeline, output_path)
|
| 109 |
+
|
| 110 |
+
return pipeline
|
| 111 |
+
|
| 112 |
+
def create_and_train_model(data: pd.DataFrame) -> Pipeline:
|
| 113 |
+
"""
|
| 114 |
+
Create and train a model on the given data
|
| 115 |
+
"""
|
| 116 |
+
# Split the data
|
| 117 |
+
X_train, _, y_train, _ = preprocess_data(data)
|
| 118 |
+
|
| 119 |
+
# Train the model
|
| 120 |
+
pipeline = create_pipeline()
|
| 121 |
+
pipeline = train_model(pipeline, X_train, y_train)
|
| 122 |
+
|
| 123 |
+
return pipeline
|
| 124 |
+
|
| 125 |
+
def train_model(
|
| 126 |
+
pipeline: Pipeline,
|
| 127 |
+
X_train: pd.DataFrame,
|
| 128 |
+
y_train: pd.DataFrame) -> Pipeline:
|
| 129 |
+
"""
|
| 130 |
+
Train the pipeline
|
| 131 |
+
"""
|
| 132 |
+
pipeline.fit(X_train, y_train)
|
| 133 |
+
return pipeline
|
| 134 |
+
|
| 135 |
+
def preprocess_data(df: pd.DataFrame) -> Tuple:
|
| 136 |
+
"""
|
| 137 |
+
Split the dataframe into X (features) and y (target).
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
df (pd.DataFrame): Input dataframe.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Tuple: Split data (X_train, X_test, y_train, y_test).
|
| 144 |
+
"""
|
| 145 |
+
# Format data for the model
|
| 146 |
+
df_model = create_pairwise_data(df)
|
| 147 |
+
|
| 148 |
+
features = [f.name for f in Feature.get_all_features()]
|
| 149 |
+
X = df_model[features]
|
| 150 |
+
y = df_model['target']
|
| 151 |
+
|
| 152 |
+
# Split the data
|
| 153 |
+
return train_test_split(X, y, test_size=0.2)
|
| 154 |
+
|
| 155 |
+
def evaluate_model(pipeline: Pipeline, X_test: pd.DataFrame, y_test: pd.Series) -> Dict:
|
| 156 |
+
"""
|
| 157 |
+
Evaluates the model
|
| 158 |
+
"""
|
| 159 |
+
y_pred = pipeline.predict(X_test)
|
| 160 |
+
accuracy = accuracy_score(y_test, y_pred)
|
| 161 |
+
roc_auc = roc_auc_score(y_test, pipeline.predict_proba(X_test)[:, 1])
|
| 162 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
"accuracy": accuracy,
|
| 166 |
+
"roc_auc": roc_auc,
|
| 167 |
+
"confusion_matrix": cm
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
def predict(
|
| 171 |
+
pipeline: Pipeline,
|
| 172 |
+
series: str,
|
| 173 |
+
surface: str,
|
| 174 |
+
court: str,
|
| 175 |
+
round_stage: str,
|
| 176 |
+
rank_player_1: int,
|
| 177 |
+
rank_player_2: int,
|
| 178 |
+
points_player_1: int,
|
| 179 |
+
points_player_2: int
|
| 180 |
+
) -> Dict[str, Any]:
|
| 181 |
+
diffRanking = rank_player_1 - rank_player_2
|
| 182 |
+
diffPoints = points_player_1 - points_player_2
|
| 183 |
+
|
| 184 |
+
# Built a DataFrame with the new match
|
| 185 |
+
new_match = pd.DataFrame([{
|
| 186 |
+
Feature.SERIES.name: series,
|
| 187 |
+
Feature.SURFACE.name: surface,
|
| 188 |
+
Feature.COURT.name: court,
|
| 189 |
+
Feature.ROUND.name: round_stage,
|
| 190 |
+
Feature.DIFF_RANKING.name: diffRanking,
|
| 191 |
+
Feature.DIFF_POINTS.name: diffPoints
|
| 192 |
+
}])
|
| 193 |
+
|
| 194 |
+
# Use the pipeline to make a prediction
|
| 195 |
+
prediction = pipeline.predict(new_match)[0]
|
| 196 |
+
proba = pipeline.predict_proba(new_match)[0]
|
| 197 |
+
|
| 198 |
+
# Print the result
|
| 199 |
+
logging.info("\n--- 📊 Result ---")
|
| 200 |
+
logging.info(f"🏆 Win probability : {proba[1]:.2f}")
|
| 201 |
+
logging.info(f"❌ Lose probability : {proba[0]:.2f}")
|
| 202 |
+
logging.info(f"🎾 Prediction : {'Victory' if prediction == 1 else 'Loss'}")
|
| 203 |
+
|
| 204 |
+
return {"result": prediction.item(), "prob": [p.item() for p in proba]}
|
| 205 |
+
|
| 206 |
+
def run_experiment(
|
| 207 |
+
circuit: Literal['atp', 'wta'],
|
| 208 |
+
from_date: str,
|
| 209 |
+
to_date: str,
|
| 210 |
+
artifact_path: str = None,
|
| 211 |
+
registered_model_name: str = 'LogisticRegression',
|
| 212 |
+
experiment_name: str = 'Logistic Tennis Prediction',
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
Run the entire ML experiment pipeline.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
experiment_name (str): Name of the MLflow experiment.
|
| 219 |
+
data_url (str): URL to load the dataset.
|
| 220 |
+
artifact_path (str): Path to store the model artifact.
|
| 221 |
+
registered_model_name (str): Name to register the model under in MLflow.
|
| 222 |
+
"""
|
| 223 |
+
if not artifact_path:
|
| 224 |
+
artifact_path = f'{circuit}_model'
|
| 225 |
+
|
| 226 |
+
# Set tracking URI to your mlflow application
|
| 227 |
+
mlflow.set_tracking_uri(os.environ["MLFLOW_SERVER_URI"])
|
| 228 |
+
|
| 229 |
+
# Start timing
|
| 230 |
+
start_time = time.time()
|
| 231 |
+
|
| 232 |
+
# Load and preprocess data
|
| 233 |
+
df = load_matches_from_postgres(
|
| 234 |
+
table_name=f"{circuit}_data",
|
| 235 |
+
from_date=from_date,
|
| 236 |
+
to_date=to_date)
|
| 237 |
+
X_train, X_test, y_train, y_test = preprocess_data(df)
|
| 238 |
+
|
| 239 |
+
# Create pipeline
|
| 240 |
+
pipe = create_pipeline()
|
| 241 |
+
|
| 242 |
+
# Set experiment's info
|
| 243 |
+
mlflow.set_experiment(experiment_name)
|
| 244 |
+
|
| 245 |
+
# Get our experiment info
|
| 246 |
+
experiment = mlflow.get_experiment_by_name(experiment_name)
|
| 247 |
+
|
| 248 |
+
# Call mlflow autolog
|
| 249 |
+
mlflow.sklearn.autolog()
|
| 250 |
+
|
| 251 |
+
with mlflow.start_run(experiment_id=experiment.experiment_id):
|
| 252 |
+
# Train model
|
| 253 |
+
train_model(pipe, X_train, y_train)
|
| 254 |
+
|
| 255 |
+
# Store metrics
|
| 256 |
+
# predicted_output = pipe.predict(X_test.values)
|
| 257 |
+
accuracy = pipe.score(X_test, y_test)
|
| 258 |
+
|
| 259 |
+
# Print results
|
| 260 |
+
logging.info("LogisticRegression model")
|
| 261 |
+
logging.info("Accuracy: {}".format(accuracy))
|
| 262 |
+
signature = infer_signature(X_test, pipe.predict(X_test))
|
| 263 |
+
|
| 264 |
+
mlflow.sklearn.log_model(
|
| 265 |
+
sk_model=pipe,
|
| 266 |
+
artifact_path=artifact_path,
|
| 267 |
+
registered_model_name=registered_model_name,
|
| 268 |
+
signature=signature
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Print timing
|
| 272 |
+
logging.info(f"...Training Done! --- Total training time: {time.time() - start_time} seconds")
|
| 273 |
+
|
| 274 |
+
def list_registered_models() -> List[Dict]:
|
| 275 |
+
"""
|
| 276 |
+
List all the registered models
|
| 277 |
+
"""
|
| 278 |
+
# Set tracking URI to your Heroku application
|
| 279 |
+
mlflow.set_tracking_uri(os.environ["MLFLOW_SERVER_URI"])
|
| 280 |
+
|
| 281 |
+
# Return the list of registered models
|
| 282 |
+
results = mlflow.search_registered_models()
|
| 283 |
+
|
| 284 |
+
output = []
|
| 285 |
+
for res in results:
|
| 286 |
+
for mv in res.latest_versions:
|
| 287 |
+
output.append({"name": mv.name, "run_id": mv.run_id, "version": mv.version})
|
| 288 |
+
|
| 289 |
+
return output
|
| 290 |
+
|
| 291 |
+
def load_model(name: str, version: str = 'latest') -> Pipeline:
|
| 292 |
+
"""
|
| 293 |
+
Load a model from MLflow
|
| 294 |
+
"""
|
| 295 |
+
if name in models.keys():
|
| 296 |
+
return models[name]
|
| 297 |
+
|
| 298 |
+
mlflow.set_tracking_uri(os.environ["MLFLOW_SERVER_URI"])
|
| 299 |
+
client = MlflowClient()
|
| 300 |
+
|
| 301 |
+
model_info = client.get_registered_model(name)
|
| 302 |
+
|
| 303 |
+
# Load the model
|
| 304 |
+
pipeline = mlflow.sklearn.load_model(model_uri=model_info.latest_versions[0].source)
|
| 305 |
+
|
| 306 |
+
logging.info(f'Model {name} loaded')
|
| 307 |
+
|
| 308 |
+
models[name] = pipeline
|
| 309 |
+
|
| 310 |
+
return pipeline
|
src/sql.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import psycopg2
|
| 3 |
+
from typing import Literal
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
PG_USER = os.getenv("PG_USER")
|
| 11 |
+
PG_PASSWORD = os.getenv("PG_PASSWORD")
|
| 12 |
+
PG_HOST = os.getenv("PG_HOST")
|
| 13 |
+
PG_PORT = os.getenv("PG_PORT")
|
| 14 |
+
PG_DB = os.getenv("PG_DB")
|
| 15 |
+
PG_SSLMODE = os.getenv("PG_SSLMODE")
|
| 16 |
+
|
| 17 |
+
def _get_connection() -> psycopg2.extensions.connection:
|
| 18 |
+
"""
|
| 19 |
+
Get a connection to the Postgres database
|
| 20 |
+
"""
|
| 21 |
+
conn = psycopg2.connect(
|
| 22 |
+
dbname=PG_DB,
|
| 23 |
+
user=PG_USER,
|
| 24 |
+
password=PG_PASSWORD,
|
| 25 |
+
host=PG_HOST,
|
| 26 |
+
port=PG_PORT,
|
| 27 |
+
sslmode=PG_SSLMODE,
|
| 28 |
+
)
|
| 29 |
+
return conn
|
| 30 |
+
|
| 31 |
+
def load_matches_from_postgres(
|
| 32 |
+
table_name: Literal['atp_data', 'wta_data'],
|
| 33 |
+
from_date: str = None,
|
| 34 |
+
to_date: str = None) -> pd.DataFrame:
|
| 35 |
+
"""
|
| 36 |
+
Load data from Postgres
|
| 37 |
+
"""
|
| 38 |
+
if not to_date:
|
| 39 |
+
to_date = datetime.now().strftime("%Y-%m-%d")
|
| 40 |
+
|
| 41 |
+
if not from_date:
|
| 42 |
+
from_date = "1900-01-01"
|
| 43 |
+
|
| 44 |
+
query = f"SELECT * FROM {table_name} WHERE date BETWEEN %s AND %s"
|
| 45 |
+
vars = [from_date, to_date]
|
| 46 |
+
|
| 47 |
+
with _get_connection() as conn:
|
| 48 |
+
with conn.cursor() as cursor:
|
| 49 |
+
cursor.execute(query, vars)
|
| 50 |
+
data = cursor.fetchall()
|
| 51 |
+
|
| 52 |
+
data = pd.DataFrame(data, columns=[desc[0] for desc in cursor.description])
|
| 53 |
+
|
| 54 |
+
return data
|
| 55 |
+
|
| 56 |
+
def list_tournaments(circuit: Literal["atp", "wta"]):
|
| 57 |
+
"""
|
| 58 |
+
List the tournaments of the circuit
|
| 59 |
+
"""
|
| 60 |
+
query = f"""
|
| 61 |
+
SELECT DISTINCT
|
| 62 |
+
tournament as name,
|
| 63 |
+
series,
|
| 64 |
+
court,
|
| 65 |
+
surface
|
| 66 |
+
FROM {circuit}_data;
|
| 67 |
+
"""
|
| 68 |
+
with _get_connection() as conn:
|
| 69 |
+
with conn.cursor() as cursor:
|
| 70 |
+
cursor.execute(query)
|
| 71 |
+
tournaments = [{'name': row[0], 'series': row[1], 'court': row[2], 'surface': row[3]} for row in cursor.fetchall()]
|
| 72 |
+
|
| 73 |
+
return tournaments
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
@pytest.fixture
|
| 5 |
+
def simple_match():
|
| 6 |
+
return pd.DataFrame({
|
| 7 |
+
'series': ['ATP250',],
|
| 8 |
+
'surface': ['Clay',],
|
| 9 |
+
'court': ['Indoor',],
|
| 10 |
+
'round': ['Round Robin',],
|
| 11 |
+
'w_rank': [5],
|
| 12 |
+
'l_rank': [300],
|
| 13 |
+
'w_points': [2000],
|
| 14 |
+
'l_points': [40],
|
| 15 |
+
})
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def simple_match_pairwise_data(simple_match: pd.DataFrame):
|
| 19 |
+
return pd.DataFrame({
|
| 20 |
+
'Series': ['ATP250', 'ATP250'],
|
| 21 |
+
'Surface': ['Clay', 'Clay'],
|
| 22 |
+
'Court': ['Indoor', 'Indoor'],
|
| 23 |
+
'Round': ['Round Robin', 'Round Robin'],
|
| 24 |
+
'diffRanking': [-295, 295],
|
| 25 |
+
'diffPoints': [1960, -1960],
|
| 26 |
+
'target': [1, 0]
|
| 27 |
+
})
|
| 28 |
+
|
| 29 |
+
@pytest.fixture
|
| 30 |
+
def simple_match_empty():
|
| 31 |
+
return pd.DataFrame({
|
| 32 |
+
'Series': [],
|
| 33 |
+
'Surface': [],
|
| 34 |
+
'Court': [],
|
| 35 |
+
'Round': [],
|
| 36 |
+
'diffRanking': [],
|
| 37 |
+
'diffPoints': [],
|
| 38 |
+
'target': []
|
| 39 |
+
})
|
tests/test_enums.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.enums import Feature
|
| 2 |
+
|
| 3 |
+
def test_get_features_by_type():
|
| 4 |
+
"""
|
| 5 |
+
Test the method Feature.get_features_by_type
|
| 6 |
+
"""
|
| 7 |
+
features = Feature.get_features_by_type('category')
|
| 8 |
+
assert len(features) == 4
|
| 9 |
+
assert all([feature.type == 'category' for feature in features])
|
| 10 |
+
|
| 11 |
+
features = Feature.get_features_by_type('number')
|
| 12 |
+
assert len(features) == 2
|
| 13 |
+
assert all([feature.type == 'number' for feature in features])
|
| 14 |
+
|
| 15 |
+
def test_get_all_features():
|
| 16 |
+
"""
|
| 17 |
+
Test the method Feature.get_all_features
|
| 18 |
+
"""
|
| 19 |
+
features = Feature.get_all_features()
|
| 20 |
+
assert len(features) == 6
|
| 21 |
+
|
tests/test_model.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from sklearn.pipeline import Pipeline
|
| 3 |
+
|
| 4 |
+
from src.model import create_pairwise_data, create_pipeline
|
| 5 |
+
|
| 6 |
+
def test_create_pairwise_data(simple_match: pd.DataFrame, simple_match_pairwise_data: pd.DataFrame):
|
| 7 |
+
result = create_pairwise_data(simple_match)
|
| 8 |
+
|
| 9 |
+
assert set(result.columns) == set(simple_match_pairwise_data.columns), "Columns are different"
|
| 10 |
+
assert simple_match_pairwise_data.equals(result), "Dataframes are different"
|
| 11 |
+
|
| 12 |
+
def test_create_pairwise_data_empty(simple_match_empty: pd.DataFrame):
|
| 13 |
+
result = create_pairwise_data(simple_match_empty)
|
| 14 |
+
|
| 15 |
+
assert result.empty, "Dataframe is not empty"
|
| 16 |
+
|
| 17 |
+
def test_create_pipeline():
|
| 18 |
+
pipeline = create_pipeline()
|
| 19 |
+
assert pipeline is not None, "Pipeline is None"
|
| 20 |
+
assert isinstance(pipeline, Pipeline), "Pipeline is not a Pipeline"
|
| 21 |
+
assert len(pipeline.named_steps) == 2, "Pipeline has wrong number of steps"
|
| 22 |
+
assert 'preprocessor' in pipeline.named_steps, "Preprocessor is missing"
|
| 23 |
+
assert 'classifier' in pipeline.named_steps, "Classifier is missing"
|