| | from typing import Any, List, Dict |
| | from llama_cpp import Llama |
| | import numpy as np |
| | import torch |
| | from transformers import AutoTokenizer, LogitsProcessorList |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize the model handler using llama_cpp. |
| | """ |
| | self.model = Llama.from_pretrained( |
| | repo_id="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", |
| | filename="Meta-Llama-3.1-8B-Instruct-Q6_K.gguf" |
| | ) |
| | self.tokenizer = AutoTokenizer.from_pretrained("taylorj94/Llama-3.2-1B") |
| |
|
| | def get_allowed_token_ids(self, vocab_list: List[str]) -> set[int]: |
| | """ |
| | Generate a set of token IDs for a given list of allowed words. |
| | Includes plain, space-prefixed, capitalized, and uppercase forms of each word. |
| | """ |
| | allowed_ids = set() |
| | for word in vocab_list: |
| | |
| | variations = {word, " " + word, word.capitalize(), " " + word.capitalize()} |
| | |
| | |
| | for variation in variations: |
| | for token_id in self.tokenizer.encode(variation, add_special_tokens=False): |
| | allowed_ids.add(token_id) |
| | |
| | return allowed_ids |
| |
|
| | def filter_allowed_tokens(self, input_ids: torch.Tensor, scores: np.ndarray, allowed_token_ids: set[int]) -> np.ndarray: |
| | """ |
| | Modify scores to allow only tokens in the allowed_token_ids set. |
| | Handles both 1D and 2D scores arrays. |
| | """ |
| | if scores.ndim == 1: |
| | |
| | mask = np.isin(np.arange(scores.shape[0]), list(allowed_token_ids)) |
| | scores[~mask] = float('-inf') |
| | elif scores.ndim == 2: |
| | |
| | for i in range(scores.shape[0]): |
| | mask = np.isin(np.arange(scores.shape[1]), list(allowed_token_ids)) |
| | scores[i, ~mask] = float('-inf') |
| | else: |
| | raise ValueError(f"Unsupported scores dimension: {scores.ndim}") |
| | return scores |
| |
|
| |
|
| | def __call__(self, data: Any) -> List[Dict[str, str]]: |
| | """ |
| | Handle the request, performing inference with a restricted vocabulary. |
| | """ |
| | |
| | inputs = data.get("inputs", None) |
| | parameters = data.get("parameters", {}) |
| | vocab_list = data.get("vocab_list", None) |
| |
|
| | if not inputs: |
| | raise ValueError("The 'inputs' field is required.") |
| |
|
| | |
| | logits_processors = None |
| | allowed_token_ids = [] |
| | |
| | if vocab_list: |
| | |
| | allowed_token_ids = self.get_allowed_token_ids(vocab_list) |
| |
|
| | |
| | input_ids = torch.tensor([self.tokenizer.encode(inputs, add_special_tokens=False)]) |
| | |
| | |
| | logits_processors = LogitsProcessorList([ |
| | lambda input_ids, scores: self.filter_allowed_tokens(input_ids, scores, allowed_token_ids) |
| | ]) |
| |
|
| | |
| | response = self.model.create_chat_completion( |
| | messages=[ |
| | {"role": "user", "content": inputs} |
| | ], |
| | max_tokens=parameters.get("max_length", 30), |
| | logits_processor=logits_processors, |
| | temperature=parameters.get("temperature", 1), |
| | repeat_penalty=parameters.get("repeat_penalty", 1.0) |
| | ) |
| |
|
| | |
| | generated_text = response["choices"][0]["message"]["content"] |
| |
|
| | return [{"generated_text": generated_text, "allowed_token_ids": list(allowed_token_ids)}] |