haukurpj commited on
Commit
3bce6c8
·
1 Parent(s): 6edf0cd

improve comments

Browse files
Files changed (1) hide show
  1. modeling.py +138 -140
modeling.py CHANGED
@@ -25,37 +25,41 @@ class MultiLabelTokenClassificationHead(nn.Module):
25
  self.num_labels = config.num_labels
26
  self.hidden_size = config.hidden_size
27
 
 
28
  self.dense = nn.Linear(self.hidden_size, self.hidden_size)
29
  self.activation_fn = F.relu
30
  self.dropout = nn.Dropout(p=config.classifier_dropout)
31
  self.layer_norm = nn.LayerNorm(self.hidden_size)
32
 
33
- # Category projection: hidden_size -> num_categories
 
34
  self.cat_proj = nn.Linear(self.hidden_size, self.num_categories)
35
-
36
- # Attribute projection: (hidden_size + num_categories) -> num_labels
37
  self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels)
38
 
39
  def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
40
  """
 
 
41
  Args:
42
- features: Word-level features of shape (batch_size, max_words, hidden_size)
43
 
44
  Returns:
45
- cat_logits: Category logits of shape (batch_size, max_words, num_categories)
46
- attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
47
  """
48
- x = self.dropout(features)
49
- x = self.dense(x)
50
- x = self.layer_norm(x)
51
- x = self.activation_fn(x)
52
 
53
- # Predict categories
54
  cat_logits = self.cat_proj(x)
55
- cat_probs = torch.softmax(cat_logits, dim=-1)
56
 
57
- # Predict attributes using concatenated features
58
  attr_input = torch.cat((cat_probs, x), dim=-1)
 
59
  attr_logits = self.out_proj(attr_input)
60
 
61
  return cat_logits, attr_logits
@@ -94,22 +98,22 @@ class IceBertPosForTokenClassification(PreTrainedModel):
94
  # Create tensors as regular attributes (not buffers to avoid init warnings)
95
  self.group_mask = schema.get_group_masks()
96
  self.group_name_to_group_attr_indices = schema.get_group_name_to_group_attr_indices()
97
-
98
  # Category name to index mapping (regular dict, no device movement needed)
99
  self.category_name_to_index = schema.get_category_name_to_index()
100
-
101
- def _apply(self, fn):
102
  """Override _apply to move our custom tensors with the model."""
103
  super()._apply(fn)
104
-
105
  # Move our custom tensors when model.to(device) is called
106
- if hasattr(self, 'group_mask'):
107
  self.group_mask = fn(self.group_mask)
108
-
109
- if hasattr(self, 'group_name_to_group_attr_indices'):
110
  for group_name, tensor in self.group_name_to_group_attr_indices.items():
111
  self.group_name_to_group_attr_indices[group_name] = fn(tensor)
112
-
113
  return self
114
 
115
  def forward(
@@ -126,14 +130,16 @@ class IceBertPosForTokenClassification(PreTrainedModel):
126
  return_dict: Optional[bool] = None,
127
  ) -> Tuple[torch.Tensor, torch.Tensor]:
128
  """
 
 
129
  Args:
130
- input_ids: Token indices of shape (batch_size, sequence_length)
131
- attention_mask: Attention mask of shape (batch_size, sequence_length)
132
- word_mask: Binary mask indicating word boundaries (1 = word start) of shape (batch_size, sequence_length)
133
 
134
  Returns:
135
- cat_logits: Category logits of shape (batch_size, max_words, num_categories)
136
- attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
137
  """
138
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
139
 
@@ -150,112 +156,87 @@ class IceBertPosForTokenClassification(PreTrainedModel):
150
  return_dict=return_dict,
151
  )
152
 
153
- x = outputs[0] # (batch_size, seq_len, hidden)
154
-
155
- # Copy exact logic from old model
156
- _, _, inner_dim = x.shape
157
-
158
- # use first bpe token of word as representation
159
- x = x[:, 1:-1, :]
160
- starts = word_mask[:, 1:-1] # remove bos, eos
161
- ends = starts.roll(-1, dims=[-1]).nonzero()[:, -1] + 1
162
- starts = starts.nonzero().tolist()
163
- mean_words = []
164
- for (seq_idx, token_idx), end in zip(starts, ends):
165
- mean_words.append(x[seq_idx, token_idx:end, :].mean(dim=0))
166
- mean_words = torch.stack(mean_words)
167
- words = mean_words
168
- # Innermost dimension is mask for tokens at head of word.
169
- nwords = word_mask.sum(dim=-1)
170
- (cat_logits, attr_logits) = self.classifier(words)
171
-
172
- # (Batch * Time) x Depth -> Batch x Time x Depth
173
- cat_logits = pad_sequence(cat_logits.split((nwords).tolist()), padding_value=0, batch_first=True)
174
- attr_logits = pad_sequence(
175
- attr_logits.split((nwords).tolist()),
176
- padding_value=0,
177
- batch_first=True,
178
- )
179
  return cat_logits, attr_logits
180
 
181
  def _aggregate_subword_tokens(
182
- self, sequence_output: torch.Tensor, word_mask: torch.Tensor
183
- ) -> Tuple[torch.Tensor, torch.Tensor]:
184
  """
185
- Aggregate subword token representations to word-level representations.
186
- Following the original fairseq approach by averaging subword tokens within each word.
187
-
 
188
  Args:
189
- sequence_output: subword token representations (batch_size, seq_len, hidden_size)
190
- word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
 
191
 
192
  Returns:
193
- word_features: Word-level features (batch_size, max_words, hidden_size)
194
- nwords: Number of words per sequence (batch_size,)
195
  """
196
- # TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
197
- # Remove BOS and EOS tokens (first and last positions)
198
- x = sequence_output[:, 1:-1, :] # (batch_size, seq_len-2, hidden_size)
199
- starts = word_mask[:, 1:-1] # (batch_size, seq_len-2)
200
-
201
- # Count words per sequence
202
- nwords = starts.sum(dim=-1) # (batch_size,)
203
-
204
- # Find word boundaries and average tokens within each word
205
  mean_words = []
206
- batch_size, seq_len, hidden_size = x.shape
207
 
208
  for batch_idx in range(batch_size):
209
- seq_starts = starts[batch_idx] # (seq_len-2,)
210
- seq_x = x[batch_idx] # (seq_len-2, hidden_size)
211
-
212
- # Find start positions of words
213
- start_positions = seq_starts.nonzero(as_tuple=True)[0] # positions where words start
214
-
215
- if len(start_positions) == 0:
 
 
216
  continue
217
-
218
- # Calculate end positions (start of next word or end of sequence)
219
- end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)])
220
-
221
- # Average tokens within each word
222
- for start_pos, end_pos in zip(start_positions, end_positions):
223
- word_tokens = seq_x[start_pos:end_pos] # tokens in this word
224
- word_repr = word_tokens.mean(dim=0) # average representation
 
 
 
 
225
  mean_words.append(word_repr)
226
 
227
  if len(mean_words) == 0:
228
- return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords
229
 
230
- return torch.stack(mean_words), nwords
231
 
232
- def _reshape_to_batch_format(
233
- self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor
234
- ) -> Tuple[torch.Tensor, torch.Tensor]:
235
  """
236
- Reshape word-level predictions back to batch format.
237
- Following the original fairseq approach with pad_sequence.
238
-
 
239
  Args:
240
- cat_logits: Category logits (total_words, num_categories)
241
- attr_logits: Attribute logits (total_words, num_labels)
242
- nwords: Number of words per sequence (batch_size,)
243
-
244
  Returns:
245
- cat_logits_batch: (batch_size, max_words, num_categories)
246
- attr_logits_batch: (batch_size, max_words, num_labels)
247
  """
248
-
249
- # Split logits by sequence using word counts
250
- words_per_seq = nwords.tolist()
251
- cat_logits_split = cat_logits.split(words_per_seq)
252
- attr_logits_split = attr_logits.split(words_per_seq)
253
-
254
- # Pad to same length (matching original fairseq approach)
255
- cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0)
256
- attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0)
257
-
258
- return cat_logits_batch, attr_logits_batch
259
 
260
  @torch.no_grad()
261
  def predict_labels(
@@ -281,24 +262,26 @@ class IceBertPosForTokenClassification(PreTrainedModel):
281
 
282
  def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
283
  """
284
- Convert word_ids to word_mask (binary mask indicating word boundaries).
285
 
 
 
286
  Args:
287
- word_ids: List of word id sequences
288
- input_shape: Shape of input_ids tensor (batch_size, seq_len)
289
 
290
  Returns:
291
- word_mask: Binary tensor where 1 indicates start of word (batch_size, seq_len)
292
  """
293
  batch_size, seq_len = input_shape
294
- word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
295
 
296
  for batch_idx, seq_word_ids in enumerate(word_ids):
297
  prev_word_id = None
298
  for token_idx, word_id in enumerate(seq_word_ids):
299
  # Skip None values (special tokens and padding)
300
  if word_id is not None and word_id != prev_word_id:
301
- word_mask[batch_idx, token_idx] = 1
302
  # Only update prev_word_id for valid (non-None) word_ids
303
  if word_id is not None:
304
  prev_word_id = word_id
@@ -310,10 +293,12 @@ class IceBertPosForTokenClassification(PreTrainedModel):
310
 
311
  def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
312
  """
313
- Predict POS labels from raw text using fairseq-style preprocessing.
314
 
 
 
315
  Args:
316
- sentences: List of input sentences
317
  tokenizer: HuggingFace tokenizer
318
 
319
  Returns:
@@ -334,9 +319,9 @@ class IceBertPosForTokenClassification(PreTrainedModel):
334
 
335
  # Debug logging to match fairseq model
336
  for i in range(len(sentences)):
337
- logger.debug(f"Encoded tokens: {batch_input_ids[i]}")
338
  logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}")
339
- logger.debug(f"Word IDs: {word_ids_list[i]}")
340
 
341
  return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
342
 
@@ -345,48 +330,59 @@ class IceBertPosForTokenClassification(PreTrainedModel):
345
  ) -> List[List[Tuple[str, List[str]]]]:
346
  """
347
  Convert logits to human-readable labels using schema-based logic.
 
 
 
 
 
 
 
 
 
 
348
  """
349
- # logits: Batch x Time x Labels
350
  bsz, _, num_cats = cat_logits.shape
351
  _, _, num_attrs = attr_logits.shape
352
- nwords = word_mask.sum(-1)
353
 
354
  assert num_attrs == len(self.config.label_schema.labels)
355
  assert num_cats == len(self.config.label_schema.label_categories)
356
 
357
  predictions = []
358
  schema = self.config.label_schema
359
-
360
  for seq_idx in range(bsz):
361
  seq_nwords = nwords[seq_idx]
 
362
  pred_cat_indices = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
363
-
364
  seq_predictions = []
365
  for word_idx in range(seq_nwords):
366
- cat_idx = int(pred_cat_indices[word_idx].item())
367
  cat_name = schema.label_categories[cat_idx]
368
-
369
  # Get valid groups for this category
370
  valid_groups = schema.category_to_group_names.get(cat_name, [])
371
-
372
  # Collect attributes for this word
373
  attributes = []
374
  for group_name in valid_groups:
375
  if group_name in self.group_name_to_group_attr_indices:
376
- group_indices = self.group_name_to_group_attr_indices[group_name]
377
  if len(group_indices) > 0:
 
378
  group_logits = attr_logits[seq_idx, word_idx, group_indices]
379
  if len(group_indices) == 1:
380
- # Binary decision
381
  if group_logits.sigmoid().item() > 0.5:
382
- attr_idx = int(group_indices[0].item())
383
  attributes.append(schema.labels[attr_idx])
384
  else:
385
- # Multi-class decision
386
- best_idx = int(group_logits.max(dim=-1).indices.item())
387
- attr_idx = int(group_indices[best_idx].item())
388
  attributes.append(schema.labels[attr_idx])
389
-
390
  # Apply specific rules from original model
391
  if len(attributes) == 1 and attributes[0] == "pos":
392
  # This label is used as a default for training but implied in mim format
@@ -394,9 +390,9 @@ class IceBertPosForTokenClassification(PreTrainedModel):
394
  elif cat_name == "sl" and "act" in attributes:
395
  # Number and tense are not shown for sl act in mim format
396
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
397
-
398
  seq_predictions.append((cat_name, attributes))
399
-
400
  predictions.append(seq_predictions)
401
 
402
  return predictions
@@ -405,12 +401,14 @@ class IceBertPosForTokenClassification(PreTrainedModel):
405
  """
406
  Predict IFD format labels from raw text.
407
 
 
 
408
  Args:
409
- sentences: List of input sentences
410
  tokenizer: HuggingFace tokenizer
411
 
412
  Returns:
413
- List of sequences, each containing IFD format labels per word
414
  """
415
  # Get model predictions in (category, [attributes]) format
416
  predictions = self.predict_labels_from_text(sentences, tokenizer)
@@ -418,7 +416,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
418
  # Convert each sentence's predictions to IFD format
419
  ifd_predictions = []
420
  for sentence_predictions in predictions:
421
- ifd_labels = convert_predictions_to_ifd(sentence_predictions)
422
  ifd_predictions.append(ifd_labels)
423
 
424
  return ifd_predictions
 
25
  self.num_labels = config.num_labels
26
  self.hidden_size = config.hidden_size
27
 
28
+ # (*, H) -> (*, H)
29
  self.dense = nn.Linear(self.hidden_size, self.hidden_size)
30
  self.activation_fn = F.relu
31
  self.dropout = nn.Dropout(p=config.classifier_dropout)
32
  self.layer_norm = nn.LayerNorm(self.hidden_size)
33
 
34
+ # Projection heads for multilabel classification
35
+ # (*, H) -> (*, C)
36
  self.cat_proj = nn.Linear(self.hidden_size, self.num_categories)
37
+ # (*, H + C) -> (*, A)
 
38
  self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels)
39
 
40
  def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
41
  """
42
+ H = hidden_size, C = num_categories, A = num_attributes, Wt = total_words
43
+
44
  Args:
45
+ features: Word-level features (Wt x H)
46
 
47
  Returns:
48
+ cat_logits: Category logits (Wt x C)
49
+ attr_logits: Attribute logits (Wt x A)
50
  """
51
+ x = self.dropout(features) # (Wt x H)
52
+ x = self.dense(x) # (Wt x H)
53
+ x = self.layer_norm(x) # (Wt x H)
54
+ x = self.activation_fn(x) # (Wt x H)
55
 
56
+ # (Wt x H) -> (Wt x C)
57
  cat_logits = self.cat_proj(x)
58
+ cat_probs = torch.softmax(cat_logits, dim=-1) # (Wt x C)
59
 
60
+ # (Wt x H) + (Wt x C) -> (Wt x H+C)
61
  attr_input = torch.cat((cat_probs, x), dim=-1)
62
+ # (Wt x H+C) -> (Wt x A)
63
  attr_logits = self.out_proj(attr_input)
64
 
65
  return cat_logits, attr_logits
 
98
  # Create tensors as regular attributes (not buffers to avoid init warnings)
99
  self.group_mask = schema.get_group_masks()
100
  self.group_name_to_group_attr_indices = schema.get_group_name_to_group_attr_indices()
101
+
102
  # Category name to index mapping (regular dict, no device movement needed)
103
  self.category_name_to_index = schema.get_category_name_to_index()
104
+
105
+ def _apply(self, fn): # type: ignore
106
  """Override _apply to move our custom tensors with the model."""
107
  super()._apply(fn)
108
+
109
  # Move our custom tensors when model.to(device) is called
110
+ if hasattr(self, "group_mask"):
111
  self.group_mask = fn(self.group_mask)
112
+
113
+ if hasattr(self, "group_name_to_group_attr_indices"):
114
  for group_name, tensor in self.group_name_to_group_attr_indices.items():
115
  self.group_name_to_group_attr_indices[group_name] = fn(tensor)
116
+
117
  return self
118
 
119
  def forward(
 
130
  return_dict: Optional[bool] = None,
131
  ) -> Tuple[torch.Tensor, torch.Tensor]:
132
  """
133
+ B = batch_size, L = seq_len, H = hidden_size, C = num_categories, A = num_attributes, W = max_words
134
+
135
  Args:
136
+ input_ids: Token indices (B x L)
137
+ attention_mask: Attention mask (B x L)
138
+ word_mask: Binary mask indicating word boundaries, 1 = word start (B x L)
139
 
140
  Returns:
141
+ cat_logits: Category logits (B x W x C)
142
+ attr_logits: Attribute logits (B x W x A)
143
  """
144
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
145
 
 
156
  return_dict=return_dict,
157
  )
158
 
159
+ hidden_states = outputs[0] # (B x L x H)
160
+
161
+ # (B x L x H) -> (Wt x H)
162
+ word_embeddings = self._aggregate_subword_tokens(hidden_states, word_mask, attention_mask)
163
+
164
+ # (Wt x H) -> (Wt x C), (Wt x A)
165
+ cat_logits, attr_logits = self.classifier(word_embeddings)
166
+
167
+ # (Wt x C) -> (B x W x C), (Wt x A) -> (B x W x A)
168
+ nwords = word_mask.sum(dim=-1) # (B,)
169
+ cat_logits = self._reshape_to_batch_format(cat_logits, nwords)
170
+ attr_logits = self._reshape_to_batch_format(attr_logits, nwords)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return cat_logits, attr_logits
172
 
173
  def _aggregate_subword_tokens(
174
+ self, sequence_output: torch.Tensor, word_mask: torch.Tensor, attention_mask: torch.Tensor
175
+ ) -> torch.Tensor:
176
  """
177
+ Average subword tokens within each word to get word-level representations.
178
+
179
+ B = batch_size, L = seq_len, H = hidden_size, Wt = total_words
180
+
181
  Args:
182
+ sequence_output: Subword token representations (B x L x H)
183
+ word_mask: Binary mask where 1 indicates start of word (B x L)
184
+ attention_mask: Attention mask to exclude padding tokens (B x L)
185
 
186
  Returns:
187
+ word_features: Concatenated word-level features (Wt x H)
 
188
  """
189
+ batch_size, seq_len, hidden_size = sequence_output.shape
 
 
 
 
 
 
 
 
190
  mean_words = []
 
191
 
192
  for batch_idx in range(batch_size):
193
+ # Get valid (non-padding) tokens for this sequence
194
+ valid_mask = attention_mask[batch_idx].bool() # (L,) -> (Lv,)
195
+ seq_output = sequence_output[batch_idx, valid_mask] # (Lv x H)
196
+ seq_word_mask = word_mask[batch_idx, valid_mask] # (Lv,)
197
+
198
+ # Find word start positions
199
+ word_starts = seq_word_mask.nonzero(as_tuple=True)[0] # (Ws,)
200
+
201
+ if len(word_starts) == 0:
202
  continue
203
+
204
+ # For each word, find its token span and average
205
+ for i, start_pos in enumerate(word_starts):
206
+ # Find end position (start of next word or end of valid sequence)
207
+ if i + 1 < len(word_starts):
208
+ end_pos = word_starts[i + 1]
209
+ else:
210
+ end_pos = len(seq_output)
211
+
212
+ # Average tokens within this word (excluding padding)
213
+ word_tokens = seq_output[start_pos:end_pos] # (Tw x H)
214
+ word_repr = word_tokens.mean(dim=0) # (H,)
215
  mean_words.append(word_repr)
216
 
217
  if len(mean_words) == 0:
218
+ return torch.empty(0, hidden_size, device=sequence_output.device)
219
 
220
+ return torch.stack(mean_words) # (Wt x H)
221
 
222
+ def _reshape_to_batch_format(self, logits: torch.Tensor, nwords: torch.Tensor) -> torch.Tensor:
 
 
223
  """
224
+ Reshape concatenated word predictions back to padded batch format.
225
+
226
+ B = batch_size, W = max_words, Wt = total_words, K = num_classes
227
+
228
  Args:
229
+ logits: Concatenated word predictions (Wt x K)
230
+ nwords: Number of words per sequence (B,)
231
+
 
232
  Returns:
233
+ batch_logits: Batched predictions (B x W x K)
 
234
  """
235
+ return pad_sequence(
236
+ logits.split(nwords.tolist()),
237
+ padding_value=0,
238
+ batch_first=True,
239
+ )
 
 
 
 
 
 
240
 
241
  @torch.no_grad()
242
  def predict_labels(
 
262
 
263
  def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
264
  """
265
+ Convert word_ids to binary mask indicating word boundaries.
266
 
267
+ B = batch_size, L = seq_len
268
+
269
  Args:
270
+ word_ids: List of word id sequences for each batch item
271
+ input_shape: Shape of input_ids tensor (B x L)
272
 
273
  Returns:
274
+ word_mask: Binary tensor where 1 indicates start of word (B x L)
275
  """
276
  batch_size, seq_len = input_shape
277
+ word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long) # (B x L)
278
 
279
  for batch_idx, seq_word_ids in enumerate(word_ids):
280
  prev_word_id = None
281
  for token_idx, word_id in enumerate(seq_word_ids):
282
  # Skip None values (special tokens and padding)
283
  if word_id is not None and word_id != prev_word_id:
284
+ word_mask[batch_idx, token_idx] = 1 # Mark word start
285
  # Only update prev_word_id for valid (non-None) word_ids
286
  if word_id is not None:
287
  prev_word_id = word_id
 
293
 
294
  def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
295
  """
296
+ Predict POS labels from raw text.
297
 
298
+ B = batch_size, L = seq_len
299
+
300
  Args:
301
+ sentences: List of input sentences (B,)
302
  tokenizer: HuggingFace tokenizer
303
 
304
  Returns:
 
319
 
320
  # Debug logging to match fairseq model
321
  for i in range(len(sentences)):
322
+ logger.debug(f"Encoded tokens: {batch_input_ids[i]}") # (L,)
323
  logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}")
324
+ logger.debug(f"Word IDs: {word_ids_list[i]}") # (L,)
325
 
326
  return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
327
 
 
330
  ) -> List[List[Tuple[str, List[str]]]]:
331
  """
332
  Convert logits to human-readable labels using schema-based logic.
333
+
334
+ B = batch_size, W = max_words, C = num_categories, A = num_attributes, L = seq_len
335
+
336
+ Args:
337
+ cat_logits: Category logits (B x W x C)
338
+ attr_logits: Attribute logits (B x W x A)
339
+ word_mask: Binary mask for valid words (B x L)
340
+
341
+ Returns:
342
+ predictions: List of [(category, [attributes])] for each sequence
343
  """
 
344
  bsz, _, num_cats = cat_logits.shape
345
  _, _, num_attrs = attr_logits.shape
346
+ nwords = word_mask.sum(-1) # (B,)
347
 
348
  assert num_attrs == len(self.config.label_schema.labels)
349
  assert num_cats == len(self.config.label_schema.label_categories)
350
 
351
  predictions = []
352
  schema = self.config.label_schema
353
+
354
  for seq_idx in range(bsz):
355
  seq_nwords = nwords[seq_idx]
356
+ # (W x C) -> (seq_nwords,)
357
  pred_cat_indices = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
358
+
359
  seq_predictions = []
360
  for word_idx in range(seq_nwords):
361
+ cat_idx = pred_cat_indices[word_idx].item()
362
  cat_name = schema.label_categories[cat_idx]
363
+
364
  # Get valid groups for this category
365
  valid_groups = schema.category_to_group_names.get(cat_name, [])
366
+
367
  # Collect attributes for this word
368
  attributes = []
369
  for group_name in valid_groups:
370
  if group_name in self.group_name_to_group_attr_indices:
371
+ group_indices = self.group_name_to_group_attr_indices[group_name] # (Gs,)
372
  if len(group_indices) > 0:
373
+ # (A,) -> (Gs,)
374
  group_logits = attr_logits[seq_idx, word_idx, group_indices]
375
  if len(group_indices) == 1:
376
+ # Binary decision for single-item groups
377
  if group_logits.sigmoid().item() > 0.5:
378
+ attr_idx = group_indices[0].item()
379
  attributes.append(schema.labels[attr_idx])
380
  else:
381
+ # Multi-class decision for multi-item groups
382
+ best_idx = group_logits.max(dim=-1).indices.item()
383
+ attr_idx = group_indices[best_idx].item()
384
  attributes.append(schema.labels[attr_idx])
385
+
386
  # Apply specific rules from original model
387
  if len(attributes) == 1 and attributes[0] == "pos":
388
  # This label is used as a default for training but implied in mim format
 
390
  elif cat_name == "sl" and "act" in attributes:
391
  # Number and tense are not shown for sl act in mim format
392
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
393
+
394
  seq_predictions.append((cat_name, attributes))
395
+
396
  predictions.append(seq_predictions)
397
 
398
  return predictions
 
401
  """
402
  Predict IFD format labels from raw text.
403
 
404
+ B = batch_size, Ws = seq_words
405
+
406
  Args:
407
+ sentences: List of input sentences (B,)
408
  tokenizer: HuggingFace tokenizer
409
 
410
  Returns:
411
+ ifd_predictions: List of IFD labels per sentence (B x Ws)
412
  """
413
  # Get model predictions in (category, [attributes]) format
414
  predictions = self.predict_labels_from_text(sentences, tokenizer)
 
416
  # Convert each sentence's predictions to IFD format
417
  ifd_predictions = []
418
  for sentence_predictions in predictions:
419
+ ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
420
  ifd_predictions.append(ifd_labels)
421
 
422
  return ifd_predictions