| |
| """ |
| 过滤出序列长度小于 65536 的样本 |
| """ |
|
|
| import json |
| from transformers import AutoTokenizer |
| from tqdm import tqdm |
|
|
| def calculate_length(sample, tokenizer): |
| """计算样本的 token 长度""" |
| total_length = 0 |
|
|
| |
| if 'system' in sample and sample['system']: |
| total_length += len(tokenizer.encode(sample['system'])) |
|
|
| |
| if 'tools' in sample and sample['tools']: |
| total_length += len(tokenizer.encode(sample['tools'])) |
|
|
| |
| for conv in sample['conversations']: |
| content = conv.get('value', '') |
| total_length += len(tokenizer.encode(content)) |
|
|
| return total_length |
|
|
| def main(): |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B-Base", trust_remote_code=True) |
|
|
| print("Loading data...") |
| input_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json" |
| with open(input_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
|
|
| print(f"Total samples: {len(data)}") |
|
|
| print("Calculating lengths and filtering...") |
| filtered_data = [] |
| length_stats = [] |
| max_length = 65536 |
|
|
| for sample in tqdm(data): |
| length = calculate_length(sample, tokenizer) |
| length_stats.append(length) |
|
|
| if length <= max_length: |
| filtered_data.append(sample) |
|
|
| print(f"\nResults:") |
| print(f"- Original samples: {len(data)}") |
| print(f"- Filtered samples (≤{max_length}): {len(filtered_data)}") |
| print(f"- Removed: {len(data) - len(filtered_data)}") |
|
|
| |
| length_stats.sort() |
| print(f"\nLength statistics:") |
| print(f"- Min: {min(length_stats)}") |
| print(f"- Max: {max(length_stats)}") |
| print(f"- Median: {length_stats[len(length_stats)//2]}") |
| print(f"- 90th percentile: {length_stats[int(len(length_stats)*0.9)]}") |
| print(f"- 95th percentile: {length_stats[int(len(length_stats)*0.95)]}") |
| print(f"- 99th percentile: {length_stats[int(len(length_stats)*0.99)]}") |
|
|
| |
| print(f"\nSamples exceeding thresholds:") |
| for threshold in [32768, 49152, 65536, 80000]: |
| count = sum(1 for l in length_stats if l > threshold) |
| print(f"- > {threshold}: {count} ({count/len(length_stats)*100:.2f}%)") |
|
|
| |
| output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k_filtered.json" |
| print(f"\nSaving filtered data to {output_file}...") |
|
|
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(filtered_data, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Done! Saved {len(filtered_data)} samples") |
|
|
| |
| has_func = sum(1 for s in filtered_data if any(c['from'] == 'function_call' for c in s['conversations'])) |
| print(f"\nFiltered data statistics:") |
| print(f"- With function_call: {has_func}") |
| print(f"- Without function_call: {len(filtered_data) - has_func}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|