|
|
|
|
|
""" |
|
|
使用字符长度估算样本长度(粗略估计) |
|
|
字符数 * 1.5 ≈ token 数(保守估计) |
|
|
""" |
|
|
|
|
|
import json |
|
|
from tqdm import tqdm |
|
|
|
|
|
def calculate_char_length(sample): |
|
|
"""计算样本的字符长度""" |
|
|
total_length = 0 |
|
|
|
|
|
|
|
|
if 'system' in sample and sample['system']: |
|
|
total_length += len(sample['system']) |
|
|
|
|
|
|
|
|
if 'tools' in sample and sample['tools']: |
|
|
total_length += len(sample['tools']) |
|
|
|
|
|
|
|
|
for conv in sample['conversations']: |
|
|
content = conv.get('value', '') |
|
|
total_length += len(content) |
|
|
|
|
|
return total_length |
|
|
|
|
|
def main(): |
|
|
print("Loading data...") |
|
|
input_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json" |
|
|
with open(input_file, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
print(f"Total samples: {len(data)}") |
|
|
|
|
|
print("Calculating character lengths...") |
|
|
length_stats = [] |
|
|
|
|
|
for sample in tqdm(data): |
|
|
char_length = calculate_char_length(sample) |
|
|
length_stats.append(char_length) |
|
|
|
|
|
length_stats.sort() |
|
|
|
|
|
print(f"\nCharacter length statistics:") |
|
|
print(f"- Min: {min(length_stats):,}") |
|
|
print(f"- Max: {max(length_stats):,}") |
|
|
print(f"- Median: {length_stats[len(length_stats)//2]:,}") |
|
|
print(f"- 90th percentile: {length_stats[int(len(length_stats)*0.9)]:,}") |
|
|
print(f"- 95th percentile: {length_stats[int(len(length_stats)*0.95)]:,}") |
|
|
print(f"- 99th percentile: {length_stats[int(len(length_stats)*0.99)]:,}") |
|
|
|
|
|
|
|
|
print(f"\nEstimated token length (chars * 1.5):") |
|
|
max_token = 65536 |
|
|
|
|
|
|
|
|
char_threshold = int(max_token / 1.5) |
|
|
print(f"- Character threshold for {max_token} tokens: {char_threshold:,}") |
|
|
|
|
|
exceeding = sum(1 for l in length_stats if l * 1.5 > max_token) |
|
|
print(f"- Samples estimated to exceed {max_token} tokens: {exceeding} ({exceeding/len(length_stats)*100:.2f}%)") |
|
|
|
|
|
|
|
|
print(f"\nFiltering samples with estimated tokens ≤ {max_token}...") |
|
|
filtered_data = [] |
|
|
|
|
|
for sample in tqdm(data): |
|
|
char_length = calculate_char_length(sample) |
|
|
estimated_tokens = char_length * 1.5 |
|
|
|
|
|
if estimated_tokens <= max_token: |
|
|
filtered_data.append(sample) |
|
|
|
|
|
print(f"\nResults:") |
|
|
print(f"- Original samples: {len(data)}") |
|
|
print(f"- Filtered samples: {len(filtered_data)}") |
|
|
print(f"- Removed: {len(data) - len(filtered_data)}") |
|
|
|
|
|
|
|
|
output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k_filtered.json" |
|
|
print(f"\nSaving to {output_file}...") |
|
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(filtered_data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
has_func = sum(1 for s in filtered_data if any(c['from'] == 'function_call' for c in s['conversations'])) |
|
|
has_tools = sum(1 for s in filtered_data if 'tools' in s and s['tools']) |
|
|
|
|
|
print(f"\nFiltered data statistics:") |
|
|
print(f"- Total: {len(filtered_data)}") |
|
|
print(f"- With function_call: {has_func}") |
|
|
print(f"- With tools: {has_tools}") |
|
|
print(f"- Without function_call: {len(filtered_data) - has_func}") |
|
|
|
|
|
import os |
|
|
file_size = os.path.getsize(output_file) / (1024 * 1024) |
|
|
print(f"- File size: {file_size:.2f} MB") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|