Spaces:
Running
Running
| import os | |
| import logging | |
| import secrets | |
| from typing import Generator, Optional, Annotated, List, Dict, AsyncGenerator | |
| from datetime import datetime | |
| from contextlib import asynccontextmanager | |
| from fastapi import ( | |
| FastAPI, | |
| Request, | |
| HTTPException, | |
| Query, | |
| Security, | |
| Depends | |
| ) | |
| from fastapi.responses import RedirectResponse, JSONResponse | |
| from fastapi.security.api_key import APIKeyHeader | |
| from starlette.status import ( | |
| HTTP_200_OK, | |
| HTTP_403_FORBIDDEN, | |
| HTTP_404_NOT_FOUND, | |
| HTTP_422_UNPROCESSABLE_ENTITY, | |
| HTTP_503_SERVICE_UNAVAILABLE) | |
| from sqlalchemy import text | |
| from sqlalchemy.orm import Session | |
| from pydantic import BaseModel, Field | |
| from dotenv import load_dotenv | |
| from sqlalchemy.exc import IntegrityError | |
| from src.api_factory import create_forward_endpoint, get_remote_params | |
| from src.entity.match import ( | |
| Match, | |
| RawMatch, | |
| MatchApiBase, | |
| MatchApiDetail | |
| ) | |
| from src.entity.player import ( | |
| Player, | |
| PlayerApiDetail, | |
| ) | |
| from src.jobs.match import schedule_matches_ingestion | |
| from src.jobs.views import schedule_refresh | |
| from src.repository.common import get_session | |
| from src.service.match import ( | |
| insert_new_match, | |
| insert_batch_matches, | |
| ) | |
| load_dotenv() | |
| # ------------------------------------------------------------------------------ | |
| logging.basicConfig(level=logging.INFO, | |
| handlers=[logging.StreamHandler()]) | |
| logger = logging.getLogger(__name__) | |
| def provide_connection() -> Generator[Session, None, None]: | |
| with get_session() as conn: | |
| yield conn | |
| # ------------------------------------------------------------------------------ | |
| # Ensure all the necessary jobs are scheduled | |
| if os.getenv("REDIS_URL"): | |
| schedule_refresh() | |
| schedule_matches_ingestion(year=None) | |
| # ------------------------------------------------------------------------------ | |
| TENNIS_ML_API = os.getenv("TENNIS_ML_API") | |
| async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | |
| if not TENNIS_ML_API: | |
| yield | |
| return | |
| # Create a forward endpoint for each endpoint in the remote API | |
| endpoints = [ | |
| "run_experiment", | |
| "predict", | |
| "list_available_models" | |
| ] | |
| for endpoint in endpoints: | |
| endpoint_def = await get_remote_params(base_url=TENNIS_ML_API, | |
| endpoint=endpoint, | |
| method='get') | |
| if not endpoint_def["found"]: | |
| logger.warning(f"Endpoint {endpoint} not found in the remote API") | |
| continue | |
| # Create a forward endpoint for the remote API | |
| forward_endpoint = create_forward_endpoint(base_url=TENNIS_ML_API, | |
| _endpoint=endpoint, | |
| param_defs=endpoint_def["params"]) | |
| app.add_api_route( | |
| path=f'/{endpoint}', | |
| endpoint=forward_endpoint, | |
| methods=["GET"], | |
| name=f"Forward to remote {forward_endpoint.__name__}", | |
| tags=endpoint_def["general"]["tags"], | |
| description=endpoint_def["general"]["description"], | |
| summary=endpoint_def["general"]["summary"], | |
| ) | |
| yield | |
| # ------------------------------------------------------------------------------ | |
| FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY") | |
| safe_clients = ['127.0.0.1'] | |
| api_key_header = APIKeyHeader(name='Authorization', auto_error=False) | |
| async def validate_api_key(request: Request, key: str = Security(api_key_header)) -> None: | |
| ''' | |
| Check if the API key is valid | |
| Args: | |
| key (str): The API key to check | |
| Raises: | |
| HTTPException: If the API key is invalid | |
| ''' | |
| if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)): | |
| raise HTTPException( | |
| status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong" | |
| ) | |
| return None | |
| app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None, | |
| lifespan=lifespan if TENNIS_ML_API else None, | |
| title="Tennis Insights API") | |
| # ------------------------------------------------------------------------------ | |
| def redirect_to_docs(): | |
| ''' | |
| Redirect to the API documentation. | |
| ''' | |
| return RedirectResponse(url='/docs') | |
| # List all the tournament names and years | |
| async def list_tournament_names( | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| List all the tournament names and first and last year of occurrence | |
| """ | |
| tournaments = session.execute(text("SELECT t.name, t.first_year, t.last_year FROM data.tournaments_list_m_view AS t")).all() | |
| tournaments = [ | |
| { | |
| "name": tournament[0], | |
| "first_year": tournament[1], | |
| "last_year": tournament[2] | |
| } | |
| for tournament in tournaments | |
| ] | |
| return tournaments | |
| async def list_courts( | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| List all the courts | |
| """ | |
| courts = session.execute(text("SELECT name FROM data.ref_court_m_view")).all() | |
| courts = [court[0] for court in courts] | |
| return courts | |
| async def list_surfaces( | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| List all the surfaces | |
| """ | |
| surfaces = session.execute(text("SELECT name FROM data.ref_surface_m_view")).all() | |
| surfaces = [surface[0] for surface in surfaces] | |
| return surfaces | |
| async def list_series( | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| List all the series | |
| """ | |
| series = session.execute(text("SELECT name FROM data.ref_series_m_view")).all() | |
| series = [serie[0] for serie in series] | |
| return series | |
| class ListPlayersInput(BaseModel): | |
| ids: List[int] = Field( | |
| description="List of player IDs", | |
| ) | |
| # Get a list of players | |
| async def list_players( | |
| params: Annotated[ListPlayersInput, Query()], | |
| session: Session = Depends(provide_connection), | |
| ): | |
| """ | |
| Get a list of players from the database | |
| """ | |
| ids = sorted(params.ids) | |
| players = session.query(Player).filter(Player.id.in_(ids)).all() | |
| if not players: | |
| raise HTTPException( | |
| status_code=HTTP_404_NOT_FOUND, | |
| detail=f"Players {ids} not found" | |
| ) | |
| players = {player.id: player for player in players} | |
| return {id: players.get(id) for id in ids} | |
| async def get_player( | |
| player_id: int, | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| Get a player from the database | |
| """ | |
| player = session.query(Player).filter(Player.id == player_id).first() | |
| if not player: | |
| raise HTTPException( | |
| status_code=HTTP_404_NOT_FOUND, | |
| detail=f"Player {player_id} not found" | |
| ) | |
| return player | |
| # Get all the matches from a tournament | |
| async def search_tournament_matches( | |
| name: str, | |
| year: int, | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| Get all the matches from a tournament | |
| """ | |
| start_date = datetime(year, 1, 1) | |
| end_date = datetime(year, 12, 31) | |
| matches = session.query(Match).filter( | |
| Match.tournament_name == name, | |
| Match.date.between(start_date, end_date) | |
| ).all() | |
| return sorted(matches, key=lambda x: x.date, reverse=True) | |
| # Get a match | |
| async def get_match( | |
| match_id: int, | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| Get a match from the database | |
| """ | |
| match = session.query(Match).filter(Match.id == match_id).first() | |
| if not match: | |
| raise HTTPException( | |
| status_code=HTTP_404_NOT_FOUND, | |
| detail=f"Match {match_id} not found" | |
| ) | |
| return match | |
| async def insert_match( | |
| raw_match: RawMatch, | |
| session: Session = Depends(provide_connection), | |
| on_conflict_do_nothing: bool = False, | |
| ): | |
| """ | |
| Insert a match into the database | |
| """ | |
| try: | |
| match = insert_new_match( | |
| db=session, | |
| raw_match=raw_match.model_dump(exclude_unset=True), | |
| on_conflict_do_nothing=on_conflict_do_nothing, | |
| ) | |
| except IntegrityError as e: | |
| logger.error(f"Error inserting match: {e}") | |
| raise HTTPException( | |
| status_code=HTTP_422_UNPROCESSABLE_ENTITY, | |
| detail="Entity already exists in the database" | |
| ) | |
| output = { | |
| "status": "ok", | |
| "match_id": match.id if match else None, | |
| } | |
| return JSONResponse(content=output, status_code=HTTP_200_OK) | |
| async def ingest_matches(year: Optional[int] = None): | |
| """ | |
| Ingest matches from tennis-data.co.uk for a given year | |
| """ | |
| job = schedule_matches_ingestion(year=year) | |
| return {"job_id": job.get_id(), "job_status": job.get_status()} | |
| async def insert_batch_match( | |
| raw_matches: list[RawMatch], | |
| on_conflict_do_nothing: bool = False, | |
| session: Session = Depends(provide_connection) | |
| ): | |
| """ | |
| Insert a batch of matches into the database | |
| """ | |
| result = insert_batch_matches(db=session, | |
| raw_matches=raw_matches, | |
| on_conflict_do_nothing=on_conflict_do_nothing) | |
| matches = result['matches'] | |
| nb_errors = result['nb_errors'] | |
| logger.info(f"Number of matches inserted: {len(matches)}") | |
| if nb_errors > 0: | |
| logger.warning(f"Number of errors: {nb_errors}") | |
| return JSONResponse( | |
| content={"status": "ok", "message": f"{len(matches)} matches inserted, {nb_errors} errors"}, | |
| status_code=HTTP_422_UNPROCESSABLE_ENTITY | |
| ) | |
| else: | |
| output = { | |
| "status": "ok", | |
| "match_ids": [match.id for match in matches], | |
| } | |
| return JSONResponse(content=output, status_code=HTTP_200_OK) | |
| # ------------------------------------------------------------------------------ | |
| async def check_health(session: Session = Depends(provide_connection)): | |
| """ | |
| Check all the services in the infrastructure are working | |
| """ | |
| # Check if the database is alive | |
| try: | |
| session.execute(text("SELECT 1")) | |
| except Exception as e: | |
| logger.error(f"DB check failed: {e}") | |
| return JSONResponse(content={"status": "unhealthy", "detail": "Database not reachable"}, | |
| status_code=HTTP_503_SERVICE_UNAVAILABLE) | |
| # Check if the scraper endpoint is reachable | |
| if FLARESOLVERR_API := os.getenv("FLARESOLVERR_API"): | |
| import requests | |
| try: | |
| # Ping the scraper endpoint | |
| response = requests.get(FLARESOLVERR_API + "health", timeout=5) | |
| if response.status_code != HTTP_200_OK: | |
| logger.error(f"Scraper check failed: {response.status_code}") | |
| return JSONResponse(content={"status": "unhealthy", "detail": "Flaresolverr not reachable"}, | |
| status_code=HTTP_503_SERVICE_UNAVAILABLE) | |
| except requests.RequestException as e: | |
| logger.error(f"Scraper check failed: {e}") | |
| return JSONResponse(content={"status": "unhealthy", "detail": "Flaresolverr not reachable"}, | |
| status_code=HTTP_503_SERVICE_UNAVAILABLE) | |
| return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK) | |