YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
IDSF_BERT
This is a BERT-based model for Joint Intent Detection and Slot Filling (IDSF).
Model Description
- Model Type: BERT-based Joint Intent Detection and Slot Filling
- Custom Architecture: BertIDSF with intent and slot classification heads
- Language: English
Usage
import torch
import json
from transformers import AutoTokenizer, BertConfig
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from torch import nn
# First define the model architecture
class IntentClassifier(nn.Module):
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
super(IntentClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_intent_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class SlotClassifier(nn.Module):
def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
super(SlotClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_slot_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class BertIDSF(BertPreTrainedModel):
def __init__(self, config, intent_label_lst, slot_label_lst, n_layers=1):
super().__init__(config)
self.num_intent_labels = len(intent_label_lst)
self.num_slot_labels = len(slot_label_lst)
self.bert = BertModel(config=config)
# Store dictionaries in config for later use
self.config.dict2 = {str(idx+1): label for idx, label in enumerate(slot_label_lst)}
self.config.inte2 = {str(idx+1): label for idx, label in enumerate(intent_label_lst)}
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels)
self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
intents=None,
output_attentions=True,
lens=None,
device=None
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=True
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
intent_logits = self.intent_classifier(sequence_output[:, 0, :])
slot_logits = self.slot_classifier(sequence_output)
total_loss = 0
# Intent Softmax
if intents is not None:
intent_loss_fct = nn.CrossEntropyLoss()
intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intents.view(-1))
total_loss += 0.5 * intent_loss
# Slot Softmax
if labels is not None:
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=0)
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
slot_loss = slot_loss_fct(active_logits, active_labels)
else:
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), labels.view(-1))
total_loss += 0.5 * slot_loss
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
outputs = (total_loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
# Now load and use the model
model_path = "soltaniali/IDSF_BERT"
# Load dictionaries from JSON files
with open('dict2.json', 'r') as f:
dict2 = json.load(f)
with open('inte2.json', 'r') as f:
inte2 = json.load(f)
# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = BertConfig.from_pretrained(model_path)
model = BertIDSF.from_pretrained(
model_path,
config=config,
slot_label_lst=list(dict2.values()),
intent_label_lst=list(inte2.values())
)
# Process a sentence
sentence = "I want to transfer 200 dollars to my savings account"
# ... process with your IDSFService class
Important Note
This model uses a custom architecture (BertIDSF) and requires both the class definition and dictionaries to be loaded correctly.
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support