Qwen3-0.6B-steps-classifier / steps_classifier.py
shaqaqio's picture
Upload steps_classifier.py with huggingface_hub
3ea50ad verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
STEPS_ENCODER_NAME = "Qwen/Qwen3-Embedding-0.6B"
MAX_LEN = 512
NUM_CLASSES = 6
EMBED_DIM = 1024
HIDDEN_DIMS = (512, 256)
DROPOUT = 0.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
STEPS_ENCODER_NAME,
trust_remote_code=True,
use_fast=True
)
class TextCollator:
def __init__(self, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.max_length = max_length
def __call__(self, batch):
texts, labels = zip(*batch)
labels = torch.tensor(labels, dtype=torch.long)
enc = self.tokenizer(
list(texts),
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
return enc["input_ids"], enc["attention_mask"], labels
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [B, L, 1]
summed = (last_hidden_state * mask).sum(dim=1) # [B, H]
counts = mask.sum(dim=1).clamp(min=1e-6) # [B, 1]
return summed / counts
class LegalSteps_Classifier(nn.Module):
def __init__(
self,
encoder_name: str,
embedding_dim: int,
num_classes: int,
hidden_dims=(512, 256),
dropout=0.1,
trust_remote_code=True,
):
super().__init__()
self.encoder = AutoModel.from_pretrained(encoder_name, trust_remote_code=trust_remote_code)
# Freeze encoder
self.encoder.eval()
for p in self.encoder.parameters():
p.requires_grad = False
layers = []
in_dim = embedding_dim
for h in hidden_dims:
layers += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)]
in_dim = h
self.mlp = nn.Sequential(*layers)
self.classifier = nn.Linear(in_dim, num_classes)
@torch.no_grad()
def _encode(self, input_ids, attention_mask):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
emb = mean_pool(out.last_hidden_state, attention_mask) # [B, H]
return emb
def forward(self, input_ids, attention_mask):
with torch.no_grad():
self.encoder.eval()
emb = self._encode(input_ids, attention_mask)
x = self.mlp(emb)
logits = self.classifier(x)
return logits
if __name__ == "__main__":
model = LegalSteps_Classifier(
encoder_name=STEPS_ENCODER_NAME,
embedding_dim=EMBED_DIM,
num_classes=NUM_CLASSES,
hidden_dims=HIDDEN_DIMS,
dropout=DROPOUT,
trust_remote_code=True
)