| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, AutoModel |
|
|
| STEPS_ENCODER_NAME = "Qwen/Qwen3-Embedding-0.6B" |
| MAX_LEN = 512 |
| NUM_CLASSES = 6 |
| EMBED_DIM = 1024 |
| HIDDEN_DIMS = (512, 256) |
| DROPOUT = 0.1 |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| STEPS_ENCODER_NAME, |
| trust_remote_code=True, |
| use_fast=True |
| ) |
|
|
| class TextCollator: |
| def __init__(self, tokenizer, max_length=512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __call__(self, batch): |
| texts, labels = zip(*batch) |
| labels = torch.tensor(labels, dtype=torch.long) |
|
|
| enc = self.tokenizer( |
| list(texts), |
| padding=True, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt", |
| ) |
| return enc["input_ids"], enc["attention_mask"], labels |
|
|
| def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) |
| summed = (last_hidden_state * mask).sum(dim=1) |
| counts = mask.sum(dim=1).clamp(min=1e-6) |
| return summed / counts |
|
|
| class LegalSteps_Classifier(nn.Module): |
| def __init__( |
| self, |
| encoder_name: str, |
| embedding_dim: int, |
| num_classes: int, |
| hidden_dims=(512, 256), |
| dropout=0.1, |
| trust_remote_code=True, |
| ): |
| super().__init__() |
| self.encoder = AutoModel.from_pretrained(encoder_name, trust_remote_code=trust_remote_code) |
|
|
| |
| self.encoder.eval() |
| for p in self.encoder.parameters(): |
| p.requires_grad = False |
|
|
| layers = [] |
| in_dim = embedding_dim |
| for h in hidden_dims: |
| layers += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)] |
| in_dim = h |
|
|
| self.mlp = nn.Sequential(*layers) |
| self.classifier = nn.Linear(in_dim, num_classes) |
|
|
| @torch.no_grad() |
| def _encode(self, input_ids, attention_mask): |
| out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| emb = mean_pool(out.last_hidden_state, attention_mask) |
| return emb |
|
|
| def forward(self, input_ids, attention_mask): |
| with torch.no_grad(): |
| self.encoder.eval() |
| emb = self._encode(input_ids, attention_mask) |
| x = self.mlp(emb) |
| logits = self.classifier(x) |
| return logits |
|
|
|
|
| if __name__ == "__main__": |
| model = LegalSteps_Classifier( |
| encoder_name=STEPS_ENCODER_NAME, |
| embedding_dim=EMBED_DIM, |
| num_classes=NUM_CLASSES, |
| hidden_dims=HIDDEN_DIMS, |
| dropout=DROPOUT, |
| trust_remote_code=True |
| ) |
|
|