File size: 1,714 Bytes
047cc84
 
bb3b2b1
 
 
047cc84
 
dca09fa
047cc84
 
 
 
 
 
 
dca09fa
047cc84
 
 
 
 
 
 
 
 
dca09fa
047cc84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dca09fa
047cc84
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import json
import torch
import os
os.system("pip install --upgrade transformers")

from transformers import AutoModelForCausalLM, AutoTokenizer

def model_fn(model_dir,context=None):
    """
    Load the model and tokenizer from the model directory.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForCausalLM.from_pretrained(model_dir)
    return model, tokenizer

def input_fn(request_body, request_content_type,context=None):
    """
    Parse the input data from the request.
    """
    if request_content_type == 'application/json':
        input_data = json.loads(request_body)
        return input_data
    else:
        raise ValueError(f"Unsupported content type: {request_content_type}")

def predict_fn(input_data, model_and_tokenizer,context=None):
    """
    Generate predictions from the input data using the model.
    """
    model, tokenizer = model_and_tokenizer
    source_text = input_data['inputs']
    prompt = input_data.get('prompt', '')  # Optional prompt for guiding translation
    
    # Combine prompt with source text
    input_text = f"{prompt} {source_text}".strip()
    
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    
    # Generate translation
    output_ids = model.generate(
        input_ids, 
        max_length=512, 
        num_beams=5, 
        early_stopping=True
    )
    
    # Decode the output tokens
    translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    return translated_text

def output_fn(prediction, content_type,context=None):
    """
    Format the prediction output.
    """
    return json.dumps({'translated_text': prediction})