|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch, json |
|
|
|
|
|
def model_fn(model_dir, *_): |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_dir, trust_remote_code=True |
|
|
) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
return {"model": model, "tokenizer": tokenizer} |
|
|
|
|
|
def input_fn(serialized_input, content_type, *_): |
|
|
|
|
|
if content_type == "application/json": |
|
|
return json.loads(serialized_input).get("inputs", "") |
|
|
return serialized_input |
|
|
|
|
|
def predict_fn(prompt, model_bundle, *_): |
|
|
tok = model_bundle["tokenizer"] |
|
|
mdl = model_bundle["model"] |
|
|
inputs = tok(prompt, return_tensors="pt").to(mdl.device) |
|
|
outputs = mdl.generate(**inputs, max_new_tokens=128) |
|
|
return tok.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def output_fn(prediction, accept, *_): |
|
|
|
|
|
if accept == "application/json": |
|
|
return json.dumps({"generated_text": prediction}) |
|
|
return prediction |
|
|
|