File size: 4,802 Bytes
d4398e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""

Tokenization Controls Module

==============================

Tokenizer selection, token counting, truncation, and splitting.

Supports tiktoken (OpenAI) and HuggingFace tokenizers.

"""

from dataclasses import dataclass
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np


@dataclass
class TokenizationConfig:
    """Configuration for tokenization controls."""
    tokenizer_name: str = "tiktoken"          # "tiktoken" or HF model name
    tiktoken_encoding: str = "cl100k_base"    # for tiktoken
    max_total_tokens: int = 2048
    truncate_long: bool = False
    split_long: bool = False
    split_overlap: int = 50                   # overlap tokens when splitting


def get_tokenizer(config: TokenizationConfig):
    """

    Return a tokenizer-like object.

    For tiktoken: returns the encoding object.

    For HF: returns AutoTokenizer instance.

    """
    if config.tokenizer_name == "tiktoken":
        try:
            import tiktoken
            return tiktoken.get_encoding(config.tiktoken_encoding)
        except ImportError:
            raise ImportError("tiktoken is required. Install with: pip install tiktoken")
    else:
        try:
            from transformers import AutoTokenizer
            return AutoTokenizer.from_pretrained(config.tokenizer_name)
        except ImportError:
            raise ImportError("transformers is required for HF tokenizers.")


def count_tokens(text: str, tokenizer, is_tiktoken: bool = True) -> int:
    """Count tokens in a text string."""
    if not isinstance(text, str) or not text.strip():
        return 0
    if is_tiktoken:
        return len(tokenizer.encode(text))
    else:
        return len(tokenizer.encode(text, add_special_tokens=False))


def compute_token_stats(

    df: pd.DataFrame,

    columns: List[str],

    tokenizer,

    is_tiktoken: bool = True,

) -> Dict[str, Dict[str, float]]:
    """

    Compute token statistics for specified columns.

    Returns dict of column -> {min, max, mean, median, p95, total}.

    """
    stats = {}
    for col in columns:
        if col not in df.columns:
            continue
        counts = df[col].apply(lambda t: count_tokens(t, tokenizer, is_tiktoken))
        stats[col] = {
            'min': int(counts.min()) if len(counts) > 0 else 0,
            'max': int(counts.max()) if len(counts) > 0 else 0,
            'mean': round(float(counts.mean()), 1) if len(counts) > 0 else 0,
            'median': int(counts.median()) if len(counts) > 0 else 0,
            'p95': int(np.percentile(counts, 95)) if len(counts) > 0 else 0,
            'total': int(counts.sum()),
        }
    return stats


def truncate_samples(

    df: pd.DataFrame,

    col: str,

    max_tokens: int,

    tokenizer,

    is_tiktoken: bool = True,

) -> pd.DataFrame:
    """Truncate text in a column to max_tokens."""
    df = df.copy()

    def _truncate(text):
        if not isinstance(text, str):
            return text
        if is_tiktoken:
            tokens = tokenizer.encode(text)
            if len(tokens) > max_tokens:
                return tokenizer.decode(tokens[:max_tokens])
        else:
            tokens = tokenizer.encode(text, add_special_tokens=False)
            if len(tokens) > max_tokens:
                return tokenizer.decode(tokens[:max_tokens])
        return text

    df[col] = df[col].apply(_truncate)
    return df


def split_long_samples(

    df: pd.DataFrame,

    col: str,

    max_tokens: int,

    tokenizer,

    is_tiktoken: bool = True,

    overlap: int = 50,

) -> pd.DataFrame:
    """

    Split rows whose text exceeds max_tokens into multiple rows.

    Each chunk has `overlap` tokens of context from the previous chunk.

    """
    new_rows = []
    for _, row in df.iterrows():
        text = row[col]
        if not isinstance(text, str):
            new_rows.append(row)
            continue

        if is_tiktoken:
            tokens = tokenizer.encode(text)
        else:
            tokens = tokenizer.encode(text, add_special_tokens=False)

        if len(tokens) <= max_tokens:
            new_rows.append(row)
        else:
            step = max(1, max_tokens - overlap)
            for i in range(0, len(tokens), step):
                chunk_tokens = tokens[i:i + max_tokens]
                if not chunk_tokens:
                    break
                new_row = row.copy()
                if is_tiktoken:
                    new_row[col] = tokenizer.decode(chunk_tokens)
                else:
                    new_row[col] = tokenizer.decode(chunk_tokens)
                new_rows.append(new_row)

    return pd.DataFrame(new_rows).reset_index(drop=True)