| from __future__ import annotations |
|
|
| from typing import Any, Dict, List |
|
|
| import torch |
| from transformers import PreTrainedTokenizer |
|
|
| from fastchat.model.model_adapter import get_model_adapter |
| from transformers.trainer_pt_utils import LabelSmoother |
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index |
|
|
| def not_irrelevant(action_msg): |
| return action_msg.get("useful", True) |
|
|
|
|
| def history_to_sft_sample( |
| history: List[Dict[str, str]], |
| tokenizer: PreTrainedTokenizer, |
| |
| ) -> Dict[str, torch.Tensor]: |
| """Tokenize a conversation for supervised fine‑tuning. |
| |
| Parameters |
| ---------- |
| history : List[Dict[str, str]] |
| Conversation in the HF chat format as returned by |
| ``State.to_dict(format="hf")``. Each message must have ``"role"`` |
| and ``"content"`` fields. |
| tokenizer : PreTrainedTokenizer |
| Tokenizer used to apply the chat template. |
| model_path : str |
| Model name or path used to determine the chat template. |
| |
| Returns |
| ------- |
| Dict[str, torch.Tensor] |
| ``input_ids``, ``attention_mask`` and ``labels`` with tokens for |
| assistant turns only considered in the labels. |
| """ |
|
|
| |
| |
| msgs = [] |
| |
| |
| msgs.extend(history) |
| if getattr(tokenizer, "pad_token", None) is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| res = tokenizer.apply_chat_template( |
| msgs, |
| tokenize=True, |
| add_generation_prompt=False, |
| padding="max_length", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| ids = res[0] if isinstance(res, torch.Tensor) else res["input_ids"][0] |
| labels = torch.full_like(ids, IGNORE_TOKEN_ID) |
|
|
| offset = 0 |
| partial: List[Dict[str, str]] = [] |
| last_assistant = None |
| for msg in msgs: |
| partial.append(msg) |
| try: |
| part_res = tokenizer.apply_chat_template( |
| partial, |
| tokenize=True, |
| add_generation_prompt=False, |
| padding=False, |
| truncation=True, |
| max_length=tokenizer.model_max_length, |
| return_tensors="pt", |
| ) |
| part_ids = part_res[0] if isinstance(part_res, torch.Tensor) else part_res["input_ids"][0] |
| except Exception as e: |
| print("\n" + "="*80) |
| print(f"[ERROR] Failed to tokenize at message index {len(partial)-1}") |
| print(f"[ERROR] Exception: {e}") |
| print(f"[ERROR] Current message causing issue:") |
| print(f" Type: {type(msg)}") |
| if isinstance(msg, dict): |
| print(f" Keys: {msg.keys()}") |
| print(f" Role: {msg.get('role', 'MISSING')}") |
| content = msg.get('content', 'MISSING') |
| print(f" Content type: {type(content)}") |
| if not isinstance(content, str): |
| print(f" Content (MALFORMED - should be str): {content}") |
| elif len(content) > 500: |
| print(f" Content (truncated): {content[:500]}...") |
| else: |
| print(f" Content: {content}") |
| else: |
| print(f" MALFORMED message (not dict): {msg}") |
| print(f"[ERROR] Full partial history ({len(partial)} messages):") |
| for i, p_msg in enumerate(partial[-3:]): |
| print(f" Message {len(partial)-3+i}: {type(p_msg)}") |
| if isinstance(p_msg, dict): |
| print(f" Role: {p_msg.get('role', 'MISSING')}") |
| p_content = p_msg.get('content', 'MISSING') |
| if not isinstance(p_content, str): |
| print(f" Content (MALFORMED): {type(p_content)} - {p_content}") |
| else: |
| print(f" Content preview: {p_content[:100]}...") |
| print("="*80 + "\n") |
| raise |
| seg_len = min(len(part_ids) - offset, len(labels) - offset) |
| if msg["role"] == "assistant" and seg_len > 0 and not_irrelevant(msg): |
| labels[offset : offset + seg_len] = ids[offset : offset + seg_len] |
| last_assistant = msg |
| offset = len(part_ids) |
| if offset >= len(labels): |
| break |
|
|
| if last_assistant is not None: |
| tokenizer.apply_chat_template( |
| [last_assistant], |
| tokenize=True, |
| add_generation_prompt=False, |
| padding=False, |
| return_tensors="pt", |
| ) |
|
|
| attention_mask = ids.ne(tokenizer.pad_token_id) |
| return {"input_ids": ids, "attention_mask": attention_mask, "labels": labels} |
|
|
|
|
|
|
|
|