google-links / _count.py
dejanseo's picture
Upload 22 files
f29b6e6 verified
#!/usr/bin/env python3
# analysis_deberta_lengths.py
#
# Purpose:
# - Load one-text-per-line CSV: train_clean.csv (no header, each field quoted)
# - Tokenize with microsoft/mdeberta-v3-base
# - Report counts exceeding max_length (default 512)
# - Report length stats, histogram bins, and estimated sliding-window counts given doc_stride
#
# Notes:
# - Paths are relative to this script. Do not rename input file.
# - Shows a progress bar with n/N and ETA.
import os
import sys
import argparse
import math
import csv
from typing import List, Tuple
try:
from tqdm import tqdm
except ImportError:
print("tqdm not installed. Install with: pip install tqdm", file=sys.stderr)
sys.exit(1)
try:
import pandas as pd
except ImportError:
print("pandas not installed. Install with: pip install pandas", file=sys.stderr)
sys.exit(1)
try:
from transformers import AutoTokenizer
except ImportError:
print("transformers not installed. Install with: pip install transformers", file=sys.stderr)
sys.exit(1)
def load_texts(csv_path: str) -> List[str]:
# CSV has no header, one quoted field per row
df = pd.read_csv(csv_path, header=None, quoting=csv.QUOTE_ALL, dtype=str)
df.columns = ["text"]
# Ensure single-line strings (train_clean.csv should already be single-line)
texts = df["text"].astype(str).tolist()
return texts
def compute_lengths(
texts: List[str],
tokenizer_name: str,
max_length: int,
batch_size: int = 64
) -> Tuple[List[int], List[int], int]:
"""
Returns:
- lengths_with_special: token lengths when add_special_tokens=True
- lengths_no_special: token lengths when add_special_tokens=False
- special_tokens_to_add: tokenizer.num_special_tokens_to_add(pair=False)
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
lengths_with_special = []
lengths_no_special = []
# Measure special tokens added per single sequence
special_tokens_to_add = tokenizer.num_special_tokens_to_add(pair=False)
# Batch tokenize to speed up
for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing", unit="batch"):
batch = texts[i:i+batch_size]
enc_with = tokenizer(
batch,
add_special_tokens=True,
truncation=False,
return_attention_mask=False,
return_token_type_ids=False
)
enc_without = tokenizer(
batch,
add_special_tokens=False,
truncation=False,
return_attention_mask=False,
return_token_type_ids=False
)
# HF returns 'input_ids' as list per sample
lengths_with_special.extend([len(ids) for ids in enc_with["input_ids"]])
lengths_no_special.extend([len(ids) for ids in enc_without["input_ids"]])
return lengths_with_special, lengths_no_special, special_tokens_to_add
def summarize(lengths: List[int]):
import numpy as np
arr = np.array(lengths, dtype=int)
stats = {
"count": int(arr.size),
"mean": float(arr.mean()) if arr.size else 0.0,
"median": float(np.median(arr)) if arr.size else 0.0,
"min": int(arr.min()) if arr.size else 0,
"p90": float(np.percentile(arr, 90)) if arr.size else 0.0,
"p95": float(np.percentile(arr, 95)) if arr.size else 0.0,
"p99": float(np.percentile(arr, 99)) if arr.size else 0.0,
"max": int(arr.max()) if arr.size else 0
}
return stats
def estimate_windows(
lengths_no_special: List[int],
max_length: int,
special_tokens_to_add: int,
doc_stride: int
) -> Tuple[int, int]:
"""
Estimate total number of sliding-window chunks required if we split long docs.
We compute per-sample windows with content capacity = max_length - special_tokens_to_add.
Overlap is applied on content tokens via doc_stride.
Returns:
(total_windows, num_samples_needing_chunking)
"""
content_capacity = max_length - special_tokens_to_add
if content_capacity <= 0:
raise ValueError(f"Invalid content capacity: {content_capacity}. Check tokenizer specials and max_length.")
total_windows = 0
need_chunking = 0
for n in lengths_no_special:
if n <= content_capacity:
total_windows += 1
else:
need_chunking += 1
step = max(content_capacity - doc_stride, 1)
# First window covers content_capacity; remaining tokens: n - content_capacity
remaining = max(n - content_capacity, 0)
extra = math.ceil(remaining / step)
total_windows += 1 + extra
return total_windows, need_chunking
def histogram_counts(lengths_with_special: List[int], max_length: int) -> List[Tuple[str, int]]:
bins = [128, 256, 384, 512]
labels = []
counts = []
prev = 0
for b in bins:
labels.append(f"{prev+1:>4}-{b:>4}")
counts.append(sum(1 for L in lengths_with_special if prev < L <= b))
prev = b
labels.append(f">{bins[-1]}")
counts.append(sum(1 for L in lengths_with_special if L > bins[-1]))
return list(zip(labels, counts))
def main():
parser = argparse.ArgumentParser(description="DeBERTa v3 base token-length analysis for train_clean.csv")
parser.add_argument("--input_csv", default="train_clean.csv", help="Input CSV (one quoted text per line).")
parser.add_argument("--tokenizer", default="microsoft/mdeberta-v3-base", help="HF tokenizer name.")
parser.add_argument("--max_length", type=int, default=512, help="Max sequence length (incl. specials).")
parser.add_argument("--doc_stride", type=int, default=128, help="Sliding window overlap on content tokens.")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size for tokenization.")
args = parser.parse_args()
script_dir = os.path.dirname(os.path.abspath(__file__))
input_path = os.path.join(script_dir, args.input_csv)
if not os.path.isfile(input_path):
print(f"Input file not found: {input_path}", file=sys.stderr)
sys.exit(1)
texts = load_texts(input_path)
lengths_with_special, lengths_no_special, specials = compute_lengths(
texts, args.tokenizer, args.max_length, batch_size=args.batch_size
)
# Basic summary
stats_with = summarize(lengths_with_special)
stats_no = summarize(lengths_no_special)
# Exceed counts
exceed = sum(1 for L in lengths_with_special if L > args.max_length)
total = len(lengths_with_special)
frac = (exceed / total) * 100.0 if total else 0.0
# Histogram and window estimates
hist = histogram_counts(lengths_with_special, args.max_length)
total_windows, need_chunking = estimate_windows(
lengths_no_special, args.max_length, specials, args.doc_stride
)
# Output
print("\n=== Token Length Analysis (DeBERTa v3 base) ===")
print(f"Tokenizer: {args.tokenizer}")
print(f"max_length: {args.max_length} (includes {specials} special tokens per sequence)")
print(f"doc_stride: {args.doc_stride} (applies to content tokens)")
print(f"Samples: {total}")
print(f"Exceeding max_length: {exceed} ({frac:.2f}%)")
print(f"Samples needing chunking (by content capacity): {need_chunking}")
print("\n-- Length stats WITH specials --")
for k, v in stats_with.items():
print(f"{k:>6}: {v}")
print("\n-- Length stats WITHOUT specials --")
for k, v in stats_no.items():
print(f"{k:>6}: {v}")
print("\n-- Histogram (WITH specials) --")
for label, cnt in hist:
print(f"{label}: {cnt}")
content_capacity = args.max_length - specials
print(f"\nContent capacity per window (tokens excluding specials): {content_capacity}")
print(f"Estimated total windows if chunked with doc_stride={args.doc_stride}: {total_windows}")
print("\nDone.")
if __name__ == "__main__":
main()