Turn Taking
Collection
3 items • Updated
A fine-tuned Qwen2.5-0.5B-Instruct model that predicts turn-taking actions in conversations. Given a conversation context, the model predicts what action a voice AI agent should take next.
Unlike acoustic-based approaches (VAD, silence detection), this model uses the semantic content of the conversation to make turn-taking decisions.
The model predicts one of 4 actions:
| Action | Token | Description |
|---|---|---|
start_speaking |
<|start_speaking|> |
User finished their turn, agent should respond |
continue_listening |
<|continue_listening|> |
User is mid-utterance, keep listening |
start_listening |
<|start_listening|> |
User interrupted the agent, stop talking |
continue_speaking |
<|continue_speaking|> |
User gave a backchannel, agent keeps talking |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "anyreach-ai/semantic-turn-taking"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
model.eval()
# Format conversation as ChatML with <|predict|> trigger
conversation = """<|im_start|>user
I need help with my bill<|im_end|>
<|im_start|>assistant
Sure I can help with that what seems to be the issue<|im_end|>
<|im_start|>user
I was charged twice for the same order<|im_end|>
<|predict|>"""
inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
with torch.no_grad():
logits = model(**inputs).logits[:, -1, :]
# Get action probabilities
action_tokens = {
"start_speaking": tokenizer.convert_tokens_to_ids("<|start_speaking|>"),
"continue_listening": tokenizer.convert_tokens_to_ids("<|continue_listening|>"),
"start_listening": tokenizer.convert_tokens_to_ids("<|start_listening|>"),
"continue_speaking": tokenizer.convert_tokens_to_ids("<|continue_speaking|>"),
}
action_logits = {name: logits[0, tid].item() for name, tid in action_tokens.items()}
probs = torch.softmax(torch.tensor(list(action_logits.values())), dim=0)
for (name, _), p in zip(action_logits.items(), probs):
print(f" {name}: {p:.4f}")
# → start_speaking: 0.95+ (user is done, agent should respond)
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("anyreach-ai/semantic-turn-taking")
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4
session = ort.InferenceSession(
"onnx/model_q8.onnx", # download from this repo
providers=["CPUExecutionProvider"],
sess_options=sess_options,
)
# Tokenize
conversation = "..." # ChatML format as above
inputs = tokenizer(conversation, return_tensors="np")
input_ids = inputs["input_ids"].astype("int64")
seq_len = input_ids.shape[1]
# Build feed (empty KV cache for single forward pass)
feed = {
"input_ids": input_ids,
"attention_mask": inputs["attention_mask"].astype("int64"),
"position_ids": np.arange(seq_len, dtype="int64").reshape(1, -1),
}
for i in range(24):
feed[f"past_key_values.{i}.key"] = np.zeros((1, 2, 0, 64), dtype="float32")
feed[f"past_key_values.{i}.value"] = np.zeros((1, 2, 0, 64), dtype="float32")
# Run inference
logits = session.run(None, feed)[0] # [1, seq_len, vocab_size]
last_logits = logits[0, -1, :]
# Extract action probabilities
ACTION_IDS = [151666, 151665, 151667, 151668] # SS, CL, SLi, CS
action_logits = last_logits[ACTION_IDS]
probs = np.exp(action_logits) / np.sum(np.exp(action_logits))
Evaluated on anyreach-ai/semantic-turn-taking-benchmark.
Only start_speaking and continue_listening examples. Predictions mapped: SS/CS → EOU, CL/SLi → Not-EOU.
| Subset | N | Accuracy | F1 (macro) |
|---|---|---|---|
| TEN | 428 | 91.82% | 91.80% |
| SwDA | 2,688 | 65.96% | 51.46% |
| Synthetic | 36 | 86.11% | 85.57% |
| Subset | N | Classes | Accuracy | F1 (macro) |
|---|---|---|---|---|
| TEN | 428 | 2 | 91.82% | 91.80% |
| SwDA | 3,523 | 3 | 68.98% | 46.92% |
| Synthetic | 60 | 4 | 76.67% | 72.07% |
Measured on single examples, CPU (4 threads) and GPU (NVIDIA T4).
| Format | Size | Short (8 tok) | Medium (28 tok) | Long (54 tok) |
|---|---|---|---|---|
| PyTorch GPU (fp16) | 942 MB | 26 ms | 30 ms | 34 ms |
| PyTorch CPU (fp32) | 942 MB | 165 ms | 247 ms | 289 ms |
| ONNX CPU (q8) | 473 MB | 128 ms | 151 ms | 191 ms |
<|predict|> trigger token<|predict|>, 4 action tokens)| File | Description |
|---|---|
model.safetensors |
PyTorch model weights (fp32) |
onnx/model_q8.onnx |
ONNX INT8 quantized (dynamic quantization) |
config.json |
Model configuration |
tokenizer.json |
Tokenizer |
@misc{semantic-turn-taking-2026,
title={Semantic Turn-Taking Model},
author={Shangeth Rajaa},
year={2026},
publisher={Hugging Face},
url={https://huggingface.co/anyreach-ai/semantic-turn-taking}
}
Apache 2.0