Upload code/ with huggingface_hub
Browse files- code/inference.py +5 -27
code/inference.py
CHANGED
|
@@ -268,12 +268,6 @@ def pipelinex(
|
|
| 268 |
|
| 269 |
|
| 270 |
|
| 271 |
-
logger = logging.getLogger(__name__)
|
| 272 |
-
logger.setLevel(logging.DEBUG)
|
| 273 |
-
|
| 274 |
-
JSON_CONTENT_TYPE = 'application/json'
|
| 275 |
-
|
| 276 |
-
|
| 277 |
|
| 278 |
# loads the model into memory from disk and returns it
|
| 279 |
def model_fn(model_dir):
|
|
@@ -285,32 +279,16 @@ def model_fn(model_dir):
|
|
| 285 |
|
| 286 |
|
| 287 |
# Perform prediction on the deserialized object, with the loaded model
|
| 288 |
-
def predict_fn(
|
| 289 |
-
|
| 290 |
|
|
|
|
| 291 |
|
| 292 |
-
logger.info("Calling model")
|
| 293 |
-
start_time = time.time()
|
| 294 |
-
#pipelines.py script in the cloned repo
|
| 295 |
multimodel = pipelinex("multitask-qa-qg",tokenizer=tokenizer,model=model)
|
| 296 |
-
answers = multimodel(
|
| 297 |
-
|
| 298 |
-
|
| 299 |
|
| 300 |
return answers
|
| 301 |
-
|
| 302 |
-
def input_fn(request_body, content_type=JSON_CONTENT_TYPE):
|
| 303 |
-
logger.info('Deserializing the input data.')
|
| 304 |
-
# process an jsonlines uploaded to the endpoint
|
| 305 |
-
if content_type == JSON_CONTENT_TYPE: return request_body["text"]
|
| 306 |
-
raise Exception('Requested unsupported ContentType in content_type: {}'.format(content_type))
|
| 307 |
-
|
| 308 |
-
# Serialize the prediction result into the desired response content type
|
| 309 |
-
def output_fn(prediction, accept=JSON_CONTENT_TYPE):
|
| 310 |
-
logger.info('Serializing the generated output.')
|
| 311 |
-
if accept == JSON_CONTENT_TYPE: return json.dumps(prediction), accept
|
| 312 |
-
raise Exception('Requested unsupported ContentType in Accept: {}'.format(accept))
|
| 313 |
-
|
| 314 |
|
| 315 |
|
| 316 |
|
|
|
|
| 268 |
|
| 269 |
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
# loads the model into memory from disk and returns it
|
| 273 |
def model_fn(model_dir):
|
|
|
|
| 279 |
|
| 280 |
|
| 281 |
# Perform prediction on the deserialized object, with the loaded model
|
| 282 |
+
def predict_fn(data, model_tokenizer):
|
|
|
|
| 283 |
|
| 284 |
+
model,tokenizer = model_tokenizer
|
| 285 |
|
|
|
|
|
|
|
|
|
|
| 286 |
multimodel = pipelinex("multitask-qa-qg",tokenizer=tokenizer,model=model)
|
| 287 |
+
answers = multimodel(data)
|
| 288 |
+
|
|
|
|
| 289 |
|
| 290 |
return answers
|
| 291 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
|
| 294 |
|