Update README.md
Browse files
README.md
CHANGED
|
@@ -245,14 +245,12 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
|
|
| 245 |
|
| 246 |
unchunk_tokens = 0
|
| 247 |
backup_pos = None
|
| 248 |
-
best_logits = torch.finfo(torch.float32).min
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
STEP = round(((MAX_TOKENS - 2)//2 )*1.75) #(MAX_TOKENS - 2)//2
|
| 252 |
print(f"Processing {input_ids.shape[1]} tokens...")
|
| 253 |
-
while windows_end <= input_ids.shape[1]
|
| 254 |
-
|
| 255 |
-
windows_end = windows_start + MAX_TOKENS - 2
|
| 256 |
ids = torch.cat((CLS, input_ids[:, windows_start:windows_end], SEP), 1)
|
| 257 |
ids = ids.to(model.device)
|
| 258 |
output = model(
|
|
@@ -279,10 +277,7 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
|
|
| 279 |
# manually chunk
|
| 280 |
if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
|
| 281 |
big_windows_end = max_tokens_per_chunk - unchunk_tokens
|
| 282 |
-
|
| 283 |
-
max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
|
| 284 |
-
else:
|
| 285 |
-
max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
|
| 286 |
if best_logits < max_value:
|
| 287 |
backup_pos = windows_start + max_index
|
| 288 |
|
|
@@ -295,15 +290,17 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
|
|
| 295 |
best_logits = torch.finfo(torch.float32).min
|
| 296 |
backup_pos = -1
|
| 297 |
unchunk_tokens = 0
|
| 298 |
-
is_chunk_start = True
|
| 299 |
|
| 300 |
# auto chunk
|
| 301 |
else:
|
|
|
|
| 302 |
if len(greater_rows_indices) >= 2:
|
| 303 |
for gi, (gri0,gri1) in enumerate(zip(greater_rows_indices[:-1],greater_rows_indices[1:])):
|
|
|
|
| 304 |
if gri1 - gri0 > max_tokens_per_chunk:
|
| 305 |
greater_rows_indices=greater_rows_indices[:gi+1]
|
| 306 |
break
|
|
|
|
| 307 |
split_str_pos = [tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices if sp > 0]
|
| 308 |
split_str_poses = split_str_poses + split_str_pos
|
| 309 |
token_pos = token_pos+ [sp + windows_start for sp in greater_rows_indices if sp > 0]
|
|
@@ -312,20 +309,20 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
|
|
| 312 |
best_logits = torch.finfo(torch.float32).min
|
| 313 |
backup_pos = -1
|
| 314 |
unchunk_tokens = 0
|
| 315 |
-
is_chunk_start = True
|
| 316 |
|
| 317 |
else:
|
| 318 |
|
| 319 |
-
unchunk_tokens_this_window = min(windows_end - windows_start,STEP)
|
|
|
|
|
|
|
| 320 |
# manually chunk
|
| 321 |
if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
|
| 322 |
big_windows_end = max_tokens_per_chunk - unchunk_tokens
|
| 323 |
-
if
|
| 324 |
-
|
| 325 |
-
else:
|
| 326 |
max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
|
| 327 |
-
|
| 328 |
-
|
| 329 |
|
| 330 |
|
| 331 |
windows_start = backup_pos
|
|
@@ -335,21 +332,16 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
|
|
| 335 |
best_logits = torch.finfo(torch.float32).min
|
| 336 |
backup_pos = -1
|
| 337 |
unchunk_tokens = 0
|
| 338 |
-
is_chunk_start = True
|
| 339 |
else:
|
| 340 |
# auto leave
|
| 341 |
-
if
|
| 342 |
max_value, max_index= logit_diff[:,1:].max(), logit_diff[:,1:].argmax() + 1
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
if best_logits < max_value:
|
| 347 |
-
best_logits = max_value
|
| 348 |
-
backup_pos = windows_start + max_index
|
| 349 |
|
| 350 |
unchunk_tokens = unchunk_tokens + STEP
|
| 351 |
windows_start = windows_start + STEP
|
| 352 |
-
is_chunk_start = False
|
| 353 |
|
| 354 |
substrings = [
|
| 355 |
text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses + [len(text)])
|
|
|
|
| 245 |
|
| 246 |
unchunk_tokens = 0
|
| 247 |
backup_pos = None
|
| 248 |
+
best_logits = torch.finfo(torch.float32).min
|
| 249 |
+
STEP = round(((MAX_TOKENS - 2)//2)*1.75 )
|
|
|
|
|
|
|
| 250 |
print(f"Processing {input_ids.shape[1]} tokens...")
|
| 251 |
+
# while windows_end <= input_ids.shape[1]:#璁板緱鏀规垚windstart
|
| 252 |
+
while windows_start < input_ids.shape[1]:#璁板緱鏀规垚windstart
|
| 253 |
+
windows_end = windows_start + MAX_TOKENS - 2
|
| 254 |
ids = torch.cat((CLS, input_ids[:, windows_start:windows_end], SEP), 1)
|
| 255 |
ids = ids.to(model.device)
|
| 256 |
output = model(
|
|
|
|
| 277 |
# manually chunk
|
| 278 |
if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
|
| 279 |
big_windows_end = max_tokens_per_chunk - unchunk_tokens
|
| 280 |
+
max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
|
|
|
|
|
|
|
|
|
|
| 281 |
if best_logits < max_value:
|
| 282 |
backup_pos = windows_start + max_index
|
| 283 |
|
|
|
|
| 290 |
best_logits = torch.finfo(torch.float32).min
|
| 291 |
backup_pos = -1
|
| 292 |
unchunk_tokens = 0
|
|
|
|
| 293 |
|
| 294 |
# auto chunk
|
| 295 |
else:
|
| 296 |
+
|
| 297 |
if len(greater_rows_indices) >= 2:
|
| 298 |
for gi, (gri0,gri1) in enumerate(zip(greater_rows_indices[:-1],greater_rows_indices[1:])):
|
| 299 |
+
|
| 300 |
if gri1 - gri0 > max_tokens_per_chunk:
|
| 301 |
greater_rows_indices=greater_rows_indices[:gi+1]
|
| 302 |
break
|
| 303 |
+
|
| 304 |
split_str_pos = [tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices if sp > 0]
|
| 305 |
split_str_poses = split_str_poses + split_str_pos
|
| 306 |
token_pos = token_pos+ [sp + windows_start for sp in greater_rows_indices if sp > 0]
|
|
|
|
| 309 |
best_logits = torch.finfo(torch.float32).min
|
| 310 |
backup_pos = -1
|
| 311 |
unchunk_tokens = 0
|
|
|
|
| 312 |
|
| 313 |
else:
|
| 314 |
|
| 315 |
+
# unchunk_tokens_this_window = min(windows_end - windows_start,STEP)
|
| 316 |
+
unchunk_tokens_this_window = min(windows_start+STEP,input_ids.shape[1]) - windows_start
|
| 317 |
+
|
| 318 |
# manually chunk
|
| 319 |
if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
|
| 320 |
big_windows_end = max_tokens_per_chunk - unchunk_tokens
|
| 321 |
+
if logit_diff.shape[1] > 1:
|
| 322 |
+
|
|
|
|
| 323 |
max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
|
| 324 |
+
if best_logits < max_value:
|
| 325 |
+
backup_pos = windows_start + max_index
|
| 326 |
|
| 327 |
|
| 328 |
windows_start = backup_pos
|
|
|
|
| 332 |
best_logits = torch.finfo(torch.float32).min
|
| 333 |
backup_pos = -1
|
| 334 |
unchunk_tokens = 0
|
|
|
|
| 335 |
else:
|
| 336 |
# auto leave
|
| 337 |
+
if logit_diff.shape[1] > 1:
|
| 338 |
max_value, max_index= logit_diff[:,1:].max(), logit_diff[:,1:].argmax() + 1
|
| 339 |
+
if best_logits < max_value:
|
| 340 |
+
best_logits = max_value
|
| 341 |
+
backup_pos = windows_start + max_index
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
unchunk_tokens = unchunk_tokens + STEP
|
| 344 |
windows_start = windows_start + STEP
|
|
|
|
| 345 |
|
| 346 |
substrings = [
|
| 347 |
text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses + [len(text)])
|