ishanjmukherjee commited on
Commit
00d76a9
·
1 Parent(s): 4caf2e9

Change past_key_values indexing from hyena (inherited from Together's Evo 1 HF code) to hcl, hcm and hcs

Browse files
Files changed (1) hide show
  1. modeling_evo2.py +15 -3
modeling_evo2.py CHANGED
@@ -104,17 +104,29 @@ class Evo2ForCausalLM(Evo2PreTrainedModel):
104
  past_key_values = self.backbone.initialize_inference_params()
105
  batch_size = input_ids.shape[0]
106
  past_key_values["mha"].max_batch_size = batch_size
107
- past_key_values["hyena"].max_batch_size = batch_size
 
 
 
 
 
 
108
  else:
109
  seqlen_offset = past_key_values["mha"].seqlen_offset
110
  if seqlen_offset == 0:
111
  # second loop through generate will have prompt_len + 1 as seqlen
112
  seqlen_offset = input_ids.shape[-1] - 1
113
- past_key_values["hyena"].seqlen_offset = seqlen_offset
 
 
 
114
  past_key_values["mha"].seqlen_offset = seqlen_offset
115
  else:
116
  past_key_values["mha"].seqlen_offset += 1
117
- past_key_values["hyena"].seqlen_offset += 1
 
 
 
118
 
119
  inputs = input_ids[
120
  :,
 
104
  past_key_values = self.backbone.initialize_inference_params()
105
  batch_size = input_ids.shape[0]
106
  past_key_values["mha"].max_batch_size = batch_size
107
+ # This line is inherited from Together's HF code. It needs to
108
+ # change for Evo 2 (specifically, we need to access hcl, hcm and
109
+ # hcs instead).
110
+ # past_key_values["hyena"].max_batch_size = batch_size
111
+ past_key_values["hcl"].max_batch_size = batch_size
112
+ past_key_values["hcm"].max_batch_size = batch_size
113
+ past_key_values["hcs"].max_batch_size = batch_size
114
  else:
115
  seqlen_offset = past_key_values["mha"].seqlen_offset
116
  if seqlen_offset == 0:
117
  # second loop through generate will have prompt_len + 1 as seqlen
118
  seqlen_offset = input_ids.shape[-1] - 1
119
+ # past_key_values["hyena"].seqlen_offset = seqlen_offset
120
+ past_key_values["hcl"].seqlen_offset = seqlen_offset
121
+ past_key_values["hcm"].seqlen_offset = seqlen_offset
122
+ past_key_values["hcs"].seqlen_offset = seqlen_offset
123
  past_key_values["mha"].seqlen_offset = seqlen_offset
124
  else:
125
  past_key_values["mha"].seqlen_offset += 1
126
+ # past_key_values["hyena"].seqlen_offset += 1
127
+ past_key_values["hcl"].seqlen_offset += 1
128
+ past_key_values["hcs"].seqlen_offset += 1
129
+ past_key_values["hcm"].seqlen_offset += 1
130
 
131
  inputs = input_ids[
132
  :,