| import json |
| import pandas as pd |
| import random |
| import numpy as np |
| import os |
|
|
| import argparse |
|
|
|
|
| def lim(file_path): |
| try: |
| with open(file_path, 'r') as f: |
| data = json.load(f) |
|
|
| min_length = 1000 |
| for values in data.values(): |
| if values and len(values) < min_length: |
| min_length = len(values) |
| print(min_length) |
| print("Datasize size: ", len(data)) |
|
|
| epoch_sums = [0] * min_length |
| for key, values in data.items(): |
| for i in range(min_length): |
| epoch_sums[i] += values[i] |
|
|
| epoch_averages = [epoch_sum/len(data) for epoch_sum in epoch_sums] |
|
|
| minus_square_sum = 0 |
| for i in range(min_length): |
| minus_square_sum += (1-epoch_averages[i]) * (1-epoch_averages[i]) |
| |
| result = {} |
| for key, values in data.items(): |
| local_minus_square_sum = 0 |
| for i in range(min_length): |
| local_minus_square_sum += (values[i]-epoch_averages[i]) * (values[i]-epoch_averages[i]) |
| result[key] = 1 - local_minus_square_sum/minus_square_sum |
|
|
| return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) |
| |
| except FileNotFoundError: |
| print(f"Error: File '{file_path}' not found.") |
| return {} |
| except json.JSONDecodeError: |
| print(f"Error: File '{file_path}' is not a valid JSON file.") |
| return {} |
| except Exception as e: |
| print(f"Error: {str(e)}") |
| return {} |
|
|
| def acc_score(file_path, method='std'): |
|
|
| try: |
| with open(file_path, 'r') as f: |
| data = json.load(f) |
| |
| result = {} |
| for key, values in data.items(): |
| if method == 'std': |
| count = np.std(values) |
|
|
| elif method == 'random': |
| |
| count = random.random() |
| elif method == 'lim': |
| return lim(file_path) |
| |
| result[key] = count |
|
|
| return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) |
| |
| except FileNotFoundError: |
| print(f"Error: File '{file_path}' not found.") |
| return {} |
| except json.JSONDecodeError: |
| print(f"Error: File '{file_path}' is not a valid JSON file.") |
| return {} |
| except Exception as e: |
| print(f"Error: {str(e)}") |
| return {} |
|
|
| def sample_parquet_by_indices(json_file_path, parquet_file_path, output_parquet_path, |
| index_column='index', top_n=128, top_start=None, top_end=None, |
| method='std', repeat_time=1, is_save=1): |
| if top_start is not None and top_end is not None: |
| print(f"Sampling records from index {top_start} to {top_end} from {parquet_file_path} using {method} method.") |
| expected_sample_count = (top_end - top_start + 1) |
| else: |
| print(f"Sampling {top_n} records from {parquet_file_path} using {method} method.") |
| expected_sample_count = top_n |
| |
| sorted_counts = acc_score(json_file_path, method) |
| if not sorted_counts: |
| print("No valid counts obtained from the JSON file / parquet file.") |
| return 0, 0 |
| |
| |
| with open(json_file_path, 'r') as f: |
| original_json_data = json.load(f) |
|
|
| string_indices = list(sorted_counts.keys()) |
| |
| |
| if top_start is not None and top_end is not None: |
| string_indices = string_indices[top_start:top_end+1] |
| else: |
| string_indices = string_indices[:min(top_n, len(sorted_counts))] |
| |
| indices_to_keep = [int(idx) for idx in string_indices] |
| print(indices_to_keep) |
|
|
| df = pd.read_parquet(parquet_file_path, engine='pyarrow') |
|
|
| result_data = [] |
| for _, row in df.iterrows(): |
| try: |
| if row['extra_info']['index'] in indices_to_keep: |
| result_data.append(row) |
| except: |
| continue |
|
|
| filtered_df = pd.DataFrame(result_data) |
| |
| |
| if repeat_time > 1: |
| |
| repeated_dfs = [filtered_df] * repeat_time |
| |
| filtered_df = pd.concat(repeated_dfs, ignore_index=True) |
| |
| |
| total_count = expected_sample_count * repeat_time |
| output_base, output_ext = os.path.splitext(output_parquet_path) |
| output_parquet_path = f"{output_base}_repeatto{total_count}{output_ext}" |
| |
| |
| sample_size = min(10, len(filtered_df)) |
| sample_indices = random.sample(range(len(filtered_df)), sample_size) |
| print("\n=== Randomly selected samples ===") |
| |
| for i, idx in enumerate(sample_indices): |
| sample_row = filtered_df.iloc[idx] |
| row_index = str(sample_row['extra_info']['index']) |
| |
| print(f"\nSample {i+1} (Index: {row_index}):") |
| print("Raw count list from JSON:") |
| print(original_json_data.get(row_index, [])) |
| print("Sample row data:") |
| print(sample_row) |
| print("prompt:") |
| print(sample_row['prompt'][0]['content']) |
| print("-" * 80) |
| |
| |
| unique_prompts = set() |
| for _, row in filtered_df.iterrows(): |
| prompt_content = row['prompt'][0]['content'] |
| unique_prompts.add(prompt_content) |
| |
| print(f"Sampled {len(filtered_df)} records out of {len(df)} total.") |
| print(f"Number of unique prompts in the output file: {len(unique_prompts)}") |
| |
| |
| if is_save: |
| filtered_df.to_parquet(output_parquet_path) |
| print(f"Saved to {output_parquet_path}") |
| else: |
| print(f"Not saving to file (is_save=0). Would have saved to {output_parquet_path}") |
| |
| return len(filtered_df), len(unique_prompts) |
|
|
| def random_sample_parquet(parquet_file_path, output_parquet_path, sample_size=128, is_save=1): |
| df = pd.read_parquet(parquet_file_path, engine='pyarrow') |
| |
| total_records = len(df) |
| |
| if total_records <= sample_size: |
| print(f"Warning: Requested sample size ({sample_size}) is greater than or equal to " |
| f"the total number of records ({total_records}). Returning all records.") |
| df.to_parquet(output_parquet_path) |
| return total_records |
| |
| random_indices = random.sample(range(total_records), sample_size) |
| |
| sampled_df = df.iloc[random_indices] |
| |
| |
| if is_save: |
| sampled_df.to_parquet(output_parquet_path) |
| print(f"Saved to {output_parquet_path}") |
| else: |
| print(f"Not saving to file (is_save=0). Would have saved to {output_parquet_path}") |
| |
| return sample_size |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='Sample parquet files based on different methods.') |
| parser.add_argument("--index_json_path", type=str, default="acc_step_500.json", help="Path to the index json file") |
| parser.add_argument("--data_dir", type=str, default="data/train/one_shot_rlvr", help="Path to the data directory") |
| parser.add_argument("--parquet_file_name", type=str, default="dsr_sub.parquet", help="Path to the parquet file") |
| parser.add_argument('--top_n', type=int, default=1, help='Number of top records to sample') |
| parser.add_argument('--repeat_time', type=int, default=128, help='Number of times to repeat the sampling') |
| parser.add_argument('--top_index', type=int, default=1200, help='Index of the top record') |
| parser.add_argument('--top_start', type=int, default=None, help='Start index of the range selection') |
| parser.add_argument('--top_end', type=int, default=None, help='End index of the range selection') |
| parser.add_argument('--method', type=str, default='std', help='Method to sample the parquet file') |
| parser.add_argument('--is_save', type=int, default=1, help='Whether to save the output parquet file (1=yes, 0=no)') |
| args = parser.parse_args() |
|
|
|
|
| |
| index_json_path = args.index_json_path |
| data_dir = args.data_dir |
| parquet_file_path = f"{data_dir}/{args.parquet_file_name}" |
|
|
| top_n = args.top_n |
| repeat_time = args.repeat_time |
| top_index = args.top_index |
| top_start = None |
| top_end = None |
| if top_index is not None: |
| top_start = top_index |
| top_end = top_index |
| print(f"top_start: {top_start}, top_end: {top_end}, top_n: {top_n}") |
|
|
|
|
| method_name = args.method |
|
|
| |
| |
| |
| |
| is_save = args.is_save |
|
|
| if top_start is not None and top_end is not None: |
| if top_start == top_end: |
| output_parquet_path = f"{data_dir}/dsr_sub_{method_name}_pi{top_start+1}_r{repeat_time}.parquet" |
| else: |
| output_parquet_path = f"{data_dir}/dsr_sub_{method_name}_pi{top_start+1}-{top_end+1}_r{repeat_time}.parquet" |
| else: |
| output_parquet_path = f"{data_dir}/dsr_sub_{method_name}_sample{top_n}_r{repeat_time}.parquet" |
| |
| print(acc_score(index_json_path, method=method_name)) |
| sample_parquet_by_indices(index_json_path, parquet_file_path, output_parquet_path, |
| top_n=top_n, top_start=top_start, top_end=top_end, |
| method=method_name, repeat_time=repeat_time, is_save=is_save) |
|
|
|
|