SoumyaRanjan's picture
Update code/inference.py
48505cf verified
# Environment variables
import os
MODEL_PATH = os.environ.get('MODEL_PATH', '/opt/ml/model')
# Load the model
def model_fn(model_dir):
model_file = os.path.join(model_dir, 'xgboost-model')
model = model
model.load_model(model_file)
return model
# Deserialize the input data
def input_fn(request_body, request_content_type):
if request_content_type == 'application/json':
input_data = json.loads(request_body)
return np.array(input_data['instances'])
else:
raise ValueError("Unsupported content type: {}".format(request_content_type))
# Serialize the output data
def output_fn(prediction, response_content_type):
if response_content_type == 'application/json':
response = json.dumps({'predictions': prediction.tolist()})
return response
else:
raise ValueError("Unsupported content type: {}".format(response_content_type))
# Make predictions
def predict_fn(input_data, model):
dmatrix = model.DMatrix(input_data)
prediction = model.predict(dmatrix)
return prediction