File size: 1,200 Bytes
ef8c7ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# code/inference.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, json

def model_fn(model_dir, *_):
    # Load with remote code support for Qwen3
    tokenizer = AutoTokenizer.from_pretrained(
        model_dir, trust_remote_code=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    return {"model": model, "tokenizer": tokenizer}

def input_fn(serialized_input, content_type, *_):
    # Accept JSON {"inputs": "..."} or raw text
    if content_type == "application/json":
        return json.loads(serialized_input).get("inputs", "")
    return serialized_input

def predict_fn(prompt, model_bundle, *_):
    tok = model_bundle["tokenizer"]
    mdl = model_bundle["model"]
    inputs = tok(prompt, return_tensors="pt").to(mdl.device)
    outputs = mdl.generate(**inputs, max_new_tokens=128)
    return tok.decode(outputs[0], skip_special_tokens=True)

def output_fn(prediction, accept, *_):
    # Return JSON if requested
    if accept == "application/json":
        return json.dumps({"generated_text": prediction})
    return prediction