Update modeling.py
Browse files- modeling.py +31 -0
modeling.py
CHANGED
|
@@ -171,3 +171,34 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 171 |
logits = self.lm_head(x)
|
| 172 |
|
| 173 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
logits = self.lm_head(x)
|
| 172 |
|
| 173 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def from_pretrained(cls, *args, **kwargs):
|
| 177 |
+
config = kwargs.pop('config', None)
|
| 178 |
+
if config is None:
|
| 179 |
+
config = CustomTransformerConfig.from_pretrained(*args, **kwargs)
|
| 180 |
+
model = cls(config)
|
| 181 |
+
state_dict = torch.load(kwargs['pretrained_model_name_or_path'], map_location='cpu')
|
| 182 |
+
model.load_state_dict(state_dict)
|
| 183 |
+
return model
|
| 184 |
+
|
| 185 |
+
def save_pretrained(self, save_directory):
|
| 186 |
+
torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
|
| 187 |
+
self.config.save_pretrained(save_directory)
|
| 188 |
+
|
| 189 |
+
def generate(self, input_ids, max_length=20, min_length=None, num_return_sequences=1):
|
| 190 |
+
logits_processor = LogitsProcessorList()
|
| 191 |
+
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
|
| 192 |
+
|
| 193 |
+
if min_length is not None:
|
| 194 |
+
logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
|
| 195 |
+
|
| 196 |
+
outputs = super().generate(
|
| 197 |
+
input_ids=input_ids,
|
| 198 |
+
max_length=max_length,
|
| 199 |
+
num_return_sequences=num_return_sequences,
|
| 200 |
+
logits_processor=logits_processor,
|
| 201 |
+
stopping_criteria=stopping_criteria
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return outputs
|