tim1900 commited on
Commit
75880dd
verified
1 Parent(s): 472b08f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +20 -28
README.md CHANGED
@@ -245,14 +245,12 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
245
 
246
  unchunk_tokens = 0
247
  backup_pos = None
248
- best_logits = torch.finfo(torch.float32).min
249
- is_chunk_start = True
250
-
251
- STEP = round(((MAX_TOKENS - 2)//2 )*1.75) #(MAX_TOKENS - 2)//2
252
  print(f"Processing {input_ids.shape[1]} tokens...")
253
- while windows_end <= input_ids.shape[1]:
254
-
255
- windows_end = windows_start + MAX_TOKENS - 2
256
  ids = torch.cat((CLS, input_ids[:, windows_start:windows_end], SEP), 1)
257
  ids = ids.to(model.device)
258
  output = model(
@@ -279,10 +277,7 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
279
  # manually chunk
280
  if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
281
  big_windows_end = max_tokens_per_chunk - unchunk_tokens
282
- if is_chunk_start:
283
- max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
284
- else:
285
- max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
286
  if best_logits < max_value:
287
  backup_pos = windows_start + max_index
288
 
@@ -295,15 +290,17 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
295
  best_logits = torch.finfo(torch.float32).min
296
  backup_pos = -1
297
  unchunk_tokens = 0
298
- is_chunk_start = True
299
 
300
  # auto chunk
301
  else:
 
302
  if len(greater_rows_indices) >= 2:
303
  for gi, (gri0,gri1) in enumerate(zip(greater_rows_indices[:-1],greater_rows_indices[1:])):
 
304
  if gri1 - gri0 > max_tokens_per_chunk:
305
  greater_rows_indices=greater_rows_indices[:gi+1]
306
  break
 
307
  split_str_pos = [tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices if sp > 0]
308
  split_str_poses = split_str_poses + split_str_pos
309
  token_pos = token_pos+ [sp + windows_start for sp in greater_rows_indices if sp > 0]
@@ -312,20 +309,20 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
312
  best_logits = torch.finfo(torch.float32).min
313
  backup_pos = -1
314
  unchunk_tokens = 0
315
- is_chunk_start = True
316
 
317
  else:
318
 
319
- unchunk_tokens_this_window = min(windows_end - windows_start,STEP)
 
 
320
  # manually chunk
321
  if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
322
  big_windows_end = max_tokens_per_chunk - unchunk_tokens
323
- if is_chunk_start:
324
- max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
325
- else:
326
  max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
327
- if best_logits < max_value:
328
- backup_pos = windows_start + max_index
329
 
330
 
331
  windows_start = backup_pos
@@ -335,21 +332,16 @@ def chunk_text_with_max_chunk_size(model, text, tokenizer, prob_threshold=0.5,ma
335
  best_logits = torch.finfo(torch.float32).min
336
  backup_pos = -1
337
  unchunk_tokens = 0
338
- is_chunk_start = True
339
  else:
340
  # auto leave
341
- if is_chunk_start:
342
  max_value, max_index= logit_diff[:,1:].max(), logit_diff[:,1:].argmax() + 1
343
-
344
- else:
345
- max_value, max_index= logit_diff[:,1:].max(), logit_diff[:,1:].argmax() + 1
346
- if best_logits < max_value:
347
- best_logits = max_value
348
- backup_pos = windows_start + max_index
349
 
350
  unchunk_tokens = unchunk_tokens + STEP
351
  windows_start = windows_start + STEP
352
- is_chunk_start = False
353
 
354
  substrings = [
355
  text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses + [len(text)])
 
245
 
246
  unchunk_tokens = 0
247
  backup_pos = None
248
+ best_logits = torch.finfo(torch.float32).min
249
+ STEP = round(((MAX_TOKENS - 2)//2)*1.75 )
 
 
250
  print(f"Processing {input_ids.shape[1]} tokens...")
251
+ # while windows_end <= input_ids.shape[1]:#璁板緱鏀规垚windstart
252
+ while windows_start < input_ids.shape[1]:#璁板緱鏀规垚windstart
253
+ windows_end = windows_start + MAX_TOKENS - 2
254
  ids = torch.cat((CLS, input_ids[:, windows_start:windows_end], SEP), 1)
255
  ids = ids.to(model.device)
256
  output = model(
 
277
  # manually chunk
278
  if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
279
  big_windows_end = max_tokens_per_chunk - unchunk_tokens
280
+ max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
 
 
 
281
  if best_logits < max_value:
282
  backup_pos = windows_start + max_index
283
 
 
290
  best_logits = torch.finfo(torch.float32).min
291
  backup_pos = -1
292
  unchunk_tokens = 0
 
293
 
294
  # auto chunk
295
  else:
296
+
297
  if len(greater_rows_indices) >= 2:
298
  for gi, (gri0,gri1) in enumerate(zip(greater_rows_indices[:-1],greater_rows_indices[1:])):
299
+
300
  if gri1 - gri0 > max_tokens_per_chunk:
301
  greater_rows_indices=greater_rows_indices[:gi+1]
302
  break
303
+
304
  split_str_pos = [tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices if sp > 0]
305
  split_str_poses = split_str_poses + split_str_pos
306
  token_pos = token_pos+ [sp + windows_start for sp in greater_rows_indices if sp > 0]
 
309
  best_logits = torch.finfo(torch.float32).min
310
  backup_pos = -1
311
  unchunk_tokens = 0
 
312
 
313
  else:
314
 
315
+ # unchunk_tokens_this_window = min(windows_end - windows_start,STEP)
316
+ unchunk_tokens_this_window = min(windows_start+STEP,input_ids.shape[1]) - windows_start
317
+
318
  # manually chunk
319
  if unchunk_tokens + unchunk_tokens_this_window > max_tokens_per_chunk:
320
  big_windows_end = max_tokens_per_chunk - unchunk_tokens
321
+ if logit_diff.shape[1] > 1:
322
+
 
323
  max_value, max_index= logit_diff[:,1:big_windows_end].max(), logit_diff[:,1:big_windows_end].argmax() + 1
324
+ if best_logits < max_value:
325
+ backup_pos = windows_start + max_index
326
 
327
 
328
  windows_start = backup_pos
 
332
  best_logits = torch.finfo(torch.float32).min
333
  backup_pos = -1
334
  unchunk_tokens = 0
 
335
  else:
336
  # auto leave
337
+ if logit_diff.shape[1] > 1:
338
  max_value, max_index= logit_diff[:,1:].max(), logit_diff[:,1:].argmax() + 1
339
+ if best_logits < max_value:
340
+ best_logits = max_value
341
+ backup_pos = windows_start + max_index
 
 
 
342
 
343
  unchunk_tokens = unchunk_tokens + STEP
344
  windows_start = windows_start + STEP
 
345
 
346
  substrings = [
347
  text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses + [len(text)])