|
|
|
|
|
""" |
|
|
Simple inference example for Turnlet BERT Multilingual EOU model |
|
|
Demonstrates both PyTorch and ONNX usage |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import numpy as np |
|
|
|
|
|
def test_pytorch(text, threshold=0.86): |
|
|
"""Test using PyTorch model""" |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
print("🔥 Loading PyTorch model...") |
|
|
model = AutoModelForSequenceClassification.from_pretrained(".") |
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
model.eval() |
|
|
|
|
|
print(f"\n📝 Input: {text}") |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
prob_eou = probs[0][1].item() |
|
|
is_eou = prob_eou > threshold |
|
|
|
|
|
print(f"✅ EOU Probability: {prob_eou:.4f}") |
|
|
print(f"🎯 Prediction: {'EOU (End of Utterance)' if is_eou else 'Non-EOU (Incomplete)'}") |
|
|
print(f"📊 Threshold: {threshold}") |
|
|
|
|
|
return is_eou, prob_eou |
|
|
|
|
|
def test_onnx(text, model_path="bert_model_optimized_dynamic_int8.onnx", threshold=0.86): |
|
|
"""Test using ONNX quantized model (faster)""" |
|
|
import onnxruntime as ort |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
print("⚡ Loading ONNX Quantized INT8 model...") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) |
|
|
|
|
|
print(f"\n📝 Input: {text}") |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, padding="max_length", max_length=128, truncation=True, return_tensors="np") |
|
|
|
|
|
|
|
|
ort_inputs = { |
|
|
'input_ids': inputs['input_ids'].astype(np.int64), |
|
|
'attention_mask': inputs['attention_mask'].astype(np.int64) |
|
|
} |
|
|
|
|
|
|
|
|
import time |
|
|
start = time.time() |
|
|
outputs = session.run(None, ort_inputs) |
|
|
inference_time = (time.time() - start) * 1000 |
|
|
|
|
|
logits = outputs[0][0] |
|
|
probs = np.exp(logits) / np.sum(np.exp(logits)) |
|
|
prob_eou = probs[1] |
|
|
is_eou = prob_eou > threshold |
|
|
|
|
|
print(f"✅ EOU Probability: {prob_eou:.4f}") |
|
|
print(f"🎯 Prediction: {'EOU (End of Utterance)' if is_eou else 'Non-EOU (Incomplete)'}") |
|
|
print(f"📊 Threshold: {threshold}") |
|
|
print(f"⚡ Inference Time: {inference_time:.2f}ms") |
|
|
|
|
|
return is_eou, prob_eou |
|
|
|
|
|
def test_multiple_examples(use_onnx=True): |
|
|
"""Test multiple examples in different languages""" |
|
|
examples = [ |
|
|
("Thanks for your help!", "en", True), |
|
|
("I need a train to Cambridge.", "en", True), |
|
|
("What time does the", "en", False), |
|
|
("धन्यवाद!", "hi", True), |
|
|
("मुझे मदद चाहिए", "hi", False), |
|
|
("¡Gracias por tu ayuda!", "es", True), |
|
|
("Necesito un tren a", "es", False), |
|
|
] |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("🌐 MULTILINGUAL EOU DETECTION TEST") |
|
|
print("="*70) |
|
|
|
|
|
correct = 0 |
|
|
total = len(examples) |
|
|
|
|
|
for text, lang, expected_eou in examples: |
|
|
print(f"\n{'─'*70}") |
|
|
print(f"🌍 Language: {lang.upper()}") |
|
|
|
|
|
if use_onnx: |
|
|
is_eou, prob = test_onnx(text, threshold=0.86) |
|
|
else: |
|
|
is_eou, prob = test_pytorch(text, threshold=0.86) |
|
|
|
|
|
expected_str = "EOU" if expected_eou else "Non-EOU" |
|
|
predicted_str = "EOU" if is_eou else "Non-EOU" |
|
|
|
|
|
is_correct = is_eou == expected_eou |
|
|
correct += is_correct |
|
|
|
|
|
status = "✅ CORRECT" if is_correct else "❌ INCORRECT" |
|
|
print(f"💡 Expected: {expected_str} | Got: {predicted_str} | {status}") |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"📊 ACCURACY: {correct}/{total} ({correct/total*100:.1f}%)") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
def interactive_mode(use_onnx=True, threshold=0.86): |
|
|
"""Interactive mode - continuously ask for input and predict""" |
|
|
import onnxruntime as ort |
|
|
from transformers import AutoTokenizer |
|
|
import time |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("🎮 INTERACTIVE MODE - Multilingual EOU Detection") |
|
|
print("="*70) |
|
|
print("🌐 Supported languages: English, Hindi, Spanish") |
|
|
print("📊 Threshold: {:.2f}".format(threshold)) |
|
|
|
|
|
if use_onnx: |
|
|
print("⚡ Using: ONNX Quantized INT8 model (fast)") |
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
session = ort.InferenceSession("bert_model_optimized_dynamic_int8.onnx", |
|
|
providers=['CPUExecutionProvider']) |
|
|
else: |
|
|
print("🔥 Using: PyTorch model") |
|
|
from transformers import AutoModelForSequenceClassification |
|
|
import torch |
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
model = AutoModelForSequenceClassification.from_pretrained(".") |
|
|
model.eval() |
|
|
|
|
|
print("\n💡 Type your text and press Enter to get EOU prediction") |
|
|
print("💡 Type 'quit' or 'exit' to stop") |
|
|
print("💡 Type 'examples' to see sample inputs") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
sample_count = 0 |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
user_input = input("📝 Enter text: ").strip() |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
print("\n👋 Goodbye! Tested {} samples.".format(sample_count)) |
|
|
break |
|
|
|
|
|
|
|
|
if user_input.lower() == 'examples': |
|
|
print("\n📚 Example inputs to try:") |
|
|
print(" English:") |
|
|
print(" - 'Thanks for your help!' (EOU)") |
|
|
print(" - 'I need to book a' (Non-EOU)") |
|
|
print(" Hindi:") |
|
|
print(" - 'धन्यवाद!' (Thank you! - EOU)") |
|
|
print(" - 'मुझे मदद चाहिए' (I need help - could be EOU)") |
|
|
print(" Spanish:") |
|
|
print(" - '¡Muchas gracias!' (Thank you! - EOU)") |
|
|
print(" - 'Necesito un tren a' (I need a train to - Non-EOU)") |
|
|
print() |
|
|
continue |
|
|
|
|
|
sample_count += 1 |
|
|
print() |
|
|
|
|
|
|
|
|
inputs = tokenizer(user_input, padding="max_length", max_length=128, |
|
|
truncation=True, return_tensors="np" if use_onnx else "pt") |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
|
|
|
if use_onnx: |
|
|
|
|
|
ort_inputs = { |
|
|
'input_ids': inputs['input_ids'].astype(np.int64), |
|
|
'attention_mask': inputs['attention_mask'].astype(np.int64) |
|
|
} |
|
|
outputs = session.run(None, ort_inputs) |
|
|
logits = outputs[0][0] |
|
|
probs = np.exp(logits) / np.sum(np.exp(logits)) |
|
|
prob_eou = probs[1] |
|
|
else: |
|
|
|
|
|
import torch |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=-1) |
|
|
prob_eou = probs[0][1].item() |
|
|
|
|
|
inference_time = (time.time() - start) * 1000 |
|
|
|
|
|
|
|
|
is_eou = prob_eou > threshold |
|
|
|
|
|
|
|
|
print("─" * 70) |
|
|
if is_eou: |
|
|
print("✅ Prediction: EOU (End of Utterance)") |
|
|
print(" └─ The user has likely finished their thought") |
|
|
else: |
|
|
print("⏳ Prediction: Non-EOU (Incomplete)") |
|
|
print(" └─ The user may still be speaking") |
|
|
|
|
|
print(f"📊 Confidence: {prob_eou:.4f} (threshold: {threshold})") |
|
|
print(f"⚡ Inference time: {inference_time:.2f}ms") |
|
|
|
|
|
|
|
|
bar_length = 40 |
|
|
filled = int(bar_length * prob_eou) |
|
|
bar = "█" * filled + "░" * (bar_length - filled) |
|
|
print(f"📈 [{bar}] {prob_eou*100:.1f}%") |
|
|
print("─" * 70 + "\n") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\n👋 Interrupted! Tested {} samples. Goodbye!".format(sample_count)) |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"❌ Error: {e}\n") |
|
|
continue |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Test Turnlet BERT Multilingual EOU model") |
|
|
parser.add_argument("--text", type=str, help="Text to classify") |
|
|
parser.add_argument("--threshold", type=float, default=0.86, help="EOU threshold (default: 0.86)") |
|
|
parser.add_argument("--pytorch", action="store_true", help="Use PyTorch instead of ONNX") |
|
|
parser.add_argument("--test-suite", action="store_true", help="Run full test suite") |
|
|
parser.add_argument("--interactive", "-i", action="store_true", help="Run in interactive mode") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.interactive: |
|
|
interactive_mode(use_onnx=not args.pytorch, threshold=args.threshold) |
|
|
elif args.test_suite: |
|
|
test_multiple_examples(use_onnx=not args.pytorch) |
|
|
elif args.text: |
|
|
if args.pytorch: |
|
|
test_pytorch(args.text, args.threshold) |
|
|
else: |
|
|
test_onnx(args.text, threshold=args.threshold) |
|
|
else: |
|
|
|
|
|
print("No arguments provided. Starting interactive mode...") |
|
|
print("(Use --help to see all options)\n") |
|
|
interactive_mode(use_onnx=True, threshold=args.threshold) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|