shallowblueQAQ commited on
Commit
eb9e779
·
verified ·
1 Parent(s): 5b18e0b

Update relevance_model/model.py

Browse files
Files changed (1) hide show
  1. relevance_model/model.py +0 -3
relevance_model/model.py CHANGED
@@ -2,8 +2,6 @@ import torch
2
  from torch import nn
3
  from transformers import AutoModel, PreTrainedModel, AutoConfig
4
 
5
- # 你可以直接复用你原来的类,稍微改一下使其兼容 HF 的 save_pretrained 更好
6
- # 但为了保持和你训练时完全一致,最简单的就是保留你原本的写法
7
  class BERTDiseaseClassifier(nn.Module):
8
  def __init__(self, model_type, num_symps) -> None:
9
  super().__init__()
@@ -16,7 +14,6 @@ class BERTDiseaseClassifier(nn.Module):
16
 
17
  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
18
  outputs = self.encoder(input_ids, attention_mask, token_type_ids)
19
- # 保持和你训练时完全一致的逻辑
20
  x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling
21
  x = self.dropout(x)
22
  logits = self.clf(x)
 
2
  from torch import nn
3
  from transformers import AutoModel, PreTrainedModel, AutoConfig
4
 
 
 
5
  class BERTDiseaseClassifier(nn.Module):
6
  def __init__(self, model_type, num_symps) -> None:
7
  super().__init__()
 
14
 
15
  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
16
  outputs = self.encoder(input_ids, attention_mask, token_type_ids)
 
17
  x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling
18
  x = self.dropout(x)
19
  logits = self.clf(x)