heendung's picture
Upload folder using huggingface_hub
d1c897a verified
from __future__ import annotations
from typing import Any, Dict, List
import torch
from transformers import PreTrainedTokenizer
from fastchat.model.model_adapter import get_model_adapter
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
def not_irrelevant(action_msg):
return action_msg.get("useful", True)
def history_to_sft_sample(
history: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
#model_path: str,
) -> Dict[str, torch.Tensor]:
"""Tokenize a conversation for supervised fine‑tuning.
Parameters
----------
history : List[Dict[str, str]]
Conversation in the HF chat format as returned by
``State.to_dict(format="hf")``. Each message must have ``"role"``
and ``"content"`` fields.
tokenizer : PreTrainedTokenizer
Tokenizer used to apply the chat template.
model_path : str
Model name or path used to determine the chat template.
Returns
-------
Dict[str, torch.Tensor]
``input_ids``, ``attention_mask`` and ``labels`` with tokens for
assistant turns only considered in the labels.
"""
# Prepend the default system prompt if required by the template
#conv = get_model_adapter(model_path).get_default_conv_template(model_path)
msgs = []
#if conv.system_message:
# msgs.append({"role": "system", "content": conv.system_message})
msgs.extend(history)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
res = tokenizer.apply_chat_template(
msgs,
tokenize=True,
add_generation_prompt=False,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
ids = res[0] if isinstance(res, torch.Tensor) else res["input_ids"][0]
labels = torch.full_like(ids, IGNORE_TOKEN_ID)
offset = 0
partial: List[Dict[str, str]] = []
last_assistant = None
for msg in msgs:
partial.append(msg)
try:
part_res = tokenizer.apply_chat_template(
partial,
tokenize=True,
add_generation_prompt=False,
padding=False,
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt",
)
part_ids = part_res[0] if isinstance(part_res, torch.Tensor) else part_res["input_ids"][0]
except Exception as e:
print("\n" + "="*80)
print(f"[ERROR] Failed to tokenize at message index {len(partial)-1}")
print(f"[ERROR] Exception: {e}")
print(f"[ERROR] Current message causing issue:")
print(f" Type: {type(msg)}")
if isinstance(msg, dict):
print(f" Keys: {msg.keys()}")
print(f" Role: {msg.get('role', 'MISSING')}")
content = msg.get('content', 'MISSING')
print(f" Content type: {type(content)}")
if not isinstance(content, str):
print(f" Content (MALFORMED - should be str): {content}")
elif len(content) > 500:
print(f" Content (truncated): {content[:500]}...")
else:
print(f" Content: {content}")
else:
print(f" MALFORMED message (not dict): {msg}")
print(f"[ERROR] Full partial history ({len(partial)} messages):")
for i, p_msg in enumerate(partial[-3:]): # Show last 3 messages
print(f" Message {len(partial)-3+i}: {type(p_msg)}")
if isinstance(p_msg, dict):
print(f" Role: {p_msg.get('role', 'MISSING')}")
p_content = p_msg.get('content', 'MISSING')
if not isinstance(p_content, str):
print(f" Content (MALFORMED): {type(p_content)} - {p_content}")
else:
print(f" Content preview: {p_content[:100]}...")
print("="*80 + "\n")
raise
seg_len = min(len(part_ids) - offset, len(labels) - offset)
if msg["role"] == "assistant" and seg_len > 0 and not_irrelevant(msg):
labels[offset : offset + seg_len] = ids[offset : offset + seg_len]
last_assistant = msg
offset = len(part_ids)
if offset >= len(labels):
break
if last_assistant is not None:
tokenizer.apply_chat_template(
[last_assistant],
tokenize=True,
add_generation_prompt=False,
padding=False,
return_tensors="pt",
)
attention_mask = ids.ne(tokenizer.pad_token_id)
return {"input_ids": ids, "attention_mask": attention_mask, "labels": labels}