abhitopia
commited on
Commit
·
5b49d65
1
Parent(s):
4383992
better logging
Browse files- code/inference.py +8 -1
code/inference.py
CHANGED
|
@@ -8,18 +8,23 @@ JSON_CONTENT_TYPE = 'application/json'
|
|
| 8 |
|
| 9 |
|
| 10 |
def model_fn(model_dir):
|
| 11 |
-
|
| 12 |
model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
|
| 13 |
return model
|
| 14 |
|
| 15 |
|
| 16 |
def predict_fn(input_data, model):
|
|
|
|
| 17 |
logger.info("input text: {}".format(input_data))
|
| 18 |
prediction = model(input_data)
|
| 19 |
logger.info("prediction: {}".format(input_data))
|
| 20 |
|
| 21 |
|
| 22 |
def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
if content_type == JSON_CONTENT_TYPE:
|
| 24 |
input_data = json.loads(serialized_input_data)
|
| 25 |
return input_data
|
|
@@ -28,6 +33,8 @@ def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
|
|
| 28 |
|
| 29 |
|
| 30 |
def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
|
|
|
|
|
|
|
| 31 |
if accept == JSON_CONTENT_TYPE:
|
| 32 |
return json.dumps(prediction_output), accept
|
| 33 |
raise Exception('Unsupported Content Type')
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def model_fn(model_dir):
|
| 11 |
+
logging.info('[### model_fn ###] Loading model from {}'.format(model_dir))
|
| 12 |
model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
|
| 13 |
return model
|
| 14 |
|
| 15 |
|
| 16 |
def predict_fn(input_data, model):
|
| 17 |
+
logging.info('[### predict_fn ###] Entering predict_fn() method')
|
| 18 |
logger.info("input text: {}".format(input_data))
|
| 19 |
prediction = model(input_data)
|
| 20 |
logger.info("prediction: {}".format(input_data))
|
| 21 |
|
| 22 |
|
| 23 |
def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
|
| 24 |
+
logging.info('[### input_fn ###] Entering input_fn() method')
|
| 25 |
+
logging.info('[### input_fn ###] request_content_type: {}'.format(content_type))
|
| 26 |
+
logging.info('[### input_fn ###] request_body: {}'.format(type(serialized_input_data)))
|
| 27 |
+
|
| 28 |
if content_type == JSON_CONTENT_TYPE:
|
| 29 |
input_data = json.loads(serialized_input_data)
|
| 30 |
return input_data
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
|
| 36 |
+
logging.info('[### output_fn ###] Entering output_fn() method')
|
| 37 |
+
logging.info('[### output_fn ###] prediction: {}'.format(prediction_output))
|
| 38 |
if accept == JSON_CONTENT_TYPE:
|
| 39 |
return json.dumps(prediction_output), accept
|
| 40 |
raise Exception('Unsupported Content Type')
|