permutans commited on
Commit
621c79f
·
verified ·
1 Parent(s): 47ff542

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_havelock.py +10 -5
modeling_havelock.py CHANGED
@@ -2,11 +2,12 @@
2
 
3
  import torch
4
  import torch.nn as nn
5
- from transformers import AutoConfig, AutoModel, PreTrainedModel, PretrainedConfig
6
 
7
 
8
  class HavelockTokenConfig(PretrainedConfig):
9
  """Config that wraps any backbone config + our custom fields."""
 
10
  model_type = "havelock_token_classifier"
11
 
12
  def __init__(self, num_types: int = 1, use_crf: bool = False, **kwargs):
@@ -18,7 +19,9 @@ class HavelockTokenConfig(PretrainedConfig):
18
  class HavelockTokenClassifier(PreTrainedModel):
19
  config_class = HavelockTokenConfig
20
 
21
- def __init__(self, config: HavelockTokenConfig, backbone: PreTrainedModel | None = None):
 
 
22
  super().__init__(config)
23
  self.num_types = config.num_types
24
  self.use_crf = config.use_crf
@@ -29,7 +32,7 @@ class HavelockTokenClassifier(PreTrainedModel):
29
  else:
30
  self.backbone = AutoModel.from_config(config)
31
 
32
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
33
  self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
34
 
35
  if self.use_crf:
@@ -75,7 +78,9 @@ class HavelockTokenClassifier(PreTrainedModel):
75
  mask = (
76
  attention_mask.bool()
77
  if attention_mask is not None
78
- else torch.ones(logits.shape[:2], dtype=torch.bool, device=logits.device)
 
 
79
  )
80
  return self.crf.decode(logits, mask)
81
- return logits.argmax(dim=-1)
 
2
 
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
6
 
7
 
8
  class HavelockTokenConfig(PretrainedConfig):
9
  """Config that wraps any backbone config + our custom fields."""
10
+
11
  model_type = "havelock_token_classifier"
12
 
13
  def __init__(self, num_types: int = 1, use_crf: bool = False, **kwargs):
 
19
  class HavelockTokenClassifier(PreTrainedModel):
20
  config_class = HavelockTokenConfig
21
 
22
+ def __init__(
23
+ self, config: HavelockTokenConfig, backbone: PreTrainedModel | None = None
24
+ ):
25
  super().__init__(config)
26
  self.num_types = config.num_types
27
  self.use_crf = config.use_crf
 
32
  else:
33
  self.backbone = AutoModel.from_config(config)
34
 
35
+ self.dropout = nn.Dropout(getattr(config, "hidden_dropout_prob", 0.1))
36
  self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
37
 
38
  if self.use_crf:
 
78
  mask = (
79
  attention_mask.bool()
80
  if attention_mask is not None
81
+ else torch.ones(
82
+ logits.shape[:2], dtype=torch.bool, device=logits.device
83
+ )
84
  )
85
  return self.crf.decode(logits, mask)
86
+ return logits.argmax(dim=-1)