|
|
| import math |
|
|
| import torch |
| import torch.utils.checkpoint |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
|
|
| from transformers.modeling_outputs import (ModelOutput,) |
|
|
|
|
| class CounterModel(PreTrainedModel): |
|
|
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.weight = config.weight |
| self.bias = config.bias |
| self.linear = torch.nn.Linear(1,1) |
| def forward(self, x,**kwargs): |
| x = self.weight * x + self.bias |
| logits = self.linear(x) |
| return logits |
| def add(self): |
| return self.weight + self.bias |