Makeshift MTP: A dumb idea that might work
Multi-token prediction is having a moment. DeepMind released a paper on it. Everyone's talking about how models should predict multiple tokens ahead instead of just one. The problem? Most implementations require architecture changes. New training objectives. More parameters. More compute. More everything.
But what if we could fake it?
Here's the idea. You have your model and a prompt like "The cat ". Normal inference predicts one token. Boring. But what if we spawned multiple continuations in parallel, each making their own guesses?
"The cat rna..."
"The cat cank..."
"The cat ran..."
Each of these runs through the model as a forward pass. Nothing fancy. No architectural changes. Then we compute loss on all of them and pick the winner. The one with the lowest loss gets to continue.
Think about what this actually buys us. We're running inference X times instead of once, sure. But we're also sampling from the latent space in multiple directions at once. The model is essentially exploring different branches of probability and letting us pick the most coherent one.
And here's the nice part. The number of branches can be anything. Running on a potato? Generate two continuations and pick the better one. Have a GPU cluster sitting around? Spawn fifty. Time-constrained? Pick based on next-token loss only. Got all day? Evaluate the full generated sequence. The tradeoff between compute and quality becomes a dial you can turn.
Why this feels like MTP
Traditional multi-token prediction trains the model to output multiple tokens in a single forward pass. The model learns to think ahead. Our approach does something similar at inference time. We explore multiple futures and commit to the best one.
The difference is we never taught the model to do this. We just throw compute at the problem until it works. Crude? Maybe. But it runs on any model without retraining.
The actual benefits
First, no more regenerating bad outputs. If a branch goes off the rails, its loss spikes, and we simply don't pick it. The bad branch dies quietly without wasting user time on a regeneration request.
Second, no architecture changes. Your model stays the same. Your training pipeline stays the same. You just add a wrapper around inference that handles the branching and selection logic.
Third, compute flexibility. Real MTP baked the multi-token prediction into the model weights. Our approach lets you decide at runtime how much exploration you can afford.
Why this is probably a bad idea
Loss is a proxy for what we actually want, which is coherence, helpfulness, and correctness. A branch might have lower loss but still say something stupid. The model confidently predicting nonsense still has low loss if it's confidently predicting.
Also, this scales poorly. If you want to explore N branches for M tokens, you're doing N times the forward passes. At some point, just using a bigger model becomes cheaper.
But for small models? For experiments? For cases where you have time but not parameters? This might be genuinely useful.
We're planning to test this on FMN-GPT. The model is small enough that running multiple forward passes is actually affordable. Whether it helps or not, we'll write up the results. Probably the failures will be more interesting than the successes.