sghorbal commited on
Commit
f924769
·
1 Parent(s): 51c1df4

generate dynamically endpoints forwarding to the ML module

Browse files
Files changed (2) hide show
  1. src/api_factory.py +94 -0
  2. src/main.py +23 -43
src/api_factory.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Query
2
+ from typing import Dict, Literal
3
+ import httpx
4
+
5
+ def extract_type(p):
6
+ schema = p.get("schema", {})
7
+ if "$ref" in schema:
8
+ return str # fallback
9
+
10
+ if "enum" in schema:
11
+ return Literal[tuple(schema["enum"])]
12
+
13
+ return openapi_type_to_python(schema.get("type", "string"))
14
+
15
+ def get_param_metadata(p):
16
+ schema = p.get("schema", {})
17
+ return {
18
+ "name": p["name"],
19
+ "type": extract_type(p),
20
+ "required": p.get("required", False),
21
+ "description": p.get("description", ""),
22
+ "default": schema.get("default", ...), # `...` => required if undefined
23
+ "enum": schema.get("enum", None),
24
+ }
25
+
26
+ async def get_remote_params(base_url: str,
27
+ endpoint: str,
28
+ method: str = 'get'):
29
+ async with httpx.AsyncClient() as client:
30
+ response = await client.get(f"{base_url}openapi.json")
31
+ spec = response.json()
32
+
33
+ # Extraction of the parameters from the OpenAPI spec
34
+ path_def = spec["paths"].get(f'/{endpoint}', {})
35
+ method_def = path_def.get(method, {})
36
+ params = method_def.get("parameters", [])
37
+
38
+ return [
39
+ get_param_metadata(p)
40
+ for p in params if p["in"] == "query"
41
+ ]
42
+
43
+ def openapi_type_to_python(t: str):
44
+ return {
45
+ "string": str,
46
+ "integer": int,
47
+ "number": float,
48
+ "boolean": bool,
49
+ "array": list,
50
+ }.get(t, str) # fallback = str
51
+
52
+ def create_forward_endpoint(base_url: str, _endpoint: str, param_defs: Dict):
53
+ # Dynamic endpoint creation
54
+ def endpoint_factory():
55
+ async def endpoint(**kwargs):
56
+ async with httpx.AsyncClient() as client:
57
+ response = await client.get(
58
+ f"{base_url}{_endpoint}",
59
+ params=kwargs
60
+ )
61
+ return response.json()
62
+ return endpoint
63
+
64
+ endpoint = endpoint_factory()
65
+
66
+ # Add the parameters to the endpoint
67
+ endpoint.__annotations__ = {
68
+ p["name"]: openapi_type_to_python(p["type"]) for p in param_defs
69
+ }
70
+
71
+ # Add dependencies to the endpoint
72
+ from inspect import Parameter, Signature
73
+
74
+ parameters = []
75
+ for p in param_defs:
76
+ annotation = p["type"]
77
+ default = p.get("default", ...) if p.get("required", False) else p.get("default", None)
78
+
79
+ query = Query(
80
+ default=default,
81
+ description=p.get("description", "")
82
+ )
83
+
84
+ param = Parameter(
85
+ name=p["name"],
86
+ kind=Parameter.KEYWORD_ONLY,
87
+ default=query,
88
+ annotation=annotation
89
+ )
90
+ parameters.append(param)
91
+
92
+ endpoint.__signature__ = Signature(parameters)
93
+
94
+ return endpoint
src/main.py CHANGED
@@ -39,7 +39,7 @@ from src.repository.common import get_session
39
  from src.service.match import insert_new_match
40
 
41
  from contextlib import asynccontextmanager
42
- import httpx
43
 
44
  load_dotenv()
45
 
@@ -59,9 +59,28 @@ TENNIS_ML_API = os.getenv("TENNIS_ML_API")
59
 
60
  @asynccontextmanager
61
  async def lifespan(app: FastAPI):
62
- # Other server that requests should be forwarded to
63
- async with httpx.AsyncClient(base_url=TENNIS_ML_API) as client:
64
- yield {'client': client}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # ------------------------------------------------------------------------------
67
 
@@ -98,45 +117,6 @@ def redirect_to_docs():
98
  '''
99
  return RedirectResponse(url='/docs')
100
 
101
- @app.get("/run_experiment", tags=["model"], description="Schedule a run of the ML experiment")
102
- async def run_xp(request: Request):
103
- """
104
- Train the model
105
- """
106
- params = dict(request.query_params)
107
-
108
- async with httpx.AsyncClient() as client:
109
- response = await client.get(TENNIS_ML_API + "run_experiment", params=params)
110
- return response.json()
111
-
112
- @app.get("/predict",
113
- tags=["model"],
114
- description="Predict the outcome of a tennis match",)
115
- async def make_prediction(request: Request):
116
- """
117
- Predict the matches
118
- """
119
- params = dict(request.query_params)
120
-
121
- if not TENNIS_ML_API:
122
- return {"message": "TENNIS_ML_API environment variable not set."}
123
-
124
- async with httpx.AsyncClient() as client:
125
- response = await client.get(TENNIS_ML_API + "predict", params=params)
126
- return response.json()
127
-
128
- @app.get("/list_available_models", tags=["model"], description="List the available models")
129
- async def list_available_models():
130
- """
131
- List the available models
132
- """
133
- if not TENNIS_ML_API:
134
- return {"message": "TENNIS_ML_API environment variable not set."}
135
-
136
- async with httpx.AsyncClient() as client:
137
- response = await client.get(TENNIS_ML_API + "list_available_models")
138
- return response.json()
139
-
140
  # List all the tournament names and years
141
  @app.get("/tournament/names", tags=["tournament"], description="List all the tournament names and years", response_model=List[Dict])
142
  async def list_tournament_names(
 
39
  from src.service.match import insert_new_match
40
 
41
  from contextlib import asynccontextmanager
42
+ from src.api_factory import create_forward_endpoint, get_remote_params
43
 
44
  load_dotenv()
45
 
 
59
 
60
  @asynccontextmanager
61
  async def lifespan(app: FastAPI):
62
+ endpoints = [
63
+ "run_experiment",
64
+ "predict",
65
+ "list_available_models"
66
+ ]
67
+
68
+ for endpoint in endpoints:
69
+ param_defs = await get_remote_params(base_url=TENNIS_ML_API,
70
+ endpoint=endpoint,
71
+ method='get')
72
+ forward_endpoint = create_forward_endpoint(base_url=TENNIS_ML_API,
73
+ _endpoint=endpoint,
74
+ param_defs=param_defs)
75
+
76
+ app.add_api_route(
77
+ path=f'/{endpoint}',
78
+ endpoint=forward_endpoint,
79
+ methods=["GET"],
80
+ name=f"Forward to remote {forward_endpoint.__name__}",
81
+ )
82
+
83
+ yield
84
 
85
  # ------------------------------------------------------------------------------
86
 
 
117
  '''
118
  return RedirectResponse(url='/docs')
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # List all the tournament names and years
121
  @app.get("/tournament/names", tags=["tournament"], description="List all the tournament names and years", response_model=List[Dict])
122
  async def list_tournament_names(