tim1900 commited on
Commit
bd17486
·
verified ·
1 Parent(s): ca6c40c

Update modeling_bertchunker.py

Browse files
Files changed (1) hide show
  1. modeling_bertchunker.py +10 -9
modeling_bertchunker.py CHANGED
@@ -48,7 +48,6 @@ class BertChunker(PreTrainedModel):
48
  CLS=input_ids[:,0].unsqueeze(0)
49
  SEP=input_ids[:,-1].unsqueeze(0)
50
  input_ids=input_ids[:,1:-1]
51
- # model= model.to(device)
52
  self.eval()
53
  split_str_poses=[]
54
 
@@ -57,26 +56,28 @@ class BertChunker(PreTrainedModel):
57
 
58
  while windows_end <= input_ids.shape[1]:
59
  windows_end= windows_start + MAX_TOKENS-2
60
-
61
  ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
62
-
63
  ids=ids.to(self.device)
64
-
65
  output=self(input_ids=ids,attention_mask=attention_mask[:,:len(ids)])
66
  logits = output['logits'][:, 1:-1,:]
67
  is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
68
- greater_rows_indices = torch.where(is_left_greater)
69
 
70
  # null or not
71
- if greater_rows_indices[1].numel():
72
- split_token_idx = greater_rows_indices[1] + windows_start + 1
73
 
74
- split_str_pos=[tokens.token_to_chars(sp).start for sp in split_token_idx.tolist()]
 
75
 
76
  split_str_poses += split_str_pos
77
 
78
- windows_start = greater_rows_indices[1][-1] + windows_start
 
79
  else:
 
80
  windows_start = windows_end
81
 
82
  substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
 
48
  CLS=input_ids[:,0].unsqueeze(0)
49
  SEP=input_ids[:,-1].unsqueeze(0)
50
  input_ids=input_ids[:,1:-1]
 
51
  self.eval()
52
  split_str_poses=[]
53
 
 
56
 
57
  while windows_end <= input_ids.shape[1]:
58
  windows_end= windows_start + MAX_TOKENS-2
59
+
60
  ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
61
+
62
  ids=ids.to(self.device)
63
+
64
  output=self(input_ids=ids,attention_mask=attention_mask[:,:len(ids)])
65
  logits = output['logits'][:, 1:-1,:]
66
  is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
67
+ greater_rows_indices = torch.where(is_left_greater)[1].tolist()
68
 
69
  # null or not
70
+ if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
 
71
 
72
+
73
+ split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
74
 
75
  split_str_poses += split_str_pos
76
 
77
+ windows_start = greater_rows_indices[-1] + windows_start
78
+
79
  else:
80
+
81
  windows_start = windows_end
82
 
83
  substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]