taylorj94 commited on
Commit
28741c4
·
verified ·
1 Parent(s): 2d0dd6d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -24
handler.py CHANGED
@@ -1,33 +1,26 @@
1
  import os
2
  import torch
3
  from llama_cpp import Llama # Library for GGUF model handling
4
- from typing import Any, List, Dict, Union
5
- from transformers import LogitsProcessorList
6
 
7
 
8
- class FixedVocabLogitsProcessor(torch.nn.Module):
9
  """
10
  A custom logits processor for GGUF-compatible models.
11
  """
12
 
13
  def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
14
- super().__init__()
15
  self.allowed_ids = allowed_ids
16
  self.fill_value = fill_value
17
 
18
- def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
19
  """
20
  Modify logits to restrict to allowed token IDs.
21
- Args:
22
- input_ids (torch.Tensor): Input IDs.
23
- scores (torch.Tensor): Logits scores.
24
- Returns:
25
- torch.Tensor: Modified logits.
26
  """
27
- for token_id in range(scores.size(1)):
28
  if token_id not in self.allowed_ids:
29
- scores[:, token_id] = self.fill_value
30
- return scores
31
 
32
 
33
  class EndpointHandler:
@@ -38,8 +31,8 @@ class EndpointHandler:
38
  path (str): Path to the GGUF file.
39
  """
40
  self.model = Llama.from_pretrained(
41
- repo_id="taylorj94/Llama-3.2-1B",
42
- filename="model.gguf",
43
  )
44
  self.tokenizer = self.model.tokenizer # GGUF-specific tokenizer, if available
45
 
@@ -54,8 +47,11 @@ class EndpointHandler:
54
  # Extract inputs and parameters
55
  inputs = data.pop("inputs", data)
56
  parameters = data.pop("parameters", {})
57
- vocab_list = data.pop("vocab_list", None)
 
58
 
 
 
59
  if not vocab_list:
60
  raise ValueError("You must provide a 'vocab_list' to define allowed tokens.")
61
 
@@ -68,19 +64,15 @@ class EndpointHandler:
68
  # Tokenize input
69
  input_ids = self.model.tokenize(inputs)
70
 
71
- # Prepare logits processor
72
- logits_processor = LogitsProcessorList([
73
- FixedVocabLogitsProcessor(allowed_ids)
74
- ])
75
-
76
  # Perform inference
77
  output_ids = self.model.generate(
78
- input_ids=input_ids,
79
  max_tokens=parameters.get("max_length", 30),
80
- logits_processor=logits_processor
81
  )
82
 
83
  # Decode the output
84
  generated_text = self.model.detokenize(output_ids)
85
 
86
- return [{"generated_text": generated_text}]
 
1
  import os
2
  import torch
3
  from llama_cpp import Llama # Library for GGUF model handling
4
+ from typing import Any, List, Dict
 
5
 
6
 
7
+ class FixedVocabLogitsProcessor:
8
  """
9
  A custom logits processor for GGUF-compatible models.
10
  """
11
 
12
  def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
 
13
  self.allowed_ids = allowed_ids
14
  self.fill_value = fill_value
15
 
16
+ def apply(self, logits: torch.FloatTensor):
17
  """
18
  Modify logits to restrict to allowed token IDs.
 
 
 
 
 
19
  """
20
+ for token_id in range(len(logits)):
21
  if token_id not in self.allowed_ids:
22
+ logits[token_id] = self.fill_value
23
+ return logits
24
 
25
 
26
  class EndpointHandler:
 
31
  path (str): Path to the GGUF file.
32
  """
33
  self.model = Llama.from_pretrained(
34
+ repo_id="taylorj94/Llama-3.2-1B",
35
+ filename="model.gguf",
36
  )
37
  self.tokenizer = self.model.tokenizer # GGUF-specific tokenizer, if available
38
 
 
47
  # Extract inputs and parameters
48
  inputs = data.pop("inputs", data)
49
  parameters = data.pop("parameters", {})
50
+ print('Debug 1')
51
+ vocab_list = data.pop("vocab_list", [])
52
 
53
+ print('Debug 2')
54
+
55
  if not vocab_list:
56
  raise ValueError("You must provide a 'vocab_list' to define allowed tokens.")
57
 
 
64
  # Tokenize input
65
  input_ids = self.model.tokenize(inputs)
66
 
67
+ print('Debug 3')
 
 
 
 
68
  # Perform inference
69
  output_ids = self.model.generate(
70
+ input_ids,
71
  max_tokens=parameters.get("max_length", 30),
72
+ logits_processor=lambda logits: FixedVocabLogitsProcessor(allowed_ids).apply(logits)
73
  )
74
 
75
  # Decode the output
76
  generated_text = self.model.detokenize(output_ids)
77
 
78
+ return [{"generated_text": generated_text}]