Spaces:
Running
- Introduction expliquant le contexte et le probleme: TITO (Token-In Token-Out). Pour faire simple, c'est un probleme qui se presente quand on entraine un llm avec du RL, et qu'on fait du multi-turn. En gros, le llm va call une ou plusieurs fois des tools avant de donner sa reponse finale, et que l'on calcule la loss etc.
Une breve introduction au RL et au multi-turn RL
explique que le rl basiquement c'est une question, le modele genere une réponse, et en fonction de cette réponse, on calcule la récompense (reward) et on ajuste les poids du modèle pour améliorer ses performances.
Les modèle ajd sont capables d'utiliser des outils externe pour satisfaire la demande du prompt. Par exemple, une calculatrice. Le prompt demande au modele de calculer une expression mathématique, le modèle va d'abord generer un appel à la calculatrice, récuperer le résultat de la calculatrice, puis utiliser ce résultat pour formuler sa réponse finale.
Pour ce qui est du RL, il y un point fondamental à retenir. Le modèle doit être entrainé sur la sequence qu'il a généré. Bah oui ca prait evident, est-ce qu'il y un truc que je capte pas? Ok, je re-dis : il faut qu'il soit entrainé PRECISEMENT sur la sequence qu'il a généré. Au token près. Le RL est très sensible à ça. Si la question est "2+2", et que le modèle genere "4" il faut l'netrainer sur "4", pas sur " 4" ou "4.", ni sur "REPONSE: 4". Garde ca dans un coin de ta tete, ca va etre important pour la suite.
Une brave introduction au chat templating
Un modele prend en entrée des tokens. Hors les data sont structurés en message. Comment passe-t-on de l'un à l'autre? En fait il y a deux stages: l'application du chat template, et la tokenization. Le chat template est une structure qui permet de transformer les messages structurés en texte, et la tokenization transforme le texte en tokens.
La specificité du multi-turn, c'est que lorsque le modèle genere des token, il faut que l'on ait un moyen de savoir s'il requiret l'appel à un outil ou pas. Cela implique donc de faire l'operation inverse: decoder (ou parser) les tokens generés.
<example de réponse du modèle qui une fois décodé corresponds à un appel d'outil>
Ce qui est crutial de comprendre pour ce blog, c'est que ce processus n'est pas reversible! en langauge mathématique, on pourrait parler d'une fonction non injective. Quoi? j'ai pas juste à ré-encoder cet appel? Oui tu as raison, en general ca te donne la même sequence de tokens que celle generée. Mais l'important est qu'il n'existe aucune garantie!!
<exemple de de tool call qui une fois ré-encodé ne correspond pas à la même sequence de tokens que celle générée initialement>
Ok, je vois il y a un token ou deux de différence, mais pourquoi? est-ce un bug? non, en fait il y plein de raisons pour lesquelles le decoding->econding n'est pas l'identité. Le BPE, s'il y a une façon d'encoder un texte, plusieurs suite de token peuvent donner le même texte. C'en est une mais pas la seule.
Est-ce que c'est pas juste un peu de la branlette? En vrai c'est globabelment la même sequence on s'en fiche non? En inference oui. Le modèle est complement capable de generer la suite de la réponse si tu lui donnes cette sequence legeerent modifée. Mais pour le training c'est un gros probleme.
La façon naturelle que n'importe qui utiliserait pour entrainer un LLM multi-turn avec du RL
Je te donne le pseudo code que n'importe qui utiliserait pour entrainer un LLM multi-turn avec du RL.
Sample un prompt Tant que le modele doit generer un nouveau turn Tokenize la conversation jusqu'à maintenant Genere des tokens jusqu'à ce que le modele s'arrete Decode la réponse S'il y a un appel d'outil Execute l'outil avec les arguments fournis Ajoute l'appel d'outil et la réponse du modele à la conversation Le modèle doit generer un nouveau turn Sinon Le modèle ne doit pas generer de nouveau turn Calcule la récompense Tokenize la conversation complete Calcule la loss en utilisant la récompense et la conversation tokenizée Backpropagate la loss et update les poids du modèle
Et voila, tu es tombé dans le piège du TITO. Mais ou? Je vais te donner un exemple, étape par etape, sois attentif aux détails:
<faire un figure qui va step par step dans l'algorithme ci-dessus, et montrer comment le TITO se produit. Il faut faire attention à être très clair ici, on pourra utiliser un double espace ou qq chose comme ça pour provoquer le mismatch>
Sample un prompt `[{"role": "user", "content": "What's 2+2?"}]`
Tant que le modele doit generer un nouveau turn
Tokenize la conversation jusqu'à maintenant
Genere des tokens jusqu'à ce que le modele s'arrete
Decode la réponse `{"tool_call": {"name": "add", "arguments": {"a": 2, "b": 2}}}`
S'il y a un appel d'outil
Execute l'outil avec les arguments fournis `4`
Ajoute l'appel d'outil et la réponse du modele à la conversation
`[{"role": "user", "content": "What's 2+2?"}, {"role": "assistant", "tool_call": {"name": "add", "arguments": {"a": 2, "b": 2}}}, {"role": "tool", "name": "add", "content": "4"}]`
Le modèle doit generer un nouveau turn
Sinon
Le modèle ne doit pas generer de nouveau turn
Calcule la récompense
Tokenize la conversation complete
Calcule la loss en utilisant la récompense et la conversation tokenizée
Backpropagate la loss et update les poids du modèle
Ok maintenant tu vois clairement le probleme, la sequence de token utilisée à la fin pour le calcul de la loss n'est pas exactement la même que celle générée par le modèle.
Oui mais c'est un cas très particulier, en general il y a pas de mismatch, non? Oui en effet, en general, il n'y pas de mismatch. Mais quelque fois il y en a un. Et ces quelques fois suffisent pour destabiliser fortement le training.
Comment s'assurer s'assurer du TITO?
La seule façon de s'assurer que le training reste stable, est de s'assurer que l'on entraine le modele sur les token qu'il a produit, comme ce qu'on a affirmé dans la section sur le RL. Mais comment faire? On est obliger de decoder à un moment non? En effet, il faut qu'on implemente le training en s'assurant que l'on ne ré-encode jamais les token decodé! C'est la regle d'or.
Pour cela, il faut modifer un peut l'algorithme:
Sample un prompt
Tokenize le prompt
Tant que le modele doit generer un nouveau turn
Genere des tokens jusqu'à ce que le modele s'arrete
Ajoute les tokens générés par le modèle à la sequence de tokens
Decode les tokens generés
S'il y a un appel d'outil
Execute l'outil avec les arguments fournis
Tokenize la réponse de l'outil et ajoute les tokens à la conversation
Le modèle doit generer un nouveau turn
Sinon
Le modèle ne doit pas generer de nouveau turn
Regarde bien, on a de nouveau un decoding, mais les tokens qui sont décodé ne sont jamais ré-encodés. Avec la même animantion que précédement ce sera plus simple à comprendre.
<faire un figure qui va step par step dans l'algorithme ci-dessus>
La seule phase critique: tokenizer la réponse
Il y a une phase dont on a peut être eludé la complexité: "Tokenize la réponse de l'outil et ajoute les tokens à la conversation". Par exemple, si la sequence est:
[151644, 872, 198, 3838, 374, 220, 17, 10, 17, 13, 151645, 198, 151644, 77091, 198, 151657, 198, 4913, 606, 788, 330, 718, 497, 330, 16370, 788, 5212, 64, 788, 220, 17, 11, 330, 65, 788, 220, 17, 11248, 151658, 151645, 198, 151644, 872, 198, 151665, 198, 19, 198, 151666, 151645, 198]
(cela correspond à la conversation
messages = [
{"role": "user", "content": "What is 2+2."},
{"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "add", "arguments": {"a": 2, "b": 2}}}]},
]
)
et que la réponse de l'outil est "4", on ne peut pas juste naivement tokenizer "4" (cela donnerait [19]), et l'ajouter à la sequence de tokens. Pour Qwen3, le modèle s'attend à ce que la réponse de l'outil soit encadrée par des tokens spéciaux, en l'occurence \n<|im_start|>user\n<tool_response>\n4\n</tool_response><|im_end|>\n<|im_start|>assistant\n (on préparer aussi la generation pour le prochain turn).
Alors lorsque l'on connait ce format c'est possiblement simple. Mais comment fait-on dans le cas general? Pour que ce fonctionne avec tous tous les modèles?
Il existe une manipulation assez simple pour obtenir les tokens qu'il faut ajouter à la séquence de tokens. La voici:
def get_tool_response_tokens(tool_response, tokenizer):
# On construit une conversation dummy
dummy_messages = [
{"role": "user", "content": "dummy"},
{"role": "assistant", "content": "dummy"},
]
# on a le tool message
tool_message = [{"role": "tool", "content": tool_response}]
# On tokenize d'abord la conversation sans le tool message pour avoir le prefix
prefix = tokenizer.apply_chat_template(dummy_messages, add_generation_prompt=False, tokenizer=True, return_dict=False)
# On tokenize ensuite la conversation complete avec le tool message pour avoir le suffix
full = tokenizer.apply_chat_template(dummy_messages + tool_message, add_generation_prompt=True, tokenizer=True, return_dict=False)
# On retourne la difference entre les deux sequences de tokens
return full[len(prefix):]
et voila c'est assez simple, petite demo avec Qwen2.5:
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
>>> tool_response = "4"
>>> tokens = get_tool_response_tokens(tool_response, tokenizer)
>>> tokens
[151644, 872, 198, 27, 14172, 9655, 397, 19, 198, 522, 14172, 9655, 29, 151645, 198, 151644, 77091, 198]
>>> tokenizer.decode(tokens)
'<|im_start|>user\n<tool_response>\n4\n</tool_response><|im_end|>\n<|im_start|>assistant\n'
et c'est tout, aussi simple que ça.
Alors oui, pour que cette fonction fonctionne, il existe une condition importante. C'est que lorsqu'on ajoute le tool message, le prefix soit conservé. C'est ce qu'on apple la prefix conservation. Heuresement, tous les modèles testés (ajouter la liste des modèles) verifie cette condition.
Il est important de noter qu'il n'est pas necessaire que le template soit prefix-preserving pour les message de l'utilisateur, pour le thinking, ou pour quoique ce soit d'autre. La plupart des chat template ne le sont pas. Soit ils collapse le thinking ou autre, et en plus c'est une trend de plus en plus populaire. Il est important de noter que cette condition est vraiment limité à un tool message. C'est en cela que "renderes" blog post de Prime se trompe. Le preservation du prefix est en fait une hypothese très faible et très largement respectée.
Pour aller plus loin: thinking collapse
Certain modèles ont un template qui collapse le thinking, c'est à dire qui ne garde uniquement que le dernier thinking. Par exemple:
c'est assez pratique pour l'inference car les modèle de reasoning sont assez gourmand en tokens. Néanmoins, pour le training, vous remarquerez qui si on suit l'algo de la section precédente, ce thinking collapse n'est jamais mis en oeuvre, les tokens ne font que s'accumuler
en fait il ne s'agit pas d'une limitation mais plutot d'une benediction car essayer de collasper le thinking pour le training invaliderait la conddition que l'on a établit dans la toute premiere section: la sequence finale doit correspondre exactement à la sequence générée par le modèle. Collapser le thinking MODIFIE l'historique.
Mai si on entraine que sur des sequence à full thinking, ne risque-t-on pas de créer un gap entre l'inference et le training qui peut être detrimental? En d'autre terme, si le modele au cours de son apprentissage n'a vu que des conversation complete avec du thinking a chaque tours, ne risque-t-il pas de rencontrer des difficultés lors de l'inference, pendant laquelle on va collapse le thinking? La litterature est encore assez pauvre sur le sujet, et mon intuition, pour ce qu'elle vaut est non.
Il existe néanmoins un façon d'activer ce thinking collapse pour training RL. Cela consiste en l'idée de forker la conversation. En gros, lors de l'inference, au lieu de continuer l'inference, tu prend la conversation, que tu decode complement, et tu l'ajoute au dataset de prompt, sans t'entrainer dessus. Et tu ne t'entraine dessus que lors du dernier turn. Mais attends, ça veut dire qu'on re-encode la conversation generé, es-tce qu'on ne revient pas au probleme initial? Non, car cette partie decodée sera considérée comme un prompt, et le modele ne sera pas entrainé dessus. Il n'y a donc pas de mismatch. Et lors du dernier turn, on entraine le modele sur la sequence complete, qui est exactement celle qu'il a généré. C'est pas ideal comme entrainement, pour plusieurs raisons qui sont hors du scope de ce blog, mais c'est une solution qui permet de faire voir au modèle avec ce genre de pattern que que l'inference soit moins hors domaine.