abhitopia commited on
Commit
5b49d65
·
1 Parent(s): 4383992

better logging

Browse files
Files changed (1) hide show
  1. code/inference.py +8 -1
code/inference.py CHANGED
@@ -8,18 +8,23 @@ JSON_CONTENT_TYPE = 'application/json'
8
 
9
 
10
  def model_fn(model_dir):
11
- logger.info(f"model_dir: {model_dir}")
12
  model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
13
  return model
14
 
15
 
16
  def predict_fn(input_data, model):
 
17
  logger.info("input text: {}".format(input_data))
18
  prediction = model(input_data)
19
  logger.info("prediction: {}".format(input_data))
20
 
21
 
22
  def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
 
 
 
 
23
  if content_type == JSON_CONTENT_TYPE:
24
  input_data = json.loads(serialized_input_data)
25
  return input_data
@@ -28,6 +33,8 @@ def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
28
 
29
 
30
  def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
 
 
31
  if accept == JSON_CONTENT_TYPE:
32
  return json.dumps(prediction_output), accept
33
  raise Exception('Unsupported Content Type')
 
8
 
9
 
10
  def model_fn(model_dir):
11
+ logging.info('[### model_fn ###] Loading model from {}'.format(model_dir))
12
  model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
13
  return model
14
 
15
 
16
  def predict_fn(input_data, model):
17
+ logging.info('[### predict_fn ###] Entering predict_fn() method')
18
  logger.info("input text: {}".format(input_data))
19
  prediction = model(input_data)
20
  logger.info("prediction: {}".format(input_data))
21
 
22
 
23
  def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
24
+ logging.info('[### input_fn ###] Entering input_fn() method')
25
+ logging.info('[### input_fn ###] request_content_type: {}'.format(content_type))
26
+ logging.info('[### input_fn ###] request_body: {}'.format(type(serialized_input_data)))
27
+
28
  if content_type == JSON_CONTENT_TYPE:
29
  input_data = json.loads(serialized_input_data)
30
  return input_data
 
33
 
34
 
35
  def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
36
+ logging.info('[### output_fn ###] Entering output_fn() method')
37
+ logging.info('[### output_fn ###] prediction: {}'.format(prediction_output))
38
  if accept == JSON_CONTENT_TYPE:
39
  return json.dumps(prediction_output), accept
40
  raise Exception('Unsupported Content Type')