emanuelaboros commited on
Commit
a886816
·
1 Parent(s): 427163b

lets try to change the pipeline

Browse files
Files changed (1) hide show
  1. configuration_stacked.py +8 -2
configuration_stacked.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import PretrainedConfig
2
  import torch
3
 
 
4
  class ImpressoConfig(PretrainedConfig):
5
  model_type = "stacked_bert"
6
 
@@ -25,6 +26,7 @@ class ImpressoConfig(PretrainedConfig):
25
  pretrained_config=None,
26
  values_override=None,
27
  label_map=None,
 
28
  **kwargs,
29
  ):
30
  super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -46,6 +48,7 @@ class ImpressoConfig(PretrainedConfig):
46
  self.classifier_dropout = classifier_dropout
47
  self.pretrained_config = pretrained_config
48
  self.label_map = label_map
 
49
 
50
  self.values_override = values_override or {}
51
  self.outputs = {
@@ -72,7 +75,9 @@ class ImpressoConfig(PretrainedConfig):
72
  """
73
  return None
74
 
75
- def generate_dummy_inputs(self, tokenizer, batch_size=1, seq_length=8, framework="pt"):
 
 
76
  """
77
  Generate dummy inputs for testing or export.
78
  Args:
@@ -88,12 +93,13 @@ class ImpressoConfig(PretrainedConfig):
88
  low=0,
89
  high=self.vocab_size,
90
  size=(batch_size, seq_length),
91
- dtype=torch.long
92
  )
93
  attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
94
  return {"input_ids": input_ids, "attention_mask": attention_mask}
95
  else:
96
  raise ValueError("Framework '{}' not supported.".format(framework))
97
 
 
98
  # Register the configuration with the transformers library
99
  ImpressoConfig.register_for_auto_class()
 
1
  from transformers import PretrainedConfig
2
  import torch
3
 
4
+
5
  class ImpressoConfig(PretrainedConfig):
6
  model_type = "stacked_bert"
7
 
 
26
  pretrained_config=None,
27
  values_override=None,
28
  label_map=None,
29
+ filename=None,
30
  **kwargs,
31
  ):
32
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
48
  self.classifier_dropout = classifier_dropout
49
  self.pretrained_config = pretrained_config
50
  self.label_map = label_map
51
+ self.filename = filename
52
 
53
  self.values_override = values_override or {}
54
  self.outputs = {
 
75
  """
76
  return None
77
 
78
+ def generate_dummy_inputs(
79
+ self, tokenizer, batch_size=1, seq_length=8, framework="pt"
80
+ ):
81
  """
82
  Generate dummy inputs for testing or export.
83
  Args:
 
93
  low=0,
94
  high=self.vocab_size,
95
  size=(batch_size, seq_length),
96
+ dtype=torch.long,
97
  )
98
  attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
99
  return {"input_ids": input_ids, "attention_mask": attention_mask}
100
  else:
101
  raise ValueError("Framework '{}' not supported.".format(framework))
102
 
103
+
104
  # Register the configuration with the transformers library
105
  ImpressoConfig.register_for_auto_class()