Commit
·
00d76a9
1
Parent(s):
4caf2e9
Change past_key_values indexing from hyena (inherited from Together's Evo 1 HF code) to hcl, hcm and hcs
Browse files- modeling_evo2.py +15 -3
modeling_evo2.py
CHANGED
|
@@ -104,17 +104,29 @@ class Evo2ForCausalLM(Evo2PreTrainedModel):
|
|
| 104 |
past_key_values = self.backbone.initialize_inference_params()
|
| 105 |
batch_size = input_ids.shape[0]
|
| 106 |
past_key_values["mha"].max_batch_size = batch_size
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
else:
|
| 109 |
seqlen_offset = past_key_values["mha"].seqlen_offset
|
| 110 |
if seqlen_offset == 0:
|
| 111 |
# second loop through generate will have prompt_len + 1 as seqlen
|
| 112 |
seqlen_offset = input_ids.shape[-1] - 1
|
| 113 |
-
past_key_values["hyena"].seqlen_offset = seqlen_offset
|
|
|
|
|
|
|
|
|
|
| 114 |
past_key_values["mha"].seqlen_offset = seqlen_offset
|
| 115 |
else:
|
| 116 |
past_key_values["mha"].seqlen_offset += 1
|
| 117 |
-
past_key_values["hyena"].seqlen_offset += 1
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
inputs = input_ids[
|
| 120 |
:,
|
|
|
|
| 104 |
past_key_values = self.backbone.initialize_inference_params()
|
| 105 |
batch_size = input_ids.shape[0]
|
| 106 |
past_key_values["mha"].max_batch_size = batch_size
|
| 107 |
+
# This line is inherited from Together's HF code. It needs to
|
| 108 |
+
# change for Evo 2 (specifically, we need to access hcl, hcm and
|
| 109 |
+
# hcs instead).
|
| 110 |
+
# past_key_values["hyena"].max_batch_size = batch_size
|
| 111 |
+
past_key_values["hcl"].max_batch_size = batch_size
|
| 112 |
+
past_key_values["hcm"].max_batch_size = batch_size
|
| 113 |
+
past_key_values["hcs"].max_batch_size = batch_size
|
| 114 |
else:
|
| 115 |
seqlen_offset = past_key_values["mha"].seqlen_offset
|
| 116 |
if seqlen_offset == 0:
|
| 117 |
# second loop through generate will have prompt_len + 1 as seqlen
|
| 118 |
seqlen_offset = input_ids.shape[-1] - 1
|
| 119 |
+
# past_key_values["hyena"].seqlen_offset = seqlen_offset
|
| 120 |
+
past_key_values["hcl"].seqlen_offset = seqlen_offset
|
| 121 |
+
past_key_values["hcm"].seqlen_offset = seqlen_offset
|
| 122 |
+
past_key_values["hcs"].seqlen_offset = seqlen_offset
|
| 123 |
past_key_values["mha"].seqlen_offset = seqlen_offset
|
| 124 |
else:
|
| 125 |
past_key_values["mha"].seqlen_offset += 1
|
| 126 |
+
# past_key_values["hyena"].seqlen_offset += 1
|
| 127 |
+
past_key_values["hcl"].seqlen_offset += 1
|
| 128 |
+
past_key_values["hcs"].seqlen_offset += 1
|
| 129 |
+
past_key_values["hcm"].seqlen_offset += 1
|
| 130 |
|
| 131 |
inputs = input_ids[
|
| 132 |
:,
|