turnlet-bert-multilingual-eou / inference_example.py
Estonel's picture
Initial commit: Turnlet BERT Multilingual EOU model with ONNX variants
f70597d verified
#!/usr/bin/env python3
"""
Simple inference example for Turnlet BERT Multilingual EOU model
Demonstrates both PyTorch and ONNX usage
"""
import argparse
import numpy as np
def test_pytorch(text, threshold=0.86):
"""Test using PyTorch model"""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
print("🔥 Loading PyTorch model...")
model = AutoModelForSequenceClassification.from_pretrained(".")
tokenizer = AutoTokenizer.from_pretrained(".")
model.eval()
print(f"\n📝 Input: {text}")
# Tokenize and predict
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
prob_eou = probs[0][1].item()
is_eou = prob_eou > threshold
print(f"✅ EOU Probability: {prob_eou:.4f}")
print(f"🎯 Prediction: {'EOU (End of Utterance)' if is_eou else 'Non-EOU (Incomplete)'}")
print(f"📊 Threshold: {threshold}")
return is_eou, prob_eou
def test_onnx(text, model_path="bert_model_optimized_dynamic_int8.onnx", threshold=0.86):
"""Test using ONNX quantized model (faster)"""
import onnxruntime as ort
from transformers import AutoTokenizer
print("⚡ Loading ONNX Quantized INT8 model...")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(".")
session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
print(f"\n📝 Input: {text}")
# Tokenize
inputs = tokenizer(text, padding="max_length", max_length=128, truncation=True, return_tensors="np")
# Prepare ONNX inputs
ort_inputs = {
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64)
}
# Run inference
import time
start = time.time()
outputs = session.run(None, ort_inputs)
inference_time = (time.time() - start) * 1000
logits = outputs[0][0]
probs = np.exp(logits) / np.sum(np.exp(logits))
prob_eou = probs[1]
is_eou = prob_eou > threshold
print(f"✅ EOU Probability: {prob_eou:.4f}")
print(f"🎯 Prediction: {'EOU (End of Utterance)' if is_eou else 'Non-EOU (Incomplete)'}")
print(f"📊 Threshold: {threshold}")
print(f"⚡ Inference Time: {inference_time:.2f}ms")
return is_eou, prob_eou
def test_multiple_examples(use_onnx=True):
"""Test multiple examples in different languages"""
examples = [
("Thanks for your help!", "en", True),
("I need a train to Cambridge.", "en", True),
("What time does the", "en", False),
("धन्यवाद!", "hi", True), # Hindi: "Thank you!"
("मुझे मदद चाहिए", "hi", False), # Hindi: "I need help" (incomplete)
("¡Gracias por tu ayuda!", "es", True), # Spanish: "Thanks for your help!"
("Necesito un tren a", "es", False), # Spanish: "I need a train to" (incomplete)
]
print("\n" + "="*70)
print("🌐 MULTILINGUAL EOU DETECTION TEST")
print("="*70)
correct = 0
total = len(examples)
for text, lang, expected_eou in examples:
print(f"\n{'─'*70}")
print(f"🌍 Language: {lang.upper()}")
if use_onnx:
is_eou, prob = test_onnx(text, threshold=0.86)
else:
is_eou, prob = test_pytorch(text, threshold=0.86)
expected_str = "EOU" if expected_eou else "Non-EOU"
predicted_str = "EOU" if is_eou else "Non-EOU"
is_correct = is_eou == expected_eou
correct += is_correct
status = "✅ CORRECT" if is_correct else "❌ INCORRECT"
print(f"💡 Expected: {expected_str} | Got: {predicted_str} | {status}")
print(f"\n{'='*70}")
print(f"📊 ACCURACY: {correct}/{total} ({correct/total*100:.1f}%)")
print(f"{'='*70}\n")
def interactive_mode(use_onnx=True, threshold=0.86):
"""Interactive mode - continuously ask for input and predict"""
import onnxruntime as ort
from transformers import AutoTokenizer
import time
print("\n" + "="*70)
print("🎮 INTERACTIVE MODE - Multilingual EOU Detection")
print("="*70)
print("🌐 Supported languages: English, Hindi, Spanish")
print("📊 Threshold: {:.2f}".format(threshold))
if use_onnx:
print("⚡ Using: ONNX Quantized INT8 model (fast)")
tokenizer = AutoTokenizer.from_pretrained(".")
session = ort.InferenceSession("bert_model_optimized_dynamic_int8.onnx",
providers=['CPUExecutionProvider'])
else:
print("🔥 Using: PyTorch model")
from transformers import AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained(".")
model = AutoModelForSequenceClassification.from_pretrained(".")
model.eval()
print("\n💡 Type your text and press Enter to get EOU prediction")
print("💡 Type 'quit' or 'exit' to stop")
print("💡 Type 'examples' to see sample inputs")
print("="*70 + "\n")
sample_count = 0
while True:
try:
# Get user input
user_input = input("📝 Enter text: ").strip()
if not user_input:
continue
# Check for exit commands
if user_input.lower() in ['quit', 'exit', 'q']:
print("\n👋 Goodbye! Tested {} samples.".format(sample_count))
break
# Show examples
if user_input.lower() == 'examples':
print("\n📚 Example inputs to try:")
print(" English:")
print(" - 'Thanks for your help!' (EOU)")
print(" - 'I need to book a' (Non-EOU)")
print(" Hindi:")
print(" - 'धन्यवाद!' (Thank you! - EOU)")
print(" - 'मुझे मदद चाहिए' (I need help - could be EOU)")
print(" Spanish:")
print(" - '¡Muchas gracias!' (Thank you! - EOU)")
print(" - 'Necesito un tren a' (I need a train to - Non-EOU)")
print()
continue
sample_count += 1
print()
# Tokenize
inputs = tokenizer(user_input, padding="max_length", max_length=128,
truncation=True, return_tensors="np" if use_onnx else "pt")
# Predict
start = time.time()
if use_onnx:
# ONNX inference
ort_inputs = {
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64)
}
outputs = session.run(None, ort_inputs)
logits = outputs[0][0]
probs = np.exp(logits) / np.sum(np.exp(logits))
prob_eou = probs[1]
else:
# PyTorch inference
import torch
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
prob_eou = probs[0][1].item()
inference_time = (time.time() - start) * 1000
# Determine prediction
is_eou = prob_eou > threshold
# Display results with color coding
print("─" * 70)
if is_eou:
print("✅ Prediction: EOU (End of Utterance)")
print(" └─ The user has likely finished their thought")
else:
print("⏳ Prediction: Non-EOU (Incomplete)")
print(" └─ The user may still be speaking")
print(f"📊 Confidence: {prob_eou:.4f} (threshold: {threshold})")
print(f"⚡ Inference time: {inference_time:.2f}ms")
# Confidence bar
bar_length = 40
filled = int(bar_length * prob_eou)
bar = "█" * filled + "░" * (bar_length - filled)
print(f"📈 [{bar}] {prob_eou*100:.1f}%")
print("─" * 70 + "\n")
except KeyboardInterrupt:
print("\n\n👋 Interrupted! Tested {} samples. Goodbye!".format(sample_count))
break
except Exception as e:
print(f"❌ Error: {e}\n")
continue
def main():
parser = argparse.ArgumentParser(description="Test Turnlet BERT Multilingual EOU model")
parser.add_argument("--text", type=str, help="Text to classify")
parser.add_argument("--threshold", type=float, default=0.86, help="EOU threshold (default: 0.86)")
parser.add_argument("--pytorch", action="store_true", help="Use PyTorch instead of ONNX")
parser.add_argument("--test-suite", action="store_true", help="Run full test suite")
parser.add_argument("--interactive", "-i", action="store_true", help="Run in interactive mode")
args = parser.parse_args()
if args.interactive:
interactive_mode(use_onnx=not args.pytorch, threshold=args.threshold)
elif args.test_suite:
test_multiple_examples(use_onnx=not args.pytorch)
elif args.text:
if args.pytorch:
test_pytorch(args.text, args.threshold)
else:
test_onnx(args.text, threshold=args.threshold)
else:
# Default to interactive mode if no arguments provided
print("No arguments provided. Starting interactive mode...")
print("(Use --help to see all options)\n")
interactive_mode(use_onnx=True, threshold=args.threshold)
if __name__ == "__main__":
main()