File size: 2,073 Bytes
d4710a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import json
# Configuration
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
ADAPTER = "tans37/mistral-query-router"
class EndpointHandler():
def __init__(self, path=""):
# Load the base model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Load base model in half precision for efficiency
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
# Load the Peft adapter
# 'path' is the directory where the handler is located (the adapter repo)
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
print(f"[Handler] Loaded LoRA adapter from {path} onto {BASE_MODEL}")
def __call__(self, data):
"""
Args:
data (:obj: `dict`):
subset of the request body with the following keys:
- `inputs`: the prompt to be processed
- `parameters`: optional generation parameters
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {
"max_new_tokens": 128,
"temperature": 0.1,
"top_p": 0.9,
"do_sample": False
})
# Tokenize
inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
# Generate
with torch.no_grad():
output_tokens = self.model.generate(
**inputs,
**parameters
)
# Decode
# We only want the new tokens, so we slice the output
input_len = inputs["input_ids"].shape[1]
new_tokens = output_tokens[0][input_len:]
prediction = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return [{"generated_text": prediction}]
|