File size: 8,147 Bytes
f29b6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#!/usr/bin/env python3
# analysis_deberta_lengths.py
#
# Purpose:
# - Load one-text-per-line CSV: train_clean.csv (no header, each field quoted)
# - Tokenize with microsoft/mdeberta-v3-base
# - Report counts exceeding max_length (default 512)
# - Report length stats, histogram bins, and estimated sliding-window counts given doc_stride
#
# Notes:
# - Paths are relative to this script. Do not rename input file.
# - Shows a progress bar with n/N and ETA.

import os
import sys
import argparse
import math
import csv
from typing import List, Tuple

try:
    from tqdm import tqdm
except ImportError:
    print("tqdm not installed. Install with: pip install tqdm", file=sys.stderr)
    sys.exit(1)

try:
    import pandas as pd
except ImportError:
    print("pandas not installed. Install with: pip install pandas", file=sys.stderr)
    sys.exit(1)

try:
    from transformers import AutoTokenizer
except ImportError:
    print("transformers not installed. Install with: pip install transformers", file=sys.stderr)
    sys.exit(1)

def load_texts(csv_path: str) -> List[str]:
    # CSV has no header, one quoted field per row
    df = pd.read_csv(csv_path, header=None, quoting=csv.QUOTE_ALL, dtype=str)
    df.columns = ["text"]
    # Ensure single-line strings (train_clean.csv should already be single-line)
    texts = df["text"].astype(str).tolist()
    return texts

def compute_lengths(

    texts: List[str],

    tokenizer_name: str,

    max_length: int,

    batch_size: int = 64

) -> Tuple[List[int], List[int], int]:
    """

    Returns:

      - lengths_with_special: token lengths when add_special_tokens=True

      - lengths_no_special: token lengths when add_special_tokens=False

      - special_tokens_to_add: tokenizer.num_special_tokens_to_add(pair=False)

    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)

    lengths_with_special = []
    lengths_no_special = []

    # Measure special tokens added per single sequence
    special_tokens_to_add = tokenizer.num_special_tokens_to_add(pair=False)

    # Batch tokenize to speed up
    for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing", unit="batch"):
        batch = texts[i:i+batch_size]

        enc_with = tokenizer(
            batch,
            add_special_tokens=True,
            truncation=False,
            return_attention_mask=False,
            return_token_type_ids=False
        )
        enc_without = tokenizer(
            batch,
            add_special_tokens=False,
            truncation=False,
            return_attention_mask=False,
            return_token_type_ids=False
        )

        # HF returns 'input_ids' as list per sample
        lengths_with_special.extend([len(ids) for ids in enc_with["input_ids"]])
        lengths_no_special.extend([len(ids) for ids in enc_without["input_ids"]])

    return lengths_with_special, lengths_no_special, special_tokens_to_add

def summarize(lengths: List[int]):
    import numpy as np
    arr = np.array(lengths, dtype=int)
    stats = {
        "count": int(arr.size),
        "mean": float(arr.mean()) if arr.size else 0.0,
        "median": float(np.median(arr)) if arr.size else 0.0,
        "min": int(arr.min()) if arr.size else 0,
        "p90": float(np.percentile(arr, 90)) if arr.size else 0.0,
        "p95": float(np.percentile(arr, 95)) if arr.size else 0.0,
        "p99": float(np.percentile(arr, 99)) if arr.size else 0.0,
        "max": int(arr.max()) if arr.size else 0
    }
    return stats

def estimate_windows(

    lengths_no_special: List[int],

    max_length: int,

    special_tokens_to_add: int,

    doc_stride: int

) -> Tuple[int, int]:
    """

    Estimate total number of sliding-window chunks required if we split long docs.

    We compute per-sample windows with content capacity = max_length - special_tokens_to_add.

    Overlap is applied on content tokens via doc_stride.

    Returns:

      (total_windows, num_samples_needing_chunking)

    """
    content_capacity = max_length - special_tokens_to_add
    if content_capacity <= 0:
        raise ValueError(f"Invalid content capacity: {content_capacity}. Check tokenizer specials and max_length.")

    total_windows = 0
    need_chunking = 0

    for n in lengths_no_special:
        if n <= content_capacity:
            total_windows += 1
        else:
            need_chunking += 1
            step = max(content_capacity - doc_stride, 1)
            # First window covers content_capacity; remaining tokens: n - content_capacity
            remaining = max(n - content_capacity, 0)
            extra = math.ceil(remaining / step)
            total_windows += 1 + extra
    return total_windows, need_chunking

def histogram_counts(lengths_with_special: List[int], max_length: int) -> List[Tuple[str, int]]:
    bins = [128, 256, 384, 512]
    labels = []
    counts = []
    prev = 0
    for b in bins:
        labels.append(f"{prev+1:>4}-{b:>4}")
        counts.append(sum(1 for L in lengths_with_special if prev < L <= b))
        prev = b
    labels.append(f">{bins[-1]}")
    counts.append(sum(1 for L in lengths_with_special if L > bins[-1]))
    return list(zip(labels, counts))

def main():
    parser = argparse.ArgumentParser(description="DeBERTa v3 base token-length analysis for train_clean.csv")
    parser.add_argument("--input_csv", default="train_clean.csv", help="Input CSV (one quoted text per line).")
    parser.add_argument("--tokenizer", default="microsoft/mdeberta-v3-base", help="HF tokenizer name.")
    parser.add_argument("--max_length", type=int, default=512, help="Max sequence length (incl. specials).")
    parser.add_argument("--doc_stride", type=int, default=128, help="Sliding window overlap on content tokens.")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for tokenization.")
    args = parser.parse_args()

    script_dir = os.path.dirname(os.path.abspath(__file__))
    input_path = os.path.join(script_dir, args.input_csv)

    if not os.path.isfile(input_path):
        print(f"Input file not found: {input_path}", file=sys.stderr)
        sys.exit(1)

    texts = load_texts(input_path)

    lengths_with_special, lengths_no_special, specials = compute_lengths(
        texts, args.tokenizer, args.max_length, batch_size=args.batch_size
    )

    # Basic summary
    stats_with = summarize(lengths_with_special)
    stats_no = summarize(lengths_no_special)

    # Exceed counts
    exceed = sum(1 for L in lengths_with_special if L > args.max_length)
    total = len(lengths_with_special)
    frac = (exceed / total) * 100.0 if total else 0.0

    # Histogram and window estimates
    hist = histogram_counts(lengths_with_special, args.max_length)
    total_windows, need_chunking = estimate_windows(
        lengths_no_special, args.max_length, specials, args.doc_stride
    )

    # Output
    print("\n=== Token Length Analysis (DeBERTa v3 base) ===")
    print(f"Tokenizer: {args.tokenizer}")
    print(f"max_length: {args.max_length} (includes {specials} special tokens per sequence)")
    print(f"doc_stride: {args.doc_stride} (applies to content tokens)")
    print(f"Samples: {total}")
    print(f"Exceeding max_length: {exceed} ({frac:.2f}%)")
    print(f"Samples needing chunking (by content capacity): {need_chunking}")

    print("\n-- Length stats WITH specials --")
    for k, v in stats_with.items():
        print(f"{k:>6}: {v}")

    print("\n-- Length stats WITHOUT specials --")
    for k, v in stats_no.items():
        print(f"{k:>6}: {v}")

    print("\n-- Histogram (WITH specials) --")
    for label, cnt in hist:
        print(f"{label}: {cnt}")

    content_capacity = args.max_length - specials
    print(f"\nContent capacity per window (tokens excluding specials): {content_capacity}")
    print(f"Estimated total windows if chunked with doc_stride={args.doc_stride}: {total_windows}")

    print("\nDone.")

if __name__ == "__main__":
    main()