wangleiofficial commited on
Commit
aa1f56f
·
verified ·
1 Parent(s): 2967ddc

Update dnaflash.py

Browse files
Files changed (1) hide show
  1. dnaflash.py +73 -1
dnaflash.py CHANGED
@@ -7,7 +7,11 @@ from einops import rearrange
7
  from rotary_embedding_torch import RotaryEmbedding
8
 
9
  from transformers import PreTrainedModel, PretrainedConfig
10
- from transformers.modeling_outputs import MaskedLMOutput
 
 
 
 
11
 
12
  # helper functions
13
 
@@ -413,3 +417,71 @@ class FLASHTransformerForPretrained(PreTrainedModel):
413
  logits, x = self.model(inputs["input_ids"], mask=inputs["attention_mask"])
414
  return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from rotary_embedding_torch import RotaryEmbedding
8
 
9
  from transformers import PreTrainedModel, PretrainedConfig
10
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
11
+
12
+ import torch.utils.checkpoint
13
+ from torch import nn, Tensor
14
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
15
 
16
  # helper functions
17
 
 
417
  logits, x = self.model(inputs["input_ids"], mask=inputs["attention_mask"])
418
  return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
419
 
420
+ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
421
+ def __init__(self, config):
422
+ super().__init__(config)
423
+ self.num_labels = config.num_labels
424
+ self.config = config
425
+
426
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
427
+ if getattr(config, "use_mlp_classifier", False):
428
+ self.score = nn.Sequential(
429
+ nn.Linear(config.hidden_size, config.hidden_size),
430
+ nn.GELU(),
431
+ nn.Dropout(0.1),
432
+ nn.Linear(config.hidden_size, self.num_labels, bias=False),
433
+ )
434
+
435
+ def forward(
436
+ self,
437
+ input_ids: Optional[torch.LongTensor] = None,
438
+ labels: Optional[torch.LongTensor] = None,
439
+ return_dict: Optional[bool] = None,
440
+ ) -> Union[Tuple, SequenceClassifierOutput]:
441
+ r"""
442
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
443
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
444
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
445
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
446
+ """
447
+
448
+ # 获取基模型输出
449
+ outputs = super().forward(
450
+ input_ids=input_ids
451
+ )
452
+ hidden_states = outputs["hidden_states"]
453
+ input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
454
+ mean_pooled = torch.sum(token_embeddings * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
455
+ logits = self.score(mean_pooled)
456
+
457
+ loss = None
458
+ if labels is not None:
459
+ labels = labels.to(logits.device)
460
+
461
+ if self.config.problem_type is None:
462
+ if self.num_labels == 1:
463
+ self.config.problem_type = "regression"
464
+ elif self.num_labels > 1 and (
465
+ labels.dtype == torch.long or labels.dtype == torch.int
466
+ ):
467
+ self.config.problem_type = "single_label_classification"
468
+ else:
469
+ self.config.problem_type = "multi_label_classification"
470
+
471
+ if self.config.problem_type == "regression":
472
+ loss_fct = MSELoss()
473
+ if self.num_labels == 1:
474
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
475
+ else:
476
+ loss = loss_fct(logits, labels)
477
+ elif self.config.problem_type == "single_label_classification":
478
+ loss_fct = CrossEntropyLoss()
479
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
480
+ elif self.config.problem_type == "multi_label_classification":
481
+ loss_fct = BCEWithLogitsLoss()
482
+ loss = loss_fct(logits, labels)
483
+ if not return_dict:
484
+ output = (logits,)
485
+ return ((loss,) + output) if loss is not None else output
486
+
487
+ return SequenceClassifierOutput(loss=loss, logits=logits)