| import os |
| import http.client |
| import json |
| import time |
| import random |
| |
| from datasets import load_dataset, Dataset, DatasetDict, Features, Value, Sequence |
| from tqdm.auto import tqdm |
| import sys |
| import logging |
| |
| |
| import concurrent.futures |
| from concurrent.futures import ThreadPoolExecutor |
| import shutil |
|
|
| |
| DATASET_NAME = "virtuoussy/Multi-subject-RLVR" |
| DATASET_SPLIT = "train" |
| API_HOST = "api2.aigcbest.top" |
| API_PATH = "/v1/chat/completions" |
| LLM_MODEL = "gpt-4.1-mini" |
| API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI") |
| if not API_KEY or API_KEY == "YOUR_API_KEY_HERE": |
| print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.") |
| sys.exit(1) |
|
|
| OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased" |
| |
| PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed") |
| |
| FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final") |
|
|
| MAX_WORKERS = 20 |
| REQUEST_DELAY_SECONDS = 0.15 |
| MAX_RETRIES = 3 |
| SAVE_INTERVAL = 2000 |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logging.getLogger("datasets").setLevel(logging.WARNING) |
| logging.getLogger("huggingface_hub").setLevel(logging.WARNING) |
| logging.getLogger("filelock").setLevel(logging.WARNING) |
|
|
| |
| |
| def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES): |
| system_prompt = ( |
| "You are an expert linguist specializing in converting structured prompts or " |
| "fill-in-the-blank problems into natural, spoken-language questions suitable for " |
| "text-to-speech (TTS). Your goal is to make the question sound like how a person " |
| "would naturally ask it. " |
| "If the input is a fill-in-the-blank problem (e.g., contains '-----'), " |
| "rephrase it as a direct question asking for the missing information. " |
| "Keep the core meaning, mathematical context, variables, and numbers exactly the same. " |
| "Focus only on rephrasing the *user's question* part provided. " |
| "Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'." |
| ) |
| payload = json.dumps({ |
| "model": model, |
| "messages": [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": original_question} |
| ], |
| }) |
| headers = { |
| 'Accept': 'application/json', |
| 'Authorization': f'Bearer {api_key}', |
| 'User-Agent': 'HuggingFace Dataset Processing Script (Retry w/ Save)', |
| 'Content-Type': 'application/json' |
| } |
| time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2)) |
|
|
| for attempt in range(retries): |
| |
| try: |
| conn = http.client.HTTPSConnection(host, timeout=60) |
| conn.request("POST", path, payload, headers) |
| res = conn.getresponse() |
| status = res.status |
| data = res.read() |
| conn.close() |
|
|
| if status == 200: |
| response_json = json.loads(data.decode("utf-8")) |
| if response_json.get("choices") and len(response_json["choices"]) > 0: |
| message = response_json["choices"][0].get("message") |
| if message and message.get("content"): |
| rephrased = message["content"].strip() |
| |
| if len(rephrased) > 1: |
| if (rephrased.startswith('"') and rephrased.endswith('"')) or \ |
| (rephrased.startswith("'") and rephrased.endswith("'")): |
| rephrased = rephrased[1:-1] |
| |
| if rephrased.lower().startswith(("rephrased:", "here's the rephrased question:")): |
| parts = rephrased.split(":", 1) |
| if len(parts) > 1: |
| potential_rephrased = parts[1].strip() |
| if (potential_rephrased.startswith('"') and potential_rephrased.endswith('"')) or \ |
| (potential_rephrased.startswith("'") and potential_rephrased.endswith("'")): |
| rephrased = potential_rephrased[1:-1] |
| else: |
| rephrased = potential_rephrased |
|
|
| if rephrased and rephrased.strip().lower() != original_question.strip().lower(): |
| |
| return rephrased |
| elif not rephrased: |
| logging.warning(f"LLM returned empty/whitespace response for: {original_question[:50]}...") |
| return None |
| else: |
| logging.warning(f"LLM returned identical response for: {original_question[:50]}...") |
| return None |
| logging.error(f"Unexpected API response structure: {data.decode('utf-8')}") |
| return None |
| elif status == 429: |
| retry_after_header = res.getheader('Retry-After', '5') |
| try: wait_time = int(retry_after_header) |
| except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) |
| logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...") |
| time.sleep(wait_time) |
| elif status >= 500: |
| wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) |
| logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...") |
| time.sleep(wait_time) |
| else: |
| logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}") |
| return None |
|
|
| except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e: |
| logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}") |
| if attempt + 1 == retries: return None |
| wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) |
| logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") |
| time.sleep(wait_time) |
| except json.JSONDecodeError as e: |
| logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200] if data else 'N/A'}") |
| if attempt + 1 == retries: return None |
| wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) |
| time.sleep(wait_time) |
| except Exception as e: |
| logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True) |
| if attempt + 1 == retries: return None |
| wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) |
| logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") |
| time.sleep(wait_time) |
|
|
| logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...") |
| return None |
|
|
| |
| |
| def rephrase_query_entry(example): |
| processed_example = example.copy() |
| original_query_list = example.get("query") |
| processed_example['query_rephrased_status'] = 'processing_retry' |
|
|
| if original_query_list is None: |
| processed_example['query_rephrased_status'] = 'skipped_missing_query_column' |
| processed_example['query_rephrased'] = example.get('query_rephrased') |
| return processed_example |
| if not isinstance(original_query_list, list): |
| processed_example['query_rephrased_status'] = 'skipped_query_not_list' |
| processed_example['query_rephrased'] = example.get('query_rephrased') |
| return processed_example |
| if not original_query_list: |
| processed_example['query_rephrased_status'] = 'skipped_query_list_empty' |
| processed_example['query_rephrased'] = example.get('query_rephrased') |
| return processed_example |
|
|
| user_question = None |
| for message in original_query_list: |
| if isinstance(message, dict) and message.get("role") == "user": |
| content = message.get("content") |
| if isinstance(content, str) and content.strip(): |
| user_question = content |
| break |
| else: |
| processed_example['query_rephrased_status'] = 'skipped_invalid_user_content' |
| processed_example['query_rephrased'] = example.get('query_rephrased') |
| return processed_example |
| if not user_question: |
| processed_example['query_rephrased_status'] = 'skipped_no_user_content_found' |
| processed_example['query_rephrased'] = example.get('query_rephrased') |
| return processed_example |
|
|
| |
| rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL) |
|
|
| if rephrased_query_content: |
| processed_example["query_rephrased"] = rephrased_query_content |
| processed_example['query_rephrased_status'] = 'success' |
| else: |
| |
| processed_example['query_rephrased'] = example.get('query_rephrased') |
| processed_example['query_rephrased_status'] = 'failed_llm_call' |
|
|
| return processed_example |
|
|
| |
| |
| def save_dataset_atomically(data_list, output_path, features): |
| """Saves the list of data dictionaries atomically using the correct schema.""" |
| if not data_list: |
| logging.info("No data provided for saving.") |
| return False |
|
|
| temp_output_path = output_path + "_saving" |
| final_output_path = output_path |
|
|
| logging.info(f"Attempting to save {len(data_list)} examples to temp path {temp_output_path}...") |
| try: |
| |
| processed_dataset = Dataset.from_list(list(data_list), features=features) |
|
|
| |
| os.makedirs(os.path.dirname(final_output_path), exist_ok=True) |
|
|
| |
| if os.path.exists(temp_output_path): |
| logging.warning(f"Removing existing temporary save directory: {temp_output_path}") |
| shutil.rmtree(temp_output_path) |
|
|
| |
| processed_dataset.save_to_disk(temp_output_path) |
| logging.info(f"Successfully saved dataset to temporary path: {temp_output_path}") |
|
|
| |
| |
| if os.path.exists(final_output_path): |
| logging.debug(f"Removing existing final destination directory before rename: {final_output_path}") |
| shutil.rmtree(final_output_path) |
|
|
| |
| os.rename(temp_output_path, final_output_path) |
| logging.info(f"Successfully moved temporary save to final path: {final_output_path}") |
| return True |
|
|
| except Exception as e: |
| logging.error(f"Failed during atomic save process to {final_output_path}: {e}", exc_info=True) |
| |
| if os.path.exists(temp_output_path): |
| try: |
| shutil.rmtree(temp_output_path) |
| logging.info(f"Cleaned up temporary directory {temp_output_path} after error.") |
| except Exception as cleanup_e: |
| logging.error(f"Could not clean up temporary directory {temp_output_path} after error: {cleanup_e}") |
| |
| fallback_json_path = final_output_path + ".jsonl.failed_save" |
| logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}") |
| try: |
| with open(fallback_json_path, 'w', encoding='utf-8') as f: |
| for item in data_list: |
| f.write(json.dumps(item, ensure_ascii=False) + '\n') |
| logging.info(f"Successfully saved fallback JSON Lines file.") |
| except Exception as json_e: |
| logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True) |
| return False |
|
|
| |
| |
| def needs_retry(example): |
| rephrased = example.get('query_rephrased') |
| status = example.get('query_rephrased_status') |
| |
| |
| retry_flag = (rephrased is None) or (status != 'success') |
| return retry_flag |
|
|
| |
| if __name__ == "__main__": |
| start_time = time.time() |
| logging.info("======================================================") |
| logging.info(f" Starting Dataset Processing - RETRY w/ PERIODIC SAVE") |
| logging.info(f" Saving progress every {SAVE_INTERVAL} processed items.") |
| logging.info("======================================================") |
| logging.info(f"Loading existing data from: {PROCESSED_DATA_PATH}") |
| logging.info(f"Intermediate and final output will be saved to: {FINAL_OUTPUT_PATH}") |
|
|
| |
| if not os.path.exists(PROCESSED_DATA_PATH): |
| logging.error(f"Existing data directory not found at '{PROCESSED_DATA_PATH}'. Cannot run retry mode.") |
| sys.exit(1) |
|
|
| logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...") |
| try: |
| existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH) |
| |
| dataset_features = existing_dataset.features |
| logging.info(f"Dataset features detected: {dataset_features}") |
| |
| results_list = existing_dataset.to_list() |
| total_examples = len(results_list) |
| logging.info(f"Loaded {total_examples} examples.") |
| except Exception as e: |
| logging.error(f"Failed to load dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True) |
| |
| |
| |
| |
| sys.exit(1) |
|
|
| |
| logging.info("Identifying examples needing retry...") |
| indices_to_retry = [ |
| i for i, example in enumerate(tqdm(results_list, desc="Checking examples")) if needs_retry(example) |
| ] |
| num_to_retry = len(indices_to_retry) |
|
|
| if num_to_retry == 0: |
| logging.info("No examples found needing retry based on the criteria ('query_rephrased' is None or status != 'success').") |
| logging.info(f"Saving the existing dataset to the final location '{FINAL_OUTPUT_PATH}' as is...") |
| if not save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): |
| logging.error("Failed to save the dataset to the final location even though no retries were needed.") |
| sys.exit(0) |
|
|
| logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.") |
|
|
| |
| processed_count_total = 0 |
| processed_since_last_save = 0 |
| last_save_time = time.time() |
|
|
| logging.info("Starting concurrent retries with periodic saving...") |
|
|
| |
| with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| |
| futures = { |
| executor.submit(rephrase_query_entry, results_list[i]): i |
| for i in indices_to_retry |
| } |
|
|
| try: |
| |
| pbar = tqdm(total=num_to_retry, desc="Retrying examples", unit="example") |
| |
| for future in concurrent.futures.as_completed(futures): |
| original_index = futures[future] |
| try: |
| |
| updated_example_dict = future.result() |
| |
| results_list[original_index] = updated_example_dict |
| pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True) |
|
|
| except Exception as exc: |
| |
| logging.error(f'Retry task for index {original_index} encountered an exception: {exc}', exc_info=True) |
| |
| error_placeholder = results_list[original_index].copy() |
| error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}' |
| |
| results_list[original_index] = error_placeholder |
| pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True) |
|
|
| finally: |
| |
| processed_count_total += 1 |
| processed_since_last_save += 1 |
| pbar.update(1) |
|
|
| |
| if processed_since_last_save >= SAVE_INTERVAL: |
| current_time = time.time() |
| time_since_last = current_time - last_save_time |
| logging.info(f"\n--- Processed {processed_since_last_save} items (Total: {processed_count_total}/{num_to_retry}). Time since last save: {time_since_last:.1f}s. Saving progress... ---") |
| if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): |
| logging.info(f"--- Progress successfully saved to {FINAL_OUTPUT_PATH} ---") |
| processed_since_last_save = 0 |
| last_save_time = current_time |
| else: |
| logging.error(f"--- FAILED TO SAVE PROGRESS! Check errors above. Will retry saving later. ---") |
| |
|
|
| except KeyboardInterrupt: |
| logging.warning("\nCtrl+C detected! Attempting final save...") |
| |
|
|
| except Exception as e: |
| logging.error(f"An unexpected error occurred during the main retry loop: {e}", exc_info=True) |
| logging.error("Attempting final save...") |
| |
|
|
| finally: |
| |
| if 'pbar' in locals() and pbar is not None: |
| pbar.close() |
|
|
| logging.info("--- Processing loop finished or interrupted. ---") |
|
|
| |
| |
| logging.info(f"Attempting final save of the dataset ({len(results_list)} items) to: {FINAL_OUTPUT_PATH}") |
| if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): |
| logging.info("--- Final dataset state saved successfully. ---") |
| else: |
| logging.error(">>> FINAL SAVE FAILED! <<< Check logs. Fallback JSON file might exist.") |
|
|
| |
| logging.info("------------------------------------------------------") |
| logging.info("Verification: Attempting to load final saved dataset...") |
| try: |
| final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH) |
| logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.") |
|
|
| |
| status_counts = {} |
| none_rephrased_count = 0 |
| for ex in final_reloaded_dataset: |
| status = ex.get('query_rephrased_status', 'unknown_status') |
| status_counts[status] = status_counts.get(status, 0) + 1 |
| if ex.get('query_rephrased') is None or not str(ex.get('query_rephrased')).strip(): |
| none_rephrased_count += 1 |
|
|
| logging.info("Final status counts:") |
| for status, count in sorted(status_counts.items()): |
| logging.info(f" - {status}: {count}") |
|
|
| final_success = status_counts.get('success', 0) |
| final_failed = sum(count for st, count in status_counts.items() if st and (st.startswith('failed_') or st == 'processing_retry')) |
| final_skipped = sum(count for st, count in status_counts.items() if st and st.startswith('skipped_')) |
| other_count = len(final_reloaded_dataset) - final_success - final_failed - final_skipped |
|
|
| logging.info(f"Summary: Success={final_success}, Failed/Incomplete={final_failed}, Skipped={final_skipped}, Other={other_count}") |
| if none_rephrased_count > 0: |
| logging.warning(f"WARNING: {none_rephrased_count} items have None/empty 'query_rephrased' in the final dataset.") |
| if final_failed > 0: |
| logging.warning(f"WARNING: {final_failed} items did not reach 'success' or 'skipped' status.") |
|
|
|
|
| except FileNotFoundError: |
| logging.error(f"Verification failed: Final dataset directory not found at {FINAL_OUTPUT_PATH}. Final save likely failed.") |
| except Exception as e: |
| logging.error(f"Verification failed: Could not reload/verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True) |
|
|
| |
| end_time = time.time() |
| logging.info("------------------------------------------------------") |
| logging.info(f"Script finished in {end_time - start_time:.2f} seconds.") |
| logging.info("======================================================") |