skatzR commited on
Commit
36851e6
·
verified ·
1 Parent(s): f160184

Create modeling_rqa.py

Browse files
Files changed (1) hide show
  1. modeling_rqa.py +136 -0
modeling_rqa.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_rqa.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import List, Optional
6
+ from transformers import (
7
+ AutoModel,
8
+ PreTrainedModel,
9
+ PretrainedConfig,
10
+ AutoConfig,
11
+ AutoModel,
12
+ )
13
+
14
+ # ============================================================
15
+ # CONFIG
16
+ # ============================================================
17
+
18
+ class RQAModelConfig(PretrainedConfig):
19
+ model_type = "rqa"
20
+
21
+ def __init__(
22
+ self,
23
+ base_model_name: str = "FacebookAI/xlm-roberta-large",
24
+ num_error_types: int = 6,
25
+ has_issue_projection_dim: int = 256,
26
+ errors_projection_dim: int = 512,
27
+ has_issue_dropout: float = 0.25,
28
+ errors_dropout: float = 0.3,
29
+ temperature_has_issue: float = 1.0,
30
+ temperature_errors: Optional[List[float]] = None,
31
+ **kwargs
32
+ ):
33
+ super().__init__(**kwargs)
34
+
35
+ self.base_model_name = base_model_name
36
+ self.num_error_types = num_error_types
37
+ self.has_issue_projection_dim = has_issue_projection_dim
38
+ self.errors_projection_dim = errors_projection_dim
39
+ self.has_issue_dropout = has_issue_dropout
40
+ self.errors_dropout = errors_dropout
41
+
42
+ self.temperature_has_issue = temperature_has_issue
43
+ self.temperature_errors = (
44
+ temperature_errors
45
+ if temperature_errors is not None
46
+ else [1.0] * num_error_types
47
+ )
48
+
49
+ # ============================================================
50
+ # POOLING
51
+ # ============================================================
52
+
53
+ class MeanPooling(nn.Module):
54
+ def forward(self, last_hidden_state, attention_mask):
55
+ mask = attention_mask.unsqueeze(-1).float()
56
+ summed = torch.sum(last_hidden_state * mask, dim=1)
57
+ denom = torch.clamp(mask.sum(dim=1), min=1e-9)
58
+ return summed / denom
59
+
60
+ # ============================================================
61
+ # MODEL
62
+ # ============================================================
63
+
64
+ class RQAModelHF(PreTrainedModel):
65
+ config_class = RQAModelConfig
66
+
67
+ def __init__(self, config: RQAModelConfig):
68
+ super().__init__(config)
69
+
70
+ self.encoder = AutoModel.from_pretrained(config.base_model_name)
71
+ hidden_size = self.encoder.config.hidden_size
72
+
73
+ self.pooler = MeanPooling()
74
+
75
+ self.has_issue_projection = nn.Sequential(
76
+ nn.Linear(hidden_size, config.has_issue_projection_dim),
77
+ nn.LayerNorm(config.has_issue_projection_dim),
78
+ nn.GELU(),
79
+ nn.Dropout(config.has_issue_dropout),
80
+ )
81
+
82
+ self.errors_projection = nn.Sequential(
83
+ nn.Linear(hidden_size, config.errors_projection_dim),
84
+ nn.LayerNorm(config.errors_projection_dim),
85
+ nn.GELU(),
86
+ nn.Dropout(config.errors_dropout),
87
+ )
88
+
89
+ self.has_issue_head = nn.Linear(config.has_issue_projection_dim, 1)
90
+ self.errors_head = nn.Linear(
91
+ config.errors_projection_dim, config.num_error_types
92
+ )
93
+
94
+ self._init_custom_weights()
95
+
96
+ def _init_custom_weights(self):
97
+ for module in [
98
+ self.has_issue_projection[0],
99
+ self.errors_projection[0],
100
+ self.has_issue_head,
101
+ self.errors_head,
102
+ ]:
103
+ if isinstance(module, nn.Linear):
104
+ nn.init.xavier_uniform_(module.weight)
105
+ nn.init.zeros_(module.bias)
106
+
107
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
108
+ outputs = self.encoder(
109
+ input_ids=input_ids,
110
+ attention_mask=attention_mask,
111
+ return_dict=True,
112
+ )
113
+
114
+ pooled = self.pooler(outputs.last_hidden_state, attention_mask)
115
+
116
+ has_issue_logits = self.has_issue_head(
117
+ self.has_issue_projection(pooled)
118
+ ).squeeze(-1)
119
+
120
+ errors_logits = self.errors_head(
121
+ self.errors_projection(pooled)
122
+ )
123
+
124
+ return {
125
+ "has_issue_logits": has_issue_logits,
126
+ "errors_logits": errors_logits,
127
+ }
128
+
129
+ # ============================================================
130
+ # 🔥 TRANSFORMERS REGISTRATION (КРИТИЧНО)
131
+ # ============================================================
132
+
133
+ AutoConfig.register("rqa", RQAModelConfig)
134
+ AutoModel.register(RQAModelConfig, RQAModelHF)
135
+
136
+ print("✅ RQA зарегистрирован в Transformers")