RQA-R1 / modeling_rqa.py
skatzR's picture
Create modeling_rqa.py
36851e6 verified
# modeling_rqa.py
import torch
import torch.nn as nn
from typing import List, Optional
from transformers import (
AutoModel,
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModel,
)
# ============================================================
# CONFIG
# ============================================================
class RQAModelConfig(PretrainedConfig):
model_type = "rqa"
def __init__(
self,
base_model_name: str = "FacebookAI/xlm-roberta-large",
num_error_types: int = 6,
has_issue_projection_dim: int = 256,
errors_projection_dim: int = 512,
has_issue_dropout: float = 0.25,
errors_dropout: float = 0.3,
temperature_has_issue: float = 1.0,
temperature_errors: Optional[List[float]] = None,
**kwargs
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.num_error_types = num_error_types
self.has_issue_projection_dim = has_issue_projection_dim
self.errors_projection_dim = errors_projection_dim
self.has_issue_dropout = has_issue_dropout
self.errors_dropout = errors_dropout
self.temperature_has_issue = temperature_has_issue
self.temperature_errors = (
temperature_errors
if temperature_errors is not None
else [1.0] * num_error_types
)
# ============================================================
# POOLING
# ============================================================
class MeanPooling(nn.Module):
def forward(self, last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).float()
summed = torch.sum(last_hidden_state * mask, dim=1)
denom = torch.clamp(mask.sum(dim=1), min=1e-9)
return summed / denom
# ============================================================
# MODEL
# ============================================================
class RQAModelHF(PreTrainedModel):
config_class = RQAModelConfig
def __init__(self, config: RQAModelConfig):
super().__init__(config)
self.encoder = AutoModel.from_pretrained(config.base_model_name)
hidden_size = self.encoder.config.hidden_size
self.pooler = MeanPooling()
self.has_issue_projection = nn.Sequential(
nn.Linear(hidden_size, config.has_issue_projection_dim),
nn.LayerNorm(config.has_issue_projection_dim),
nn.GELU(),
nn.Dropout(config.has_issue_dropout),
)
self.errors_projection = nn.Sequential(
nn.Linear(hidden_size, config.errors_projection_dim),
nn.LayerNorm(config.errors_projection_dim),
nn.GELU(),
nn.Dropout(config.errors_dropout),
)
self.has_issue_head = nn.Linear(config.has_issue_projection_dim, 1)
self.errors_head = nn.Linear(
config.errors_projection_dim, config.num_error_types
)
self._init_custom_weights()
def _init_custom_weights(self):
for module in [
self.has_issue_projection[0],
self.errors_projection[0],
self.has_issue_head,
self.errors_head,
]:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
pooled = self.pooler(outputs.last_hidden_state, attention_mask)
has_issue_logits = self.has_issue_head(
self.has_issue_projection(pooled)
).squeeze(-1)
errors_logits = self.errors_head(
self.errors_projection(pooled)
)
return {
"has_issue_logits": has_issue_logits,
"errors_logits": errors_logits,
}
# ============================================================
# 🔥 TRANSFORMERS REGISTRATION (КРИТИЧНО)
# ============================================================
AutoConfig.register("rqa", RQAModelConfig)
AutoModel.register(RQAModelConfig, RQAModelHF)
print("✅ RQA зарегистрирован в Transformers")