liyang-ict commited on
Commit
bb5f351
·
verified ·
1 Parent(s): 529170f

Upload modeling_qwen4dual_2CE_w_logic.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_qwen4dual_2CE_w_logic.py +171 -0
modeling_qwen4dual_2CE_w_logic.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Qwen2PreTrainedModel, Qwen2Model
4
+ from transformers.modeling_outputs import ModelOutput
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+ from .logic_consistency_loss import LogicConsistencyLoss
8
+
9
+
10
+ @dataclass
11
+ class DualTaskModelOutput(ModelOutput):
12
+ """
13
+ Output class for Dual-Task Models.
14
+ """
15
+ loss: Optional[torch.FloatTensor] = None
16
+ token_logits: torch.FloatTensor = None
17
+ sequence_logits: torch.FloatTensor = None
18
+
19
+ class QwenForDualTask(Qwen2PreTrainedModel):
20
+
21
+ supports_report_metrics: bool = True
22
+
23
+ def __init__(self, config):
24
+ super().__init__(config)
25
+ self.model = Qwen2Model(config)
26
+ # Token Classification Head
27
+ self.dropout = nn.Dropout(0.1)
28
+ self.token_classifier = nn.Linear(config.hidden_size, config.num_token_labels)
29
+ # Sequence Classification Head
30
+ self.sequence_classifier = nn.Linear(
31
+ config.hidden_size, config.num_sequence_labels, bias=False
32
+ )
33
+ # Loss Functions
34
+ self.token_loss_fn = nn.CrossEntropyLoss()
35
+ self.sequence_loss_fn = nn.CrossEntropyLoss()
36
+ self.logic_loss_fn = LogicConsistencyLoss(
37
+ n_classes=config.num_token_labels,
38
+ reduce=config.logic_reduce,
39
+ reduction="mean",
40
+ )
41
+ # Call post_init
42
+ self.post_init()
43
+
44
+ def post_init(self):
45
+ """
46
+ Custom initialization for classification heads.
47
+ """
48
+ # Initialize token classification head
49
+ nn.init.xavier_uniform_(self.token_classifier.weight)
50
+ if self.token_classifier.bias is not None:
51
+ nn.init.zeros_(self.token_classifier.bias)
52
+ # Initialize sequence classification head
53
+ nn.init.xavier_uniform_(self.sequence_classifier.weight)
54
+ if self.sequence_classifier.bias is not None:
55
+ nn.init.zeros_(self.sequence_classifier.bias)
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor | None = None,
60
+ attention_mask: torch.Tensor | None = None,
61
+ position_ids: torch.LongTensor | None = None,
62
+ past_key_values: list[torch.FloatTensor] | None = None,
63
+ inputs_embeds: torch.FloatTensor | None = None,
64
+ token_labels: torch.LongTensor | None = None,
65
+ sequence_labels: torch.LongTensor | None = None,
66
+ use_cache: bool | None = None,
67
+ output_attentions: bool | None = None,
68
+ output_hidden_states: bool | None = None,
69
+ return_dict: bool | None = None,
70
+ **kwargs
71
+ ):
72
+ return_dict = (
73
+ return_dict if return_dict is not None else self.config.use_return_dict
74
+ )
75
+
76
+ outputs = self.model(
77
+ input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_values=past_key_values,
81
+ inputs_embeds=inputs_embeds,
82
+ use_cache=use_cache,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict,
86
+ )
87
+ hidden_states = outputs[0]
88
+
89
+ # Sequence Classification
90
+ if input_ids is not None:
91
+ batch_size = input_ids.shape[0]
92
+ else:
93
+ batch_size = inputs_embeds.shape[0]
94
+
95
+ if self.config.pad_token_id is None and batch_size != 1:
96
+ raise ValueError(
97
+ "Cannot handle batch sizes > 1 if no padding token is defined."
98
+ )
99
+ if self.config.pad_token_id is None:
100
+ last_non_pad_token = -1
101
+ elif input_ids is not None:
102
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
103
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(
104
+ hidden_states.device, torch.int32
105
+ )
106
+ token_indices = torch.arange(
107
+ input_ids.shape[-1], device=hidden_states.device
108
+ )
109
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
110
+ else:
111
+ last_non_pad_token = -1
112
+
113
+ sequence_logits = self.sequence_classifier(
114
+ hidden_states[
115
+ torch.arange(batch_size, device=hidden_states.device),
116
+ last_non_pad_token,
117
+ ]
118
+ )
119
+ sequence_loss = None
120
+ if sequence_labels is not None:
121
+ sequence_loss = self.sequence_loss_fn(sequence_logits, sequence_labels)
122
+
123
+ # Token Classification
124
+ hidden_states = self.dropout(hidden_states)
125
+ token_logits = self.token_classifier(hidden_states)
126
+ token_loss = None
127
+ if token_labels is not None:
128
+ token_loss = self.token_loss_fn(
129
+ token_logits.view(-1, self.config.num_token_labels),
130
+ token_labels.view(-1),
131
+ )
132
+
133
+ # Logic Consistency Loss
134
+ logic_loss = None
135
+ if token_loss is not None and sequence_loss is not None:
136
+ token_mask = (token_labels != self.config.ignore_index).to(
137
+ token_logits.device, torch.int32
138
+ )
139
+ logic_loss = self.logic_loss_fn(sequence_logits, token_logits, token_mask)
140
+
141
+ # Total Loss
142
+ total_loss = None
143
+ if (
144
+ token_loss is not None
145
+ and sequence_loss is not None
146
+ and logic_loss is not None
147
+ ):
148
+ total_loss = (
149
+ self.config.alpha * token_loss
150
+ + self.config.beta * sequence_loss
151
+ + self.config.gamma * logic_loss
152
+ )
153
+
154
+ if hasattr(
155
+ self, "report_metrics"
156
+ ): # checking if the report method is accessible or not is the robust practice
157
+ self.report_metrics(
158
+ token_loss=token_loss,
159
+ sequence_loss=sequence_loss,
160
+ logic_loss=logic_loss,
161
+ )
162
+
163
+ if not return_dict:
164
+ output = (token_logits, sequence_logits)
165
+ return ((total_loss,) + output) if total_loss is not None else output
166
+
167
+ return DualTaskModelOutput(
168
+ loss=total_loss,
169
+ token_logits=token_logits,
170
+ sequence_logits=sequence_logits,
171
+ )