hackmol-hackathon / Instadeep_NT_500M_CPT /PathScan /modeling_pathopreter.py
Wrostdevil's picture
Upload folder using huggingface_hub
9270c9a verified
import torch
import torch.nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
class SequenceClassificationWithFeatures(nn.Module):
def __init__(self, encoder, hidden_size, feature_dim, num_labels):
super().__init__()
self.encoder = encoder
self.feature_dim = feature_dim
self.classifier = nn.Sequential(
nn.Linear(hidden_size + feature_dim, hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_labels)
)
@classmethod
def from_pretrained(cls, path, device="cpu"):
import json
import os
from safetensors.torch import load_file
# Load Config
with open(os.path.join(path, "model_metadata.json"), "r") as f:
meta = json.load(f)
# Load Encoder
encoder = AutoModel.from_pretrained(path, trust_remote_code=True)
# Init Model
model = cls(
encoder=encoder,
hidden_size=encoder.config.hidden_size,
feature_dim=meta["feature_dim"],
num_labels=meta["num_labels"]
)
# Load Weights
state_dict = load_file(os.path.join(path, "model.safetensors"))
model.load_state_dict(state_dict)
return model.to(device)
def forward(self, input_ids=None, attention_mask=None, features=None, labels=None):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
if hasattr(out, "pooler_output") and out.pooler_output is not None:
pooled = out.pooler_output
else:
last = out.last_hidden_state
mask = attention_mask.unsqueeze(-1)
pooled = (last * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
# Cast features to match model dtype/device
features = features.to(pooled.device).to(pooled.dtype)
x = torch.cat([pooled, features], dim=1)
logits = self.classifier(x)
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(logits, labels)
return SequenceClassifierOutput(loss=loss, logits=logits)