| """ |
| BertRegressor — truncated bert-base-uncased + single-Linear regression head. |
| |
| Architecture used in the Ace-CEFR baseline reproduction |
| (https://arxiv.org/abs/2506.14046, §4.5.1). |
| |
| The model loads the first `num_hidden_layers` transformer blocks of |
| `bert-base-uncased`, plus its embeddings and pooler, and predicts a CEFR |
| difficulty score as a float in [1.0, 6.0] (A1 = 1, A2 = 2, B1 = 3, B2 = 4, |
| C1 = 5, C2 = 6). |
| |
| Example: |
| >>> import torch |
| >>> from transformers import BertTokenizerFast |
| >>> from modeling import BertRegressor |
| >>> model = BertRegressor("bert-base-uncased", num_layers=3) |
| >>> sd = torch.load("pytorch_model.bin", map_location="cpu") |
| >>> model.load_state_dict(sd) |
| >>> model.eval() |
| >>> tok = BertTokenizerFast.from_pretrained("bert-base-uncased") |
| >>> enc = tok(["Hello, how are you?"], return_tensors="pt", |
| ... padding="max_length", truncation=True, max_length=128) |
| >>> with torch.no_grad(): |
| ... score = model(enc["input_ids"], enc["attention_mask"], |
| ... enc["token_type_ids"]).clamp(1.0, 6.0).item() |
| >>> print(score) # e.g. 1.4 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import BertConfig, BertModel |
|
|
|
|
| class BertRegressor(nn.Module): |
| def __init__(self, model_name: str = "bert-base-uncased", num_layers: int = 3): |
| super().__init__() |
| cfg = BertConfig.from_pretrained(model_name) |
| cfg.num_hidden_layers = num_layers |
| self.bert = BertModel(cfg) |
|
|
| pretrained = BertModel.from_pretrained(model_name) |
| self.bert.embeddings.load_state_dict(pretrained.embeddings.state_dict()) |
| for i in range(num_layers): |
| self.bert.encoder.layer[i].load_state_dict( |
| pretrained.encoder.layer[i].state_dict() |
| ) |
| self.bert.pooler.load_state_dict(pretrained.pooler.state_dict()) |
| del pretrained |
|
|
| self.regressor = nn.Linear(cfg.hidden_size, 1) |
|
|
| def forward(self, input_ids, attention_mask, token_type_ids): |
| out = self.bert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| return self.regressor(out.pooler_output).squeeze(-1) |
|
|