| import os |
| import http.client |
| import json |
| import time |
| import random |
| from datasets import load_dataset, Dataset, DatasetDict, Features, Value |
| from tqdm.auto import tqdm |
| import sys |
| import logging |
| import getpass |
| import signal |
| import socket |
| import concurrent.futures |
| from concurrent.futures import ThreadPoolExecutor |
| import argparse |
|
|
| |
| DATASET_NAME = "virtuoussy/Multi-subject-RLVR" |
| DATASET_SPLIT = "train" |
| API_HOST = "api2.aigcbest.top" |
| API_PATH = "/v1/chat/completions" |
| LLM_MODEL = "gpt-4.1-mini" |
| API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI") |
| if not API_KEY or API_KEY == "YOUR_API_KEY_HERE": |
| print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.") |
| sys.exit(1) |
|
|
| OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased" |
| |
| PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed") |
| |
| FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final") |
|
|
| BATCH_SAVE_SIZE = 500 |
| MAX_WORKERS = 20 |
| REQUEST_DELAY_SECONDS = 0.15 |
| MAX_RETRIES = 3 |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logging.getLogger("datasets").setLevel(logging.WARNING) |
| logging.getLogger("huggingface_hub").setLevel(logging.WARNING) |
|
|
| |
| |
| def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES): |
| system_prompt = ( |
| "You are an expert linguist specializing in converting structured prompts or " |
| "fill-in-the-blank problems into natural, spoken-language questions suitable for " |
| "text-to-speech (TTS). Your goal is to make the question sound like how a person " |
| "would naturally ask it. " |
| "If the input is a fill-in-the-blank problem (e.g., contains '-----'), " |
| "rephrase it as a direct question asking for the missing information. " |
| "Keep the core meaning, mathematical context, variables, and numbers exactly the same. " |
| "Focus only on rephrasing the *user's question* part provided. " |
| "Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'." |
| ) |
| payload = json.dumps({ |
| "model": model, |
| "messages": [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": original_question} |
| ], |
| }) |
| headers = { |
| 'Accept': 'application/json', |
| 'Authorization': f'Bearer {api_key}', |
| 'User-Agent': 'HuggingFace Dataset Processing Script (Retry Mode)', |
| 'Content-Type': 'application/json' |
| } |
| time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2)) |
|
|
| for attempt in range(retries): |
| logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...") |
| try: |
| conn = http.client.HTTPSConnection(host, timeout=60) |
| conn.request("POST", path, payload, headers) |
| res = conn.getresponse() |
| status = res.status |
| data = res.read() |
| conn.close() |
|
|
| if status == 200: |
| response_json = json.loads(data.decode("utf-8")) |
| if response_json.get("choices") and len(response_json["choices"]) > 0: |
| message = response_json["choices"][0].get("message") |
| if message and message.get("content"): |
| rephrased = message["content"].strip() |
| if len(rephrased) > 1 and ((rephrased.startswith('"') and rephrased.endswith('"')) or \ |
| (rephrased.startswith("'") and rephrased.endswith("'"))): |
| rephrased = rephrased[1:-1] |
| if rephrased and rephrased.strip().lower() != original_question.strip().lower(): |
| logging.debug(f"Successfully rephrased: {rephrased[:80]}...") |
| return rephrased |
| elif not rephrased: |
| logging.warning(f"LLM returned empty response for: {original_question[:50]}...") |
| return None |
| else: |
| logging.warning(f"LLM returned identical response for: {original_question[:50]}...") |
| return None |
| logging.error(f"Unexpected API response structure: {data.decode('utf-8')}") |
| return None |
| elif status == 429: |
| retry_after_header = res.getheader('Retry-After', '5') |
| try: wait_time = int(retry_after_header) |
| except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) |
| logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...") |
| time.sleep(wait_time) |
| elif status >= 500: |
| wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) |
| logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...") |
| time.sleep(wait_time) |
| else: |
| logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}") |
| return None |
| except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e: |
| logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}") |
| if attempt + 1 == retries: return None |
| wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) |
| logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") |
| time.sleep(wait_time) |
| except json.JSONDecodeError as e: |
| logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200]}") |
| if attempt + 1 == retries: return None |
| wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) |
| time.sleep(wait_time) |
| except Exception as e: |
| logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True) |
| if attempt + 1 == retries: return None |
| wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) |
| logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") |
| time.sleep(wait_time) |
|
|
| logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...") |
| return None |
|
|
|
|
| |
| |
| def rephrase_query_entry(example): |
| processed_example = example.copy() |
| |
| if 'query_rephrased_status' not in processed_example: |
| processed_example['query_rephrased_status'] = 'unprocessed' |
|
|
| original_query_list = example.get("query") |
|
|
| |
| if original_query_list is None: |
| processed_example['query_rephrased_status'] = 'skipped_missing_query_column' |
| processed_example['query_rephrased'] = None |
| return processed_example |
| if not isinstance(original_query_list, list): |
| processed_example['query_rephrased_status'] = 'skipped_query_not_list' |
| processed_example['query_rephrased'] = None |
| return processed_example |
| if not original_query_list: |
| processed_example['query_rephrased_status'] = 'skipped_query_list_empty' |
| processed_example['query_rephrased'] = None |
| return processed_example |
|
|
| |
| user_question = None |
| for i, message in enumerate(original_query_list): |
| if isinstance(message, dict) and message.get("role") == "user": |
| content = message.get("content") |
| if isinstance(content, str) and content.strip(): |
| user_question = content |
| break |
| else: |
| processed_example['query_rephrased_status'] = 'skipped_invalid_user_content' |
| processed_example['query_rephrased'] = None |
| return processed_example |
|
|
| if not user_question: |
| processed_example['query_rephrased_status'] = 'skipped_no_user_content_found' |
| processed_example['query_rephrased'] = None |
| return processed_example |
|
|
| |
| logging.info(f"Attempting to rephrase: {user_question[:60]}...") |
| rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL) |
|
|
| |
| if rephrased_query_content: |
| logging.debug(f"Rephrased '{user_question[:30]}...' to '{rephrased_query_content[:30]}...'") |
| processed_example["query_rephrased"] = rephrased_query_content |
| processed_example['query_rephrased_status'] = 'success_retried' |
| else: |
| logging.warning(f"Retry failed for user question: {user_question[:50]}...") |
| |
| processed_example['query_rephrased_status'] = 'failed_llm_retry' |
|
|
| return processed_example |
|
|
|
|
| |
| |
| def save_final_dataset(data_list, output_path): |
| """Saves the final list of processed data dictionaries.""" |
| if not data_list: |
| logging.info("No data provided for saving.") |
| return False |
| logging.info(f"Attempting to save {len(data_list)} final examples to {output_path}...") |
| try: |
| |
| |
| features = Features({ |
| 'query': [{'role': Value(dtype='string', id=None), 'content': Value(dtype='string', id=None)}], |
| 'query_rephrased': Value(dtype='string', id=None), |
| 'query_rephrased_status': Value(dtype='string', id=None), |
| |
| |
| |
| |
| 'query_code': Value(dtype='string', id=None), |
| 'answer': Value(dtype='string', id=None), |
| 'answer_code': Value(dtype='string', id=None), |
| 'subject': Value(dtype='string', id=None), |
| 'grade': Value(dtype='string', id=None), |
| 'source': Value(dtype='string', id=None), |
| 'split': Value(dtype='string', id=None), |
| '__index_level_0__': Value(dtype='int64', id=None) |
| }) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| processed_dataset = Dataset.from_list(list(data_list), features=features) |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| processed_dataset.save_to_disk(output_path) |
| logging.info(f"Successfully saved final dataset ({len(data_list)} items) to {output_path}") |
| return True |
| except Exception as e: |
| logging.error(f"Failed to save final dataset to {output_path}: {e}", exc_info=True) |
| |
| fallback_json_path = output_path + ".jsonl" |
| logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}") |
| try: |
| with open(fallback_json_path, 'w', encoding='utf-8') as f: |
| for item in data_list: |
| f.write(json.dumps(item, ensure_ascii=False) + '\n') |
| logging.info(f"Successfully saved fallback JSON Lines file to {fallback_json_path}") |
| except Exception as json_e: |
| logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True) |
| return False |
|
|
|
|
| |
| def get_user_query(example): |
| """Extracts the user query content from the 'query' list.""" |
| query_list = example.get("query") |
| if isinstance(query_list, list): |
| for message in query_list: |
| if isinstance(message, dict) and message.get("role") == "user": |
| content = message.get("content") |
| if isinstance(content, str) and content.strip(): |
| return content |
| return None |
|
|
|
|
| |
| def needs_retry(example): |
| """Determines if an example needs reprocessing based on its current state.""" |
| status = example.get('query_rephrased_status') |
| rephrased_text = example.get('query_rephrased') |
|
|
| |
| if status in ['failed_llm_call', 'failed_llm_retry', 'failed_processing_exception']: |
| return True |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| if rephrased_text is None or not str(rephrased_text).strip(): |
| |
| if status not in ['skipped_missing_query_column', 'skipped_query_not_list', |
| 'skipped_query_list_empty', 'skipped_invalid_user_content', |
| 'skipped_no_user_content_found']: |
| return True |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| return False |
|
|
|
|
| |
| if __name__ == "__main__": |
| start_time = time.time() |
| logging.info("======================================================") |
| logging.info(f" Starting Dataset Processing Script in RETRY MODE") |
| logging.info("======================================================") |
| logging.info(f"Dataset: {DATASET_NAME}, Split: {DATASET_SPLIT}") |
| logging.info(f"Loading existing processed data from: {PROCESSED_DATA_PATH}") |
| logging.info(f"Final output will be saved to: {FINAL_OUTPUT_PATH}") |
| logging.info(f"Max concurrent workers: {MAX_WORKERS}") |
|
|
| |
| if not os.path.exists(PROCESSED_DATA_PATH): |
| logging.error(f"Existing processed data not found at '{PROCESSED_DATA_PATH}'. Cannot run in retry mode.") |
| sys.exit(1) |
|
|
| logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...") |
| try: |
| |
| existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH) |
| |
| |
| results_list = existing_dataset.to_list() |
| total_examples = len(results_list) |
| logging.info(f"Loaded {total_examples} examples.") |
| |
| for i in range(total_examples): |
| if 'query_rephrased' not in results_list[i]: |
| results_list[i]['query_rephrased'] = None |
| if 'query_rephrased_status' not in results_list[i]: |
| results_list[i]['query_rephrased_status'] = 'unknown_original_status' |
|
|
| except Exception as e: |
| logging.error(f"Failed to load existing dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True) |
| sys.exit(1) |
|
|
|
|
| |
| indices_to_retry = [ |
| i for i, example in enumerate(results_list) if needs_retry(example) |
| ] |
| num_to_retry = len(indices_to_retry) |
|
|
| if num_to_retry == 0: |
| logging.info("No examples found needing retry based on the criteria.") |
| logging.info(f"The dataset at {PROCESSED_DATA_PATH} is considered final.") |
| |
| |
| |
| sys.exit(0) |
|
|
| logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.") |
|
|
| |
| processed_count_in_retry = 0 |
| |
| |
| retry_results_dict = {} |
|
|
| logging.info("Starting concurrent processing for examples needing retry...") |
|
|
| with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| |
| |
| futures = { |
| executor.submit(rephrase_query_entry, results_list[i]): i |
| for i in indices_to_retry |
| } |
|
|
| try: |
| pbar = tqdm(total=num_to_retry, desc="Retrying failed examples", unit="example") |
| for future in concurrent.futures.as_completed(futures): |
| original_index = futures[future] |
| try: |
| |
| updated_example_dict = future.result() |
| |
| retry_results_dict[original_index] = updated_example_dict |
| pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True) |
|
|
| except Exception as exc: |
| |
| logging.error(f'Retry for example index {original_index} generated an exception: {exc}', exc_info=True) |
| |
| error_placeholder = results_list[original_index].copy() |
| error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}' |
| |
| retry_results_dict[original_index] = error_placeholder |
| pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True) |
|
|
| finally: |
| processed_count_in_retry += 1 |
| pbar.update(1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| except KeyboardInterrupt: |
| logging.warning("\nCtrl+C detected during retry! Attempting to save progress...") |
|
|
| except Exception as e: |
| logging.error(f"An unexpected error occurred during the retry loop: {e}", exc_info=True) |
|
|
| finally: |
| if 'pbar' in locals(): |
| pbar.close() |
|
|
| |
| logging.info("Updating main results list with completed retry attempts...") |
| update_count = 0 |
| for idx, updated_item in retry_results_dict.items(): |
| if idx < len(results_list): |
| results_list[idx] = updated_item |
| update_count += 1 |
| else: |
| logging.error(f"Index {idx} from retry results is out of bounds for results_list (size {len(results_list)}). Skipping update.") |
|
|
| logging.info(f"Applied updates for {update_count} retried items.") |
|
|
| |
| logging.info(f"Attempting to save the final updated dataset to: {FINAL_OUTPUT_PATH}") |
| if save_final_dataset(results_list, FINAL_OUTPUT_PATH): |
| logging.info("Final dataset saved successfully.") |
| |
| |
| else: |
| logging.error(">>> FINAL SAVE FAILED! <<<") |
| logging.error(f"Check the logs. The latest state might be in memory or a fallback JSON file if created.") |
|
|
|
|
| |
| logging.info("------------------------------------------------------") |
| logging.info("Verification: Loading final saved dataset for status check...") |
| try: |
| final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH) |
| logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.") |
| status_counts = {} |
| for ex in final_reloaded_dataset: |
| status = ex.get('query_rephrased_status', 'unknown_status_field') |
| status_counts[status] = status_counts.get(status, 0) + 1 |
|
|
| logging.info("Status counts in the final saved file:") |
| for status, count in sorted(status_counts.items()): |
| logging.info(f" - {status}: {count}") |
|
|
| |
| remaining_failures = status_counts.get('failed_llm_retry', 0) + \ |
| status_counts.get('failed_retry_exception', 0) + \ |
| status_counts.get('failed_llm_call', 0) |
|
|
| if remaining_failures > 0: |
| logging.warning(f"Found {remaining_failures} examples still marked as failed after retry attempts.") |
| else: |
| logging.info("All identified failures appear to have been successfully retried or were not retried.") |
|
|
| except FileNotFoundError: |
| logging.error(f"Verification failed: Final saved dataset not found at {FINAL_OUTPUT_PATH}.") |
| except Exception as e: |
| logging.error(f"Failed to reload or verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True) |
|
|
|
|
| end_time = time.time() |
| logging.info("------------------------------------------------------") |
| logging.info(f"Retry script finished in {end_time - start_time:.2f} seconds.") |
| logging.info("======================================================") |