tennis-api / src /main.py
SlimG's picture
improve resilience to non-existing remote endpoints
9316f10
import os
import logging
import secrets
from typing import Generator, Optional, Annotated, List, Dict, AsyncGenerator
from datetime import datetime
from contextlib import asynccontextmanager
from fastapi import (
FastAPI,
Request,
HTTPException,
Query,
Security,
Depends
)
from fastapi.responses import RedirectResponse, JSONResponse
from fastapi.security.api_key import APIKeyHeader
from starlette.status import (
HTTP_200_OK,
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
HTTP_422_UNPROCESSABLE_ENTITY,
HTTP_503_SERVICE_UNAVAILABLE)
from sqlalchemy import text
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from dotenv import load_dotenv
from sqlalchemy.exc import IntegrityError
from src.api_factory import create_forward_endpoint, get_remote_params
from src.entity.match import (
Match,
RawMatch,
MatchApiBase,
MatchApiDetail
)
from src.entity.player import (
Player,
PlayerApiDetail,
)
from src.jobs.match import schedule_matches_ingestion
from src.jobs.views import schedule_refresh
from src.repository.common import get_session
from src.service.match import (
insert_new_match,
insert_batch_matches,
)
load_dotenv()
# ------------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO,
handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)
def provide_connection() -> Generator[Session, None, None]:
with get_session() as conn:
yield conn
# ------------------------------------------------------------------------------
# Ensure all the necessary jobs are scheduled
if os.getenv("REDIS_URL"):
schedule_refresh()
schedule_matches_ingestion(year=None)
# ------------------------------------------------------------------------------
TENNIS_ML_API = os.getenv("TENNIS_ML_API")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if not TENNIS_ML_API:
yield
return
# Create a forward endpoint for each endpoint in the remote API
endpoints = [
"run_experiment",
"predict",
"list_available_models"
]
for endpoint in endpoints:
endpoint_def = await get_remote_params(base_url=TENNIS_ML_API,
endpoint=endpoint,
method='get')
if not endpoint_def["found"]:
logger.warning(f"Endpoint {endpoint} not found in the remote API")
continue
# Create a forward endpoint for the remote API
forward_endpoint = create_forward_endpoint(base_url=TENNIS_ML_API,
_endpoint=endpoint,
param_defs=endpoint_def["params"])
app.add_api_route(
path=f'/{endpoint}',
endpoint=forward_endpoint,
methods=["GET"],
name=f"Forward to remote {forward_endpoint.__name__}",
tags=endpoint_def["general"]["tags"],
description=endpoint_def["general"]["description"],
summary=endpoint_def["general"]["summary"],
)
yield
# ------------------------------------------------------------------------------
FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
safe_clients = ['127.0.0.1']
api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
async def validate_api_key(request: Request, key: str = Security(api_key_header)) -> None:
'''
Check if the API key is valid
Args:
key (str): The API key to check
Raises:
HTTPException: If the API key is invalid
'''
if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
)
return None
app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
lifespan=lifespan if TENNIS_ML_API else None,
title="Tennis Insights API")
# ------------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
'''
Redirect to the API documentation.
'''
return RedirectResponse(url='/docs')
# List all the tournament names and years
@app.get("/tournament/names", tags=["tournament"], description="List all the tournament names and years", response_model=List[Dict])
async def list_tournament_names(
session: Session = Depends(provide_connection)
):
"""
List all the tournament names and first and last year of occurrence
"""
tournaments = session.execute(text("SELECT t.name, t.first_year, t.last_year FROM data.tournaments_list_m_view AS t")).all()
tournaments = [
{
"name": tournament[0],
"first_year": tournament[1],
"last_year": tournament[2]
}
for tournament in tournaments
]
return tournaments
@app.get("/references/courts", tags=["reference"], description="List all the courts")
async def list_courts(
session: Session = Depends(provide_connection)
):
"""
List all the courts
"""
courts = session.execute(text("SELECT name FROM data.ref_court_m_view")).all()
courts = [court[0] for court in courts]
return courts
@app.get("/references/surfaces", tags=["reference"], description="List all the surfaces")
async def list_surfaces(
session: Session = Depends(provide_connection)
):
"""
List all the surfaces
"""
surfaces = session.execute(text("SELECT name FROM data.ref_surface_m_view")).all()
surfaces = [surface[0] for surface in surfaces]
return surfaces
@app.get("/references/series", tags=["reference"], description="List all the series")
async def list_series(
session: Session = Depends(provide_connection)
):
"""
List all the series
"""
series = session.execute(text("SELECT name FROM data.ref_series_m_view")).all()
series = [serie[0] for serie in series]
return series
class ListPlayersInput(BaseModel):
ids: List[int] = Field(
description="List of player IDs",
)
# Get a list of players
@app.get("/players", tags=["player"], description="Get a list of players from the database", response_model=Dict[str|int, Optional[PlayerApiDetail]])
async def list_players(
params: Annotated[ListPlayersInput, Query()],
session: Session = Depends(provide_connection),
):
"""
Get a list of players from the database
"""
ids = sorted(params.ids)
players = session.query(Player).filter(Player.id.in_(ids)).all()
if not players:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Players {ids} not found"
)
players = {player.id: player for player in players}
return {id: players.get(id) for id in ids}
@app.get("/player/{player_id}", tags=["player"], description="Get a player from the database", response_model=PlayerApiDetail)
async def get_player(
player_id: int,
session: Session = Depends(provide_connection)
):
"""
Get a player from the database
"""
player = session.query(Player).filter(Player.id == player_id).first()
if not player:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Player {player_id} not found"
)
return player
# Get all the matches from a tournament
@app.get("/tournament/matches", tags=["tournament"], description="Get all the matches from a tournament", response_model=List[MatchApiBase])
async def search_tournament_matches(
name: str,
year: int,
session: Session = Depends(provide_connection)
):
"""
Get all the matches from a tournament
"""
start_date = datetime(year, 1, 1)
end_date = datetime(year, 12, 31)
matches = session.query(Match).filter(
Match.tournament_name == name,
Match.date.between(start_date, end_date)
).all()
return sorted(matches, key=lambda x: x.date, reverse=True)
# Get a match
@app.get("/match/{match_id}", tags=["match"], description="Get a match from the database", response_model=MatchApiDetail)
async def get_match(
match_id: int,
session: Session = Depends(provide_connection)
):
"""
Get a match from the database
"""
match = session.query(Match).filter(Match.id == match_id).first()
if not match:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Match {match_id} not found"
)
return match
@app.post("/match/insert", tags=["match"], description="Insert a match into the database")
async def insert_match(
raw_match: RawMatch,
session: Session = Depends(provide_connection),
on_conflict_do_nothing: bool = False,
):
"""
Insert a match into the database
"""
try:
match = insert_new_match(
db=session,
raw_match=raw_match.model_dump(exclude_unset=True),
on_conflict_do_nothing=on_conflict_do_nothing,
)
except IntegrityError as e:
logger.error(f"Error inserting match: {e}")
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
detail="Entity already exists in the database"
)
output = {
"status": "ok",
"match_id": match.id if match else None,
}
return JSONResponse(content=output, status_code=HTTP_200_OK)
@app.post("/match/ingest_year", tags=["match"], description="Ingest matches from tennis-data.co.uk for a given year")
async def ingest_matches(year: Optional[int] = None):
"""
Ingest matches from tennis-data.co.uk for a given year
"""
job = schedule_matches_ingestion(year=year)
return {"job_id": job.get_id(), "job_status": job.get_status()}
@app.post("/batch/match/insert", tags=["match"], description="Insert a batch of matches into the database")
async def insert_batch_match(
raw_matches: list[RawMatch],
on_conflict_do_nothing: bool = False,
session: Session = Depends(provide_connection)
):
"""
Insert a batch of matches into the database
"""
result = insert_batch_matches(db=session,
raw_matches=raw_matches,
on_conflict_do_nothing=on_conflict_do_nothing)
matches = result['matches']
nb_errors = result['nb_errors']
logger.info(f"Number of matches inserted: {len(matches)}")
if nb_errors > 0:
logger.warning(f"Number of errors: {nb_errors}")
return JSONResponse(
content={"status": "ok", "message": f"{len(matches)} matches inserted, {nb_errors} errors"},
status_code=HTTP_422_UNPROCESSABLE_ENTITY
)
else:
output = {
"status": "ok",
"match_ids": [match.id for match in matches],
}
return JSONResponse(content=output, status_code=HTTP_200_OK)
# ------------------------------------------------------------------------------
@app.get("/check_health", tags=["general"], description="Check the health of the API")
async def check_health(session: Session = Depends(provide_connection)):
"""
Check all the services in the infrastructure are working
"""
# Check if the database is alive
try:
session.execute(text("SELECT 1"))
except Exception as e:
logger.error(f"DB check failed: {e}")
return JSONResponse(content={"status": "unhealthy", "detail": "Database not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
# Check if the scraper endpoint is reachable
if FLARESOLVERR_API := os.getenv("FLARESOLVERR_API"):
import requests
try:
# Ping the scraper endpoint
response = requests.get(FLARESOLVERR_API + "health", timeout=5)
if response.status_code != HTTP_200_OK:
logger.error(f"Scraper check failed: {response.status_code}")
return JSONResponse(content={"status": "unhealthy", "detail": "Flaresolverr not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
except requests.RequestException as e:
logger.error(f"Scraper check failed: {e}")
return JSONResponse(content={"status": "unhealthy", "detail": "Flaresolverr not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)