valcore commited on
Commit
4ec421a
·
verified ·
1 Parent(s): 88e9f49

Upload model

Browse files
Files changed (4) hide show
  1. BranchyModel.py +61 -2
  2. README.md +2 -2
  3. config.json +1 -0
  4. generation_config.json +4 -0
BranchyModel.py CHANGED
@@ -5,13 +5,15 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
  import copy
7
 
8
- from dataclasses import dataclass
9
  from torch import Tensor
10
  from .BranchyModelConfig import BranchyModelConfig
11
  from typing import List, Optional, Dict, Tuple
12
  from transformers import AutoModelForCausalLM, PreTrainedModel
13
  from transformers.modeling_outputs import CausalLMOutputWithPast
14
  from transformers.utils import ModelOutput
 
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
@@ -306,7 +308,7 @@ class BranchyCausalModel(PreTrainedModel):
306
  config: BranchyModelConfig):
307
  super().__init__(config)
308
  self.model = BranchyModel(config)
309
- self.head_thresholds = torch.tensor(config.head_thresholds).to(self.model.device)
310
  if config.confidence_metric == "breaking_ties":
311
  self.confidence_metric_fn = breaking_ties
312
  elif config.confidence_metric == "max":
@@ -314,6 +316,63 @@ class BranchyCausalModel(PreTrainedModel):
314
  else:
315
  raise ValueError("confidence_metric must be 'breaking_ties' or 'max'.")
316
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  def forward(self,
319
  input_ids: torch.LongTensor = None,
 
5
  import torch.nn.functional as F
6
  import copy
7
 
8
+ from dataclasses import dataclass
9
  from torch import Tensor
10
  from .BranchyModelConfig import BranchyModelConfig
11
  from typing import List, Optional, Dict, Tuple
12
  from transformers import AutoModelForCausalLM, PreTrainedModel
13
  from transformers.modeling_outputs import CausalLMOutputWithPast
14
  from transformers.utils import ModelOutput
15
+ from transformers.cache_utils import Cache, DynamicCache
16
+
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
 
308
  config: BranchyModelConfig):
309
  super().__init__(config)
310
  self.model = BranchyModel(config)
311
+ self.head_thresholds = torch.tensor(config.head_thresholds).to(config.device)
312
  if config.confidence_metric == "breaking_ties":
313
  self.confidence_metric_fn = breaking_ties
314
  elif config.confidence_metric == "max":
 
316
  else:
317
  raise ValueError("confidence_metric must be 'breaking_ties' or 'max'.")
318
  self.post_init()
319
+
320
+
321
+ def prepare_inputs_for_generation(
322
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
323
+ ):
324
+ if past_key_values is not None:
325
+ if isinstance(past_key_values, Cache):
326
+ cache_length = past_key_values.get_seq_length()
327
+ past_length = past_key_values.seen_tokens
328
+ max_cache_length = past_key_values.get_max_length()
329
+ else:
330
+ cache_length = past_length = past_key_values[0][0].shape[2]
331
+ max_cache_length = None
332
+
333
+ # Keep only the unprocessed tokens:
334
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
335
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
336
+ # input)
337
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
338
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
339
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
340
+ # input_ids based on the past_length.
341
+ elif past_length < input_ids.shape[1]:
342
+ input_ids = input_ids[:, past_length:]
343
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
344
+
345
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
346
+ if (
347
+ max_cache_length is not None
348
+ and attention_mask is not None
349
+ and cache_length + input_ids.shape[1] > max_cache_length
350
+ ):
351
+ attention_mask = attention_mask[:, -max_cache_length:]
352
+
353
+ position_ids = kwargs.get("position_ids", None)
354
+ if attention_mask is not None and position_ids is None:
355
+ # create position_ids on the fly for batch generation
356
+ position_ids = attention_mask.long().cumsum(-1) - 1
357
+ position_ids.masked_fill_(attention_mask == 0, 1)
358
+ if past_key_values:
359
+ position_ids = position_ids[:, -input_ids.shape[1] :]
360
+
361
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
362
+ if inputs_embeds is not None and past_key_values is None:
363
+ model_inputs = {"inputs_embeds": inputs_embeds}
364
+ else:
365
+ model_inputs = {"input_ids": input_ids}
366
+
367
+ model_inputs.update(
368
+ {
369
+ "position_ids": position_ids,
370
+ "past_key_values": past_key_values,
371
+ "use_cache": kwargs.get("use_cache"),
372
+ "attention_mask": attention_mask,
373
+ }
374
+ )
375
+ return model_inputs
376
 
377
  def forward(self,
378
  input_ids: torch.LongTensor = None,
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- library_name: transformers
3
- license: mit
4
  language:
5
  - en
 
 
6
  pipeline_tag: text-generation
7
  ---
8
 
 
1
  ---
 
 
2
  language:
3
  - en
4
+ license: mit
5
+ library_name: transformers
6
  pipeline_tag: text-generation
7
  ---
8
 
config.json CHANGED
@@ -15,6 +15,7 @@
15
  "branch_number": 4,
16
  "confidence_metric": "breaking_ties",
17
  "copy_lm_head": false,
 
18
  "head_thresholds": [
19
  10.0,
20
  10.0,
 
15
  "branch_number": 4,
16
  "confidence_metric": "breaking_ties",
17
  "copy_lm_head": false,
18
+ "device": "cuda:1",
19
  "head_thresholds": [
20
  10.0,
21
  10.0,
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.40.2"
4
+ }