Upload model
Browse files- BranchyModel.py +61 -2
- README.md +2 -2
- config.json +1 -0
- generation_config.json +4 -0
BranchyModel.py
CHANGED
|
@@ -5,13 +5,15 @@ import torch.nn as nn
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
import copy
|
| 7 |
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
from torch import Tensor
|
| 10 |
from .BranchyModelConfig import BranchyModelConfig
|
| 11 |
from typing import List, Optional, Dict, Tuple
|
| 12 |
from transformers import AutoModelForCausalLM, PreTrainedModel
|
| 13 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 14 |
from transformers.utils import ModelOutput
|
|
|
|
|
|
|
| 15 |
logging.basicConfig(level=logging.INFO)
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
@@ -306,7 +308,7 @@ class BranchyCausalModel(PreTrainedModel):
|
|
| 306 |
config: BranchyModelConfig):
|
| 307 |
super().__init__(config)
|
| 308 |
self.model = BranchyModel(config)
|
| 309 |
-
self.head_thresholds = torch.tensor(config.head_thresholds).to(
|
| 310 |
if config.confidence_metric == "breaking_ties":
|
| 311 |
self.confidence_metric_fn = breaking_ties
|
| 312 |
elif config.confidence_metric == "max":
|
|
@@ -314,6 +316,63 @@ class BranchyCausalModel(PreTrainedModel):
|
|
| 314 |
else:
|
| 315 |
raise ValueError("confidence_metric must be 'breaking_ties' or 'max'.")
|
| 316 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
def forward(self,
|
| 319 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
import copy
|
| 7 |
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
from torch import Tensor
|
| 10 |
from .BranchyModelConfig import BranchyModelConfig
|
| 11 |
from typing import List, Optional, Dict, Tuple
|
| 12 |
from transformers import AutoModelForCausalLM, PreTrainedModel
|
| 13 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 14 |
from transformers.utils import ModelOutput
|
| 15 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 16 |
+
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
|
|
|
| 308 |
config: BranchyModelConfig):
|
| 309 |
super().__init__(config)
|
| 310 |
self.model = BranchyModel(config)
|
| 311 |
+
self.head_thresholds = torch.tensor(config.head_thresholds).to(config.device)
|
| 312 |
if config.confidence_metric == "breaking_ties":
|
| 313 |
self.confidence_metric_fn = breaking_ties
|
| 314 |
elif config.confidence_metric == "max":
|
|
|
|
| 316 |
else:
|
| 317 |
raise ValueError("confidence_metric must be 'breaking_ties' or 'max'.")
|
| 318 |
self.post_init()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def prepare_inputs_for_generation(
|
| 322 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 323 |
+
):
|
| 324 |
+
if past_key_values is not None:
|
| 325 |
+
if isinstance(past_key_values, Cache):
|
| 326 |
+
cache_length = past_key_values.get_seq_length()
|
| 327 |
+
past_length = past_key_values.seen_tokens
|
| 328 |
+
max_cache_length = past_key_values.get_max_length()
|
| 329 |
+
else:
|
| 330 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 331 |
+
max_cache_length = None
|
| 332 |
+
|
| 333 |
+
# Keep only the unprocessed tokens:
|
| 334 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 335 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 336 |
+
# input)
|
| 337 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 338 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 339 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 340 |
+
# input_ids based on the past_length.
|
| 341 |
+
elif past_length < input_ids.shape[1]:
|
| 342 |
+
input_ids = input_ids[:, past_length:]
|
| 343 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 344 |
+
|
| 345 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 346 |
+
if (
|
| 347 |
+
max_cache_length is not None
|
| 348 |
+
and attention_mask is not None
|
| 349 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 350 |
+
):
|
| 351 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 352 |
+
|
| 353 |
+
position_ids = kwargs.get("position_ids", None)
|
| 354 |
+
if attention_mask is not None and position_ids is None:
|
| 355 |
+
# create position_ids on the fly for batch generation
|
| 356 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 357 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 358 |
+
if past_key_values:
|
| 359 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 360 |
+
|
| 361 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 362 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 363 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 364 |
+
else:
|
| 365 |
+
model_inputs = {"input_ids": input_ids}
|
| 366 |
+
|
| 367 |
+
model_inputs.update(
|
| 368 |
+
{
|
| 369 |
+
"position_ids": position_ids,
|
| 370 |
+
"past_key_values": past_key_values,
|
| 371 |
+
"use_cache": kwargs.get("use_cache"),
|
| 372 |
+
"attention_mask": attention_mask,
|
| 373 |
+
}
|
| 374 |
+
)
|
| 375 |
+
return model_inputs
|
| 376 |
|
| 377 |
def forward(self,
|
| 378 |
input_ids: torch.LongTensor = None,
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
library_name: transformers
|
| 3 |
-
license: mit
|
| 4 |
language:
|
| 5 |
- en
|
|
|
|
|
|
|
| 6 |
pipeline_tag: text-generation
|
| 7 |
---
|
| 8 |
|
|
|
|
| 1 |
---
|
|
|
|
|
|
|
| 2 |
language:
|
| 3 |
- en
|
| 4 |
+
license: mit
|
| 5 |
+
library_name: transformers
|
| 6 |
pipeline_tag: text-generation
|
| 7 |
---
|
| 8 |
|
config.json
CHANGED
|
@@ -15,6 +15,7 @@
|
|
| 15 |
"branch_number": 4,
|
| 16 |
"confidence_metric": "breaking_ties",
|
| 17 |
"copy_lm_head": false,
|
|
|
|
| 18 |
"head_thresholds": [
|
| 19 |
10.0,
|
| 20 |
10.0,
|
|
|
|
| 15 |
"branch_number": 4,
|
| 16 |
"confidence_metric": "breaking_ties",
|
| 17 |
"copy_lm_head": false,
|
| 18 |
+
"device": "cuda:1",
|
| 19 |
"head_thresholds": [
|
| 20 |
10.0,
|
| 21 |
10.0,
|
generation_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"transformers_version": "4.40.2"
|
| 4 |
+
}
|