|
|
import json |
|
|
import torch |
|
|
import os |
|
|
os.system("pip install --upgrade transformers") |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
def model_fn(model_dir,context=None): |
|
|
""" |
|
|
Load the model and tokenizer from the model directory. |
|
|
""" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_dir) |
|
|
return model, tokenizer |
|
|
|
|
|
def input_fn(request_body, request_content_type,context=None): |
|
|
""" |
|
|
Parse the input data from the request. |
|
|
""" |
|
|
if request_content_type == 'application/json': |
|
|
input_data = json.loads(request_body) |
|
|
return input_data |
|
|
else: |
|
|
raise ValueError(f"Unsupported content type: {request_content_type}") |
|
|
|
|
|
def predict_fn(input_data, model_and_tokenizer,context=None): |
|
|
""" |
|
|
Generate predictions from the input data using the model. |
|
|
""" |
|
|
model, tokenizer = model_and_tokenizer |
|
|
source_text = input_data['inputs'] |
|
|
prompt = input_data.get('prompt', '') |
|
|
|
|
|
|
|
|
input_text = f"{prompt} {source_text}".strip() |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt') |
|
|
|
|
|
|
|
|
output_ids = model.generate( |
|
|
input_ids, |
|
|
max_length=512, |
|
|
num_beams=5, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
|
|
|
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
return translated_text |
|
|
|
|
|
def output_fn(prediction, content_type,context=None): |
|
|
""" |
|
|
Format the prediction output. |
|
|
""" |
|
|
return json.dumps({'translated_text': prediction}) |