emanuelaboros commited on
Commit
47aff3b
·
1 Parent(s): d2959f2

testin the trick

Browse files
Files changed (1) hide show
  1. modeling_stacked.py +23 -4
modeling_stacked.py CHANGED
@@ -27,6 +27,22 @@ def get_info(label_map):
27
  # return cls()
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
31
  config_class = ImpressoConfig
32
  _keys_to_ignore_on_load_missing = [r"position_ids"]
@@ -37,16 +53,19 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
37
 
38
  # Load floret model
39
  self.dummy_param = nn.Parameter(torch.zeros(1))
40
- self.model_floret = floret.load_model(self.config.filename)
41
- input_ids = "this is a text"
42
- predictions, probabilities = self.model_floret.predict([input_ids], k=1)
 
 
43
 
 
44
  def forward(self, input_ids, attention_mask=None, **kwargs):
45
  # Convert input_ids to strings using tokenizer
46
  print(
47
  f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
48
  )
49
-
50
  # if input_ids is not None:
51
  # tokenizer = kwargs.get("tokenizer")
52
  # texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
 
27
  # return cls()
28
 
29
 
30
+ class SafeFloretWrapper(nn.Module):
31
+ """
32
+ A safe wrapper for floret model that keeps it off-device to avoid segmentation faults.
33
+ """
34
+
35
+ def __init__(self, floret_model):
36
+ super().__init__()
37
+ self.floret_model = floret_model
38
+
39
+ def forward(self, texts):
40
+ # Floret expects strings, not tensors
41
+ _, predictions = self.model_floret.predict([texts], k=1)
42
+ # Convert predictions to tensors for Hugging Face compatibility
43
+ return torch.tensor(predictions)
44
+
45
+
46
  class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
47
  config_class = ImpressoConfig
48
  _keys_to_ignore_on_load_missing = [r"position_ids"]
 
53
 
54
  # Load floret model
55
  self.dummy_param = nn.Parameter(torch.zeros(1))
56
+ model_floret = floret.load_model(self.config.filename)
57
+ self.model_floret = SafeFloretWrapper(model_floret)
58
+ # input_ids = "this is a text"
59
+
60
+ # predictions, probabilities = self.model_floret.predict([input_ids], k=1)
61
 
62
+ #
63
  def forward(self, input_ids, attention_mask=None, **kwargs):
64
  # Convert input_ids to strings using tokenizer
65
  print(
66
  f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
67
  )
68
+ print(self.model_floret(input_ids))
69
  # if input_ids is not None:
70
  # tokenizer = kwargs.get("tokenizer")
71
  # texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)