aiba-bert-bilstm / inference.py
primel's picture
Upload AIBA BERT-BiLSTM v2 (Avg F1: 0.9566)
0aaead9 verified
"""
Example inference script for AIBA BERT-BiLSTM model
This script demonstrates how to use the model for prediction.
"""
import torch
from transformers import AutoTokenizer
from nn_model import load_model_and_tokenizer
def predict(text, model_path="YOUR_USERNAME/aiba-bert-bilstm"):
"""
Make predictions on input text
Args:
text: Input text to analyze
model_path: Hugging Face model repo ID
Returns:
dict with extracted entities, intent, and language
"""
# Load model
model, tokenizer, config = load_model_and_tokenizer(model_path)
# Get label mappings from config
id2tag = {int(k): v for k, v in config['id2tag'].items()}
id2intent = {int(k): v for k, v in config['id2intent'].items()}
id2lang = {int(k): v for k, v in config['id2lang'].items()}
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=config['max_length'],
return_offsets_mapping=True
)
offset_mapping = inputs.pop('offset_mapping')[0]
# Predict
with torch.no_grad():
outputs = model(**inputs)
# Process NER predictions
ner_predictions = torch.argmax(outputs['ner_logits'], dim=2)[0]
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Extract entities
entities = []
current_entity = None
current_tokens = []
for token, pred_id, offset in zip(tokens, ner_predictions, offset_mapping):
if token in ['[CLS]', '[SEP]', '[PAD]']:
continue
tag = id2tag[pred_id.item()]
if tag.startswith('B-'):
if current_entity:
entity_text = text[current_tokens[0][0]:current_tokens[-1][1]]
entities.append({'type': current_entity, 'value': entity_text})
current_entity = tag[2:]
current_tokens = [offset.tolist()]
elif tag.startswith('I-') and current_entity:
current_tokens.append(offset.tolist())
else:
if current_entity:
entity_text = text[current_tokens[0][0]:current_tokens[-1][1]]
entities.append({'type': current_entity, 'value': entity_text})
current_entity = None
current_tokens = []
if current_entity:
entity_text = text[current_tokens[0][0]:current_tokens[-1][1]]
entities.append({'type': current_entity, 'value': entity_text})
# Process intent prediction
intent_logits = outputs['intent_logits'][0]
intent_probs = torch.nn.functional.softmax(intent_logits, dim=0)
intent_id = torch.argmax(intent_probs).item()
intent = id2intent[intent_id]
intent_confidence = intent_probs[intent_id].item()
# Process language prediction
lang_logits = outputs['lang_logits'][0]
lang_probs = torch.nn.functional.softmax(lang_logits, dim=0)
lang_id = torch.argmax(lang_probs).item()
language = id2lang[lang_id]
lang_confidence = lang_probs[lang_id].item()
return {
'text': text,
'entities': entities,
'intent': intent,
'intent_confidence': intent_confidence,
'language': language,
'language_confidence': lang_confidence
}
# Example usage
if __name__ == "__main__":
# Test examples
examples = [
"Qabul qiluvchi Omad Biznes MCHJ STIR 123456789 summa 500000 UZS",
"Получатель ООО Прогресс ИНН 987654321 сумма 1000000 руб",
"Transfer 5000 USD to Starlight Ventures LLC TIN 555666777"
]
for text in examples:
print(f"\nInput: {text}")
result = predict(text)
print(f"Intent: {result['intent']} ({result['intent_confidence']:.2%})")
print(f"Language: {result['language']} ({result['language_confidence']:.2%})")
print("Entities:")
for entity in result['entities']:
print(f" - {entity['type']}: {entity['value']}")