Spaces:
Sleeping
Sleeping
fix(llm_model): align token chunking and prefix handling with engine
Browse files- models/llm_model.py +16 -21
models/llm_model.py
CHANGED
|
@@ -37,12 +37,12 @@ class LLMService:
|
|
| 37 |
if isinstance(prompt, torch.Tensor):
|
| 38 |
if mode.lower() == "instruct":
|
| 39 |
if "instruct" not in self._prefix_cache:
|
| 40 |
-
self._prefix_cache["instruct"] = self.engine.tokenize("/no_think\n")
|
| 41 |
return torch.cat([self._prefix_cache["instruct"], prompt], dim=-1)
|
| 42 |
|
| 43 |
if mode.lower() == "think":
|
| 44 |
if "think" not in self._prefix_cache:
|
| 45 |
-
self._prefix_cache["think"] = self.engine.tokenize("/think\n")
|
| 46 |
return torch.cat([self._prefix_cache["think"], prompt], dim=-1)
|
| 47 |
|
| 48 |
return prompt
|
|
@@ -73,41 +73,36 @@ class LLMService:
|
|
| 73 |
"""
|
| 74 |
Util: split text into token chunks not exceeding max_tokens,
|
| 75 |
trying to respect sentence boundaries where possible.
|
|
|
|
| 76 |
"""
|
| 77 |
max_tokens = min(14000, max_tokens)
|
| 78 |
encodings = self.engine.tokenize(text, return_offsets=True)
|
| 79 |
-
tokens = encodings["input_ids"]
|
| 80 |
-
offsets = encodings["offset_mapping"]
|
| 81 |
|
| 82 |
-
# detect sentence boundaries
|
| 83 |
sentence_boundaries = set(split_content(text, return_boundaries=True))
|
| 84 |
|
| 85 |
chunks = []
|
| 86 |
-
|
| 87 |
-
current_len = 0
|
| 88 |
|
| 89 |
-
for i,
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
_, end = offsets[i]
|
| 93 |
-
|
| 94 |
-
if current_len >= max_tokens:
|
| 95 |
boundary_candidates = [b for b in sentence_boundaries if b <= end]
|
| 96 |
if boundary_candidates:
|
| 97 |
boundary_index = max(boundary_candidates)
|
| 98 |
cutoff_token_index = max(
|
| 99 |
j for j, (s, e) in enumerate(offsets[:i+1]) if e <= boundary_index
|
| 100 |
)
|
| 101 |
-
chunks.append(
|
| 102 |
-
|
| 103 |
-
current_len = len(current_chunk)
|
| 104 |
else:
|
| 105 |
-
chunks.append(
|
| 106 |
-
|
| 107 |
-
current_len = 0
|
| 108 |
|
| 109 |
-
if
|
| 110 |
-
chunks.append(
|
| 111 |
|
| 112 |
return chunks
|
| 113 |
|
|
|
|
| 37 |
if isinstance(prompt, torch.Tensor):
|
| 38 |
if mode.lower() == "instruct":
|
| 39 |
if "instruct" not in self._prefix_cache:
|
| 40 |
+
self._prefix_cache["instruct"] = self.engine.tokenize("/no_think\n")["input_ids"]
|
| 41 |
return torch.cat([self._prefix_cache["instruct"], prompt], dim=-1)
|
| 42 |
|
| 43 |
if mode.lower() == "think":
|
| 44 |
if "think" not in self._prefix_cache:
|
| 45 |
+
self._prefix_cache["think"] = self.engine.tokenize("/think\n")["input_ids"]
|
| 46 |
return torch.cat([self._prefix_cache["think"], prompt], dim=-1)
|
| 47 |
|
| 48 |
return prompt
|
|
|
|
| 73 |
"""
|
| 74 |
Util: split text into token chunks not exceeding max_tokens,
|
| 75 |
trying to respect sentence boundaries where possible.
|
| 76 |
+
Returns: List[torch.Tensor] (each tensor is a chunk of token IDs, still on CPU)
|
| 77 |
"""
|
| 78 |
max_tokens = min(14000, max_tokens)
|
| 79 |
encodings = self.engine.tokenize(text, return_offsets=True)
|
| 80 |
+
tokens = encodings["input_ids"][0] # shape: (N,)
|
| 81 |
+
offsets = encodings["offset_mapping"][0] # shape: (N, 2)
|
| 82 |
|
| 83 |
+
# detect sentence boundaries (character-level positions in original text)
|
| 84 |
sentence_boundaries = set(split_content(text, return_boundaries=True))
|
| 85 |
|
| 86 |
chunks = []
|
| 87 |
+
start = 0
|
|
|
|
| 88 |
|
| 89 |
+
for i, (_, end) in enumerate(offsets):
|
| 90 |
+
# If the current chunk length is greater than max_tokens, break it based on the boundary.
|
| 91 |
+
if (i - start + 1) >= max_tokens:
|
|
|
|
|
|
|
|
|
|
| 92 |
boundary_candidates = [b for b in sentence_boundaries if b <= end]
|
| 93 |
if boundary_candidates:
|
| 94 |
boundary_index = max(boundary_candidates)
|
| 95 |
cutoff_token_index = max(
|
| 96 |
j for j, (s, e) in enumerate(offsets[:i+1]) if e <= boundary_index
|
| 97 |
)
|
| 98 |
+
chunks.append(tokens[start:cutoff_token_index+1])
|
| 99 |
+
start = cutoff_token_index + 1
|
|
|
|
| 100 |
else:
|
| 101 |
+
chunks.append(tokens[start:i+1])
|
| 102 |
+
start = i + 1
|
|
|
|
| 103 |
|
| 104 |
+
if start < len(tokens):
|
| 105 |
+
chunks.append(tokens[start:])
|
| 106 |
|
| 107 |
return chunks
|
| 108 |
|