SlimG commited on
Commit
8a395d5
·
1 Parent(s): d5e00d5

move model classes to dedicated file

Browse files
Files changed (3) hide show
  1. src/entity/__init__.py +0 -0
  2. src/entity/model.py +28 -0
  3. src/main.py +3 -27
src/entity/__init__.py ADDED
File without changes
src/entity/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, List
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class ModelInput(BaseModel):
6
+ p1_rank: int = Field(gt=0, default=1, description="The rank of the 1st player")
7
+ p2_rank: int = Field(gt=0, default=100, description="The rank of the 2nd player")
8
+ court: Literal['Outdoor', 'Indoor'] = Field(default='Outdoor', description="The type of court")
9
+ surface: Literal['Grass', 'Carpet', 'Clay', 'Hard'] = Field(default='Clay', description="The type of surface")
10
+ series: Literal['Grand Slam', 'Masters 1000', 'Masters', 'Masters Cup', 'ATP500', 'ATP250', 'International Gold', 'International'] = Field(default='Grand Slam', description="The series of the tournament")
11
+ p1_height: Optional[int] = Field(gt=0, default=180, description="The height of the 1st player in centimeters")
12
+ p2_height: Optional[int] = Field(gt=0, default=180, description="The height of the 2nd player in centimeters")
13
+ p1_weight: Optional[int] = Field(gt=0, default=80, description="The weight of the 1st player in kilograms")
14
+ p2_weight: Optional[int] = Field(gt=0, default=80, description="The weight of the 2nd player in kilograms")
15
+ p1_year_of_birth: Optional[int] = Field(gt=1950, default=1980, description="The year of birth of the 1st player")
16
+ p2_year_of_birth: Optional[int] = Field(gt=1950, default=1980, description="The year of birth of the 2nd player")
17
+ p1_play_hand: Literal['Right', 'Left'] = Field(default='Right', description="The play hand of the 1st player")
18
+ p2_play_hand: Literal['Right', 'Left'] = Field(default='Right', description="The play hand of the 2nd player")
19
+ p1_back_hand: int = Field(default=1, ge=1, le=2, description="The back hand of the 1st player. 1 for one-handed, 2 for two-handed")
20
+ p2_back_hand: int = Field(default=1, ge=1, le=2, description="The back hand of the 2nd player. 1 for one-handed, 2 for two-handed")
21
+ p1_pro_year: Optional[int] = Field(gt=1970, default=2000, description="The year the 1st player turned pro")
22
+ p2_pro_year: Optional[int] = Field(gt=1970, default=2000, description="The year the 2nd player turned pro")
23
+ model: Optional[str] = Field(default='LogisticRegression', description="The name of the model to use for prediction")
24
+ version: Optional[str] = Field(default='latest', description="The version of the model to use for prediction")
25
+
26
+ class ModelOutput(BaseModel):
27
+ result: int = Field(description="The prediction result. 1 if player 1 is expected to win, 0 otherwise.", json_schema_extra={"example": "1"})
28
+ prob: List[float] = Field(description="Probability of [defeat, victory] of player 1.", json_schema_extra={"example": "[0.15, 0.85]"})
src/main.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import joblib
3
  import logging
4
  import secrets
5
- from typing import Generator, Literal, Optional, Annotated
6
  from fastapi import (
7
  FastAPI,
8
  Request,
@@ -19,10 +19,11 @@ from starlette.status import (
19
  HTTP_403_FORBIDDEN,
20
  HTTP_404_NOT_FOUND,
21
  HTTP_503_SERVICE_UNAVAILABLE)
22
- from pydantic import BaseModel, Field
23
  from dotenv import load_dotenv
24
  from mlflow.exceptions import RestException
25
 
 
26
  from src.service.model import (
27
  run_experiment,
28
  predict,
@@ -93,31 +94,6 @@ async def run_xp(background_tasks: BackgroundTasks,
93
 
94
  return {"message": "Experiment scheduled"}
95
 
96
- class ModelInput(BaseModel):
97
- p1_rank: int = Field(gt=0, default=1, description="The rank of the 1st player")
98
- p2_rank: int = Field(gt=0, default=100, description="The rank of the 2nd player")
99
- court: Literal['Outdoor', 'Indoor'] = Field(default='Outdoor', description="The type of court")
100
- surface: Literal['Grass', 'Carpet', 'Clay', 'Hard'] = Field(default='Clay', description="The type of surface")
101
- series: Literal['Grand Slam', 'Masters 1000', 'Masters', 'Masters Cup', 'ATP500', 'ATP250', 'International Gold', 'International'] = Field(default='Grand Slam', description="The series of the tournament")
102
- p1_height: Optional[int] = Field(gt=0, default=180, description="The height of the 1st player in centimeters")
103
- p2_height: Optional[int] = Field(gt=0, default=180, description="The height of the 2nd player in centimeters")
104
- p1_weight: Optional[int] = Field(gt=0, default=80, description="The weight of the 1st player in kilograms")
105
- p2_weight: Optional[int] = Field(gt=0, default=80, description="The weight of the 2nd player in kilograms")
106
- p1_year_of_birth: Optional[int] = Field(gt=1950, default=1980, description="The year of birth of the 1st player")
107
- p2_year_of_birth: Optional[int] = Field(gt=1950, default=1980, description="The year of birth of the 2nd player")
108
- p1_play_hand: Literal['Right', 'Left'] = Field(default='Right', description="The play hand of the 1st player")
109
- p2_play_hand: Literal['Right', 'Left'] = Field(default='Right', description="The play hand of the 2nd player")
110
- p1_back_hand: int = Field(default=1, ge=1, le=2, description="The back hand of the 1st player. 1 for one-handed, 2 for two-handed")
111
- p2_back_hand: int = Field(default=1, ge=1, le=2, description="The back hand of the 2nd player. 1 for one-handed, 2 for two-handed")
112
- p1_pro_year: Optional[int] = Field(gt=1970, default=2000, description="The year the 1st player turned pro")
113
- p2_pro_year: Optional[int] = Field(gt=1970, default=2000, description="The year the 2nd player turned pro")
114
- model: Optional[str] = Field(default='LogisticRegression', description="The name of the model to use for prediction")
115
- version: Optional[str] = Field(default='latest', description="The version of the model to use for prediction")
116
-
117
- class ModelOutput(BaseModel):
118
- result: int = Field(description="The prediction result. 1 if player 1 is expected to win, 0 otherwise.", json_schema_extra={"example": "1"})
119
- prob: list[float] = Field(description="Probability of [defeat, victory] of player 1.", json_schema_extra={"example": "[0.15, 0.85]"})
120
-
121
  @app.get("/predict",
122
  tags=["model"],
123
  description="Predict the outcome of a tennis match",
 
2
  import joblib
3
  import logging
4
  import secrets
5
+ from typing import Generator, Optional, Annotated
6
  from fastapi import (
7
  FastAPI,
8
  Request,
 
19
  HTTP_403_FORBIDDEN,
20
  HTTP_404_NOT_FOUND,
21
  HTTP_503_SERVICE_UNAVAILABLE)
22
+ from pydantic import Field
23
  from dotenv import load_dotenv
24
  from mlflow.exceptions import RestException
25
 
26
+ from src.entity.model import ModelInput, ModelOutput
27
  from src.service.model import (
28
  run_experiment,
29
  predict,
 
94
 
95
  return {"message": "Experiment scheduled"}
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  @app.get("/predict",
98
  tags=["model"],
99
  description="Predict the outcome of a tennis match",