Update handler.py
Browse files- handler.py +8 -7
handler.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import Any, List, Dict
|
|
| 2 |
from llama_cpp import Llama
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
-
from transformers import AutoTokenizer
|
| 6 |
|
| 7 |
class EndpointHandler:
|
| 8 |
def __init__(self, path=""):
|
|
@@ -49,15 +49,16 @@ class EndpointHandler:
|
|
| 49 |
if not inputs:
|
| 50 |
raise ValueError("The 'inputs' field is required.")
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
if vocab_list:
|
| 55 |
# Define allowed tokens dynamically
|
| 56 |
allowed_token_ids = self.get_allowed_token_ids(vocab_list)
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
|
|
|
| 61 |
|
| 62 |
# Tokenize input
|
| 63 |
input_ids = torch.tensor([self.tokenizer.encode(inputs, add_special_tokens=False)])
|
|
@@ -68,7 +69,7 @@ class EndpointHandler:
|
|
| 68 |
{"role": "user", "content": inputs}
|
| 69 |
],
|
| 70 |
max_tokens=parameters.get("max_length", 30),
|
| 71 |
-
logits_processor=
|
| 72 |
temperature=parameters.get("temperature", 1),
|
| 73 |
repeat_penalty=parameters.get("repeat_penalty", 1.0)
|
| 74 |
)
|
|
|
|
| 2 |
from llama_cpp import Llama
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
+
from transformers import AutoTokenizer, LogitsProcessorList
|
| 6 |
|
| 7 |
class EndpointHandler:
|
| 8 |
def __init__(self, path=""):
|
|
|
|
| 49 |
if not inputs:
|
| 50 |
raise ValueError("The 'inputs' field is required.")
|
| 51 |
|
| 52 |
+
# Prepare logits processor
|
| 53 |
+
logits_processors = None
|
| 54 |
if vocab_list:
|
| 55 |
# Define allowed tokens dynamically
|
| 56 |
allowed_token_ids = self.get_allowed_token_ids(vocab_list)
|
| 57 |
|
| 58 |
+
# Create LogitsProcessorList with filtering function
|
| 59 |
+
logits_processors = LogitsProcessorList([
|
| 60 |
+
lambda input_ids, scores: self.filter_allowed_tokens(input_ids, scores, allowed_token_ids)
|
| 61 |
+
])
|
| 62 |
|
| 63 |
# Tokenize input
|
| 64 |
input_ids = torch.tensor([self.tokenizer.encode(inputs, add_special_tokens=False)])
|
|
|
|
| 69 |
{"role": "user", "content": inputs}
|
| 70 |
],
|
| 71 |
max_tokens=parameters.get("max_length", 30),
|
| 72 |
+
logits_processor=logits_processors, # Pass the LogitsProcessorList here
|
| 73 |
temperature=parameters.get("temperature", 1),
|
| 74 |
repeat_penalty=parameters.get("repeat_penalty", 1.0)
|
| 75 |
)
|