m97j commited on
Commit
deb604d
·
1 Parent(s): ac17ed0

fix(llm_model): align token chunking and prefix handling with engine

Browse files
Files changed (1) hide show
  1. models/llm_model.py +16 -21
models/llm_model.py CHANGED
@@ -37,12 +37,12 @@ class LLMService:
37
  if isinstance(prompt, torch.Tensor):
38
  if mode.lower() == "instruct":
39
  if "instruct" not in self._prefix_cache:
40
- self._prefix_cache["instruct"] = self.engine.tokenize("/no_think\n")
41
  return torch.cat([self._prefix_cache["instruct"], prompt], dim=-1)
42
 
43
  if mode.lower() == "think":
44
  if "think" not in self._prefix_cache:
45
- self._prefix_cache["think"] = self.engine.tokenize("/think\n")
46
  return torch.cat([self._prefix_cache["think"], prompt], dim=-1)
47
 
48
  return prompt
@@ -73,41 +73,36 @@ class LLMService:
73
  """
74
  Util: split text into token chunks not exceeding max_tokens,
75
  trying to respect sentence boundaries where possible.
 
76
  """
77
  max_tokens = min(14000, max_tokens)
78
  encodings = self.engine.tokenize(text, return_offsets=True)
79
- tokens = encodings["input_ids"]
80
- offsets = encodings["offset_mapping"]
81
 
82
- # detect sentence boundaries
83
  sentence_boundaries = set(split_content(text, return_boundaries=True))
84
 
85
  chunks = []
86
- current_chunk = []
87
- current_len = 0
88
 
89
- for i, tok in enumerate(tokens):
90
- current_chunk.append(tok)
91
- current_len += 1
92
- _, end = offsets[i]
93
-
94
- if current_len >= max_tokens:
95
  boundary_candidates = [b for b in sentence_boundaries if b <= end]
96
  if boundary_candidates:
97
  boundary_index = max(boundary_candidates)
98
  cutoff_token_index = max(
99
  j for j, (s, e) in enumerate(offsets[:i+1]) if e <= boundary_index
100
  )
101
- chunks.append(current_chunk[:cutoff_token_index+1])
102
- current_chunk = current_chunk[cutoff_token_index+1:]
103
- current_len = len(current_chunk)
104
  else:
105
- chunks.append(current_chunk)
106
- current_chunk = []
107
- current_len = 0
108
 
109
- if current_chunk:
110
- chunks.append(current_chunk)
111
 
112
  return chunks
113
 
 
37
  if isinstance(prompt, torch.Tensor):
38
  if mode.lower() == "instruct":
39
  if "instruct" not in self._prefix_cache:
40
+ self._prefix_cache["instruct"] = self.engine.tokenize("/no_think\n")["input_ids"]
41
  return torch.cat([self._prefix_cache["instruct"], prompt], dim=-1)
42
 
43
  if mode.lower() == "think":
44
  if "think" not in self._prefix_cache:
45
+ self._prefix_cache["think"] = self.engine.tokenize("/think\n")["input_ids"]
46
  return torch.cat([self._prefix_cache["think"], prompt], dim=-1)
47
 
48
  return prompt
 
73
  """
74
  Util: split text into token chunks not exceeding max_tokens,
75
  trying to respect sentence boundaries where possible.
76
+ Returns: List[torch.Tensor] (each tensor is a chunk of token IDs, still on CPU)
77
  """
78
  max_tokens = min(14000, max_tokens)
79
  encodings = self.engine.tokenize(text, return_offsets=True)
80
+ tokens = encodings["input_ids"][0] # shape: (N,)
81
+ offsets = encodings["offset_mapping"][0] # shape: (N, 2)
82
 
83
+ # detect sentence boundaries (character-level positions in original text)
84
  sentence_boundaries = set(split_content(text, return_boundaries=True))
85
 
86
  chunks = []
87
+ start = 0
 
88
 
89
+ for i, (_, end) in enumerate(offsets):
90
+ # If the current chunk length is greater than max_tokens, break it based on the boundary.
91
+ if (i - start + 1) >= max_tokens:
 
 
 
92
  boundary_candidates = [b for b in sentence_boundaries if b <= end]
93
  if boundary_candidates:
94
  boundary_index = max(boundary_candidates)
95
  cutoff_token_index = max(
96
  j for j, (s, e) in enumerate(offsets[:i+1]) if e <= boundary_index
97
  )
98
+ chunks.append(tokens[start:cutoff_token_index+1])
99
+ start = cutoff_token_index + 1
 
100
  else:
101
+ chunks.append(tokens[start:i+1])
102
+ start = i + 1
 
103
 
104
+ if start < len(tokens):
105
+ chunks.append(tokens[start:])
106
 
107
  return chunks
108