med-jargon-crf / modeling_jargon.py
DNivalis's picture
Update modeling_jargon.py
e002b8e verified
from transformers import AutoModel
from huggingface_hub import PyTorchModelHubMixin
from torchcrf import CRF
import torch.nn as nn
class CRFTokenClassificationModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__()
# Load base transformer model
self.transformer = AutoModel.from_pretrained(config["pretrained_model_name"])
# Classification layers
self.dropout = nn.Dropout(config["hidden_dropout_prob"])
self.classifier = nn.Linear(config["hidden_size"], config["num_labels"])
# CRF layer for sequence labeling
self.crf = CRF(config["num_labels"], batch_first=True)
# Label mappings
self.id2label = {v: k for k, v in config["label_map"].items()}
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
# Get transformer outputs
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
logits = self.classifier(sequence_output)
# Calculate loss if labels provided (training mode)
if labels is not None:
loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
return {"loss": loss, "logits": logits}
# Return logits only (inference mode)
return {"logits": logits}
def decode(self, logits, mask):
# Use CRF to decode best sequence
return self.crf.decode(logits, mask.bool())