File size: 235 Bytes
7575c08
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from transformers import BertConfig


class BertMultiTaskConfig(BertConfig):
    model_type = "bert"

    def __init__(self, tasks: dict[str, int] | None = None, **kwargs):
        self.tasks = tasks
        super().__init__(**kwargs)