katrjohn commited on
Commit
6182cc7
Β·
verified Β·
1 Parent(s): 09a88c9

Upload 4 files

Browse files
configuration_mbert_greek_news_bert.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+
4
+ class MBertGreekNewsConfig(BertConfig):
5
+
6
+ model_type = "mbert_greek_news"
7
+
8
+ def __init__(
9
+ self,
10
+ num_labels_class: int = 19,
11
+ num_labels_ner: int = 32,
12
+ ner_loss_weight: float = 3.0,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.num_labels_class = num_labels_class
17
+ self.num_labels_ner = num_labels_ner
18
+ self.ner_loss_weight = ner_loss_weight
19
+
20
+
21
+ MBertGreekNewsConfig.register_for_auto_class()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:895f362839b1c433b65b0e56acc86aaead0e2033e7ea633953ea2b3ef88df0ea
3
+ size 713956908
modeling_mbert_greek_news_bert.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import BertModel, BertPreTrainedModel
3
+ # relative import β†’ required for remote code
4
+ from .configuration_mbert_greek_news import MBertGreekNewsConfig
5
+
6
+
7
+ class MBertGreekNews(BertPreTrainedModel):
8
+ config_class = MBertGreekNewsConfig
9
+ _auto_class = "AutoModel" # appears in auto_map
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+
14
+ self.bert = BertModel(config)
15
+
16
+ n_cls = config.num_labels_class
17
+ n_ner = config.num_labels_ner
18
+ self.ner_loss_weight = getattr(config, "ner_loss_weight", 3.0)
19
+
20
+ # ── classification head ─────────────────────────────
21
+ self.class_dropout = nn.Dropout(0.3)
22
+ self.class_fc = nn.Linear(config.hidden_size, 768)
23
+ self.class_relu = nn.ReLU()
24
+ self.classifier = nn.Linear(768, n_cls)
25
+
26
+ # ── NER head ────────────────────────────────────────
27
+ self.ner_classifier = nn.Linear(config.hidden_size, n_ner)
28
+
29
+ # helpers for dynamic-normalised training
30
+ self.initial_cls_loss = None
31
+ self.initial_ner_loss = None
32
+
33
+ self.init_weights()
34
+
35
+ # ----------------------------------------------------------
36
+ def forward(
37
+ self,
38
+ input_ids,
39
+ attention_mask=None,
40
+ token_type_ids=None,
41
+ labels_class=None,
42
+ labels_ner=None,
43
+ ):
44
+ outputs = self.bert(
45
+ input_ids,
46
+ attention_mask=attention_mask,
47
+ token_type_ids=token_type_ids,
48
+ return_dict=True,
49
+ )
50
+ seq_out = outputs.last_hidden_state # (B, L, H)
51
+ pooled_out= outputs.pooler_output # (B, H)
52
+
53
+ # ── classification branch ───────────────────────────
54
+ x = self.class_dropout(pooled_out)
55
+ x = self.class_fc(x)
56
+ x = self.class_relu(x)
57
+ logits_class = self.classifier(x)
58
+
59
+ # ── NER branch ──────────────────────────────────────
60
+ logits_ner = self.ner_classifier(seq_out)
61
+
62
+ # inference path
63
+ if labels_class is None or labels_ner is None:
64
+ return logits_class, logits_ner
65
+
66
+ # β€” classification loss
67
+ loss_cls = nn.CrossEntropyLoss()(logits_class, labels_class)
68
+
69
+ # β€” NER loss: summed, averaged over non-pad tokens
70
+ ner_loss_sum = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")(
71
+ logits_ner.view(-1, logits_ner.size(-1)),
72
+ labels_ner.view(-1),
73
+ )
74
+ mask = (labels_ner != -100).view(-1).float()
75
+ loss_ner = ner_loss_sum / (mask.sum() + 1e-9)
76
+
77
+ # β€” dynamic normalisation
78
+ if self.initial_cls_loss is None and self.training:
79
+ self.initial_cls_loss = loss_cls.item()
80
+ if self.initial_ner_loss is None and self.training:
81
+ self.initial_ner_loss = loss_ner.item()
82
+
83
+ if (self.initial_cls_loss is not None) and (self.initial_ner_loss is not None):
84
+ norm_cls = loss_cls / (self.initial_cls_loss + 1e-8)
85
+ norm_ner = loss_ner / (self.initial_ner_loss + 1e-8)
86
+ else:
87
+ norm_cls, norm_ner = loss_cls, loss_ner
88
+
89
+ loss = norm_cls + self.ner_loss_weight * norm_ner
90
+ return loss, logits_class, logits_ner
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5a6b23224aa3fb90e0be01245560f0a84d5b537f60de8dbd37a1dd790aacec7
3
+ size 5304