| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from peft import PeftModel
|
| import json
|
|
|
|
|
| BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
| ADAPTER = "tans37/mistral-query-router"
|
|
|
| class EndpointHandler():
|
| def __init__(self, path=""):
|
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
|
|
|
| base_model = AutoModelForCausalLM.from_pretrained(
|
| BASE_MODEL,
|
| torch_dtype=torch.float16,
|
| device_map="auto"
|
| )
|
|
|
|
|
|
|
| self.model = PeftModel.from_pretrained(base_model, path)
|
| self.model.eval()
|
|
|
| print(f"[Handler] Loaded LoRA adapter from {path} onto {BASE_MODEL}")
|
|
|
| def __call__(self, data):
|
| """
|
| Args:
|
| data (:obj: `dict`):
|
| subset of the request body with the following keys:
|
| - `inputs`: the prompt to be processed
|
| - `parameters`: optional generation parameters
|
| """
|
| inputs = data.pop("inputs", data)
|
| parameters = data.pop("parameters", {
|
| "max_new_tokens": 128,
|
| "temperature": 0.1,
|
| "top_p": 0.9,
|
| "do_sample": False
|
| })
|
|
|
|
|
| inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
|
|
|
|
|
| with torch.no_grad():
|
| output_tokens = self.model.generate(
|
| **inputs,
|
| **parameters
|
| )
|
|
|
|
|
|
|
| input_len = inputs["input_ids"].shape[1]
|
| new_tokens = output_tokens[0][input_len:]
|
| prediction = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
|
|
| return [{"generated_text": prediction}]
|
|
|