FINGU-AI's picture
Update code/inference.py
bb3b2b1 verified
import json
import torch
import os
os.system("pip install --upgrade transformers")
from transformers import AutoModelForCausalLM, AutoTokenizer
def model_fn(model_dir,context=None):
"""
Load the model and tokenizer from the model directory.
"""
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir)
return model, tokenizer
def input_fn(request_body, request_content_type,context=None):
"""
Parse the input data from the request.
"""
if request_content_type == 'application/json':
input_data = json.loads(request_body)
return input_data
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model_and_tokenizer,context=None):
"""
Generate predictions from the input data using the model.
"""
model, tokenizer = model_and_tokenizer
source_text = input_data['inputs']
prompt = input_data.get('prompt', '') # Optional prompt for guiding translation
# Combine prompt with source text
input_text = f"{prompt} {source_text}".strip()
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Generate translation
output_ids = model.generate(
input_ids,
max_length=512,
num_beams=5,
early_stopping=True
)
# Decode the output tokens
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return translated_text
def output_fn(prediction, content_type,context=None):
"""
Format the prediction output.
"""
return json.dumps({'translated_text': prediction})