gene-entity-recognition / custom_modeling.py
RaduGabriel's picture
Upload custom_modeling.py with huggingface_hub
31d8586 verified
from transformers import AutoModelForTokenClassification, AutoModel, AutoConfig
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from torchcrf import CRF
from typing import Optional, Union, Tuple, List
import os
import json
class TransformerCRFForTokenClassification(AutoModelForTokenClassification):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.base_model = AutoModel.from_config(config=config, use_safetensors=True)
hidden_size = config.hidden_size if hasattr(config, 'hidden_size') else 768
self.dropout = nn.Dropout(config.hidden_dropout_prob if hasattr(config, 'hidden_dropout_prob') else 0.1)
self.classifier = nn.Linear(hidden_size, config.num_labels)
self.use_crf = config.use_crf if hasattr(config, 'use_crf') else False
if self.use_crf:
self.crf = CRF(num_tags=self.num_labels, batch_first=True)
else:
self.crf = None
self.loss_fn = nn.CrossEntropyLoss()
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.base_model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.crf is not None:
mask = attention_mask.bool()
labels_mask = labels != -100
mask = mask & labels_mask
loss = -self.crf(logits, labels, mask=mask, reduction='mean')
else:
loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions if output_attentions else None,
)
def save_pretrained(self, save_directory: str, **kwargs):
"""Save model with custom CRF layer"""
# Save the config
self.config.use_crf = self.use_crf
self.config.save_pretrained(save_directory, safe_serialization=True)
# Save the model weights
super().save_pretrained(save_directory, safe_serialization=True, **kwargs)
if self.crf is not None:
crf_path = os.path.join(save_directory, "crf.pt")
torch.save(self.crf.state_dict(), crf_path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
"""Load model with custom CRF layer"""
if 'config' in kwargs:
config = kwargs.pop('config')
else:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
# Ensure use_crf is set in the configuration
if not hasattr(config, 'use_crf'):
config.use_crf = False # or True, depending on your default
# Load the model
model = super().from_pretrained(pretrained_model_name_or_path, config=config, use_safetensors=True, *model_args, **kwargs)
# Initialize CRF if needed
if config.use_crf:
model.crf = CRF(num_tags=config.num_labels, batch_first=True)
crf_path = os.path.join(pretrained_model_name_or_path, "crf.pt")
if os.path.exists(crf_path):
model.crf.load_state_dict(torch.load(crf_path))
else:
model.crf = None
return model