File size: 1,057 Bytes
9e37e43
48505cf
9e37e43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Environment variables
import os
MODEL_PATH = os.environ.get('MODEL_PATH', '/opt/ml/model')

# Load the model
def model_fn(model_dir):
    model_file = os.path.join(model_dir, 'xgboost-model')
    model = model
    model.load_model(model_file)
    return model
# Deserialize the input data
def input_fn(request_body, request_content_type):
    if request_content_type == 'application/json':
        input_data = json.loads(request_body)
        return np.array(input_data['instances'])
    else:
        raise ValueError("Unsupported content type: {}".format(request_content_type))
# Serialize the output data
def output_fn(prediction, response_content_type):
    if response_content_type == 'application/json':
        response = json.dumps({'predictions': prediction.tolist()})
        return response
    else:
        raise ValueError("Unsupported content type: {}".format(response_content_type))
# Make predictions
def predict_fn(input_data, model):
    dmatrix = model.DMatrix(input_data)
    prediction = model.predict(dmatrix)
    return prediction