import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel STEPS_ENCODER_NAME = "Qwen/Qwen3-Embedding-0.6B" MAX_LEN = 512 NUM_CLASSES = 6 EMBED_DIM = 1024 HIDDEN_DIMS = (512, 256) DROPOUT = 0.1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained( STEPS_ENCODER_NAME, trust_remote_code=True, use_fast=True ) class TextCollator: def __init__(self, tokenizer, max_length=512): self.tokenizer = tokenizer self.max_length = max_length def __call__(self, batch): texts, labels = zip(*batch) labels = torch.tensor(labels, dtype=torch.long) enc = self.tokenizer( list(texts), padding=True, truncation=True, max_length=self.max_length, return_tensors="pt", ) return enc["input_ids"], enc["attention_mask"], labels def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [B, L, 1] summed = (last_hidden_state * mask).sum(dim=1) # [B, H] counts = mask.sum(dim=1).clamp(min=1e-6) # [B, 1] return summed / counts class LegalSteps_Classifier(nn.Module): def __init__( self, encoder_name: str, embedding_dim: int, num_classes: int, hidden_dims=(512, 256), dropout=0.1, trust_remote_code=True, ): super().__init__() self.encoder = AutoModel.from_pretrained(encoder_name, trust_remote_code=trust_remote_code) # Freeze encoder self.encoder.eval() for p in self.encoder.parameters(): p.requires_grad = False layers = [] in_dim = embedding_dim for h in hidden_dims: layers += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)] in_dim = h self.mlp = nn.Sequential(*layers) self.classifier = nn.Linear(in_dim, num_classes) @torch.no_grad() def _encode(self, input_ids, attention_mask): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) emb = mean_pool(out.last_hidden_state, attention_mask) # [B, H] return emb def forward(self, input_ids, attention_mask): with torch.no_grad(): self.encoder.eval() emb = self._encode(input_ids, attention_mask) x = self.mlp(emb) logits = self.classifier(x) return logits if __name__ == "__main__": model = LegalSteps_Classifier( encoder_name=STEPS_ENCODER_NAME, embedding_dim=EMBED_DIM, num_classes=NUM_CLASSES, hidden_dims=HIDDEN_DIMS, dropout=DROPOUT, trust_remote_code=True )