grsdfdf / r1-a /dataset /examqa_rewrite.py
1f's picture
Add files using upload-large-folder tool
19891ba verified
import os
import http.client
import json
import time
import random
from datasets import load_dataset, Dataset, DatasetDict, Features, Value
from tqdm.auto import tqdm
import sys
import logging
import getpass
import signal
import socket
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import argparse # For command-line arguments
# --- Configuration --- (Mostly Same)
DATASET_NAME = "virtuoussy/Multi-subject-RLVR"
DATASET_SPLIT = "train"
API_HOST = "api2.aigcbest.top"
API_PATH = "/v1/chat/completions"
LLM_MODEL = "gpt-4.1-mini"
API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI") # Simplified API Key Get
if not API_KEY or API_KEY == "YOUR_API_KEY_HERE":
print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.")
sys.exit(1)
OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased"
# Define the path where the *potentially incomplete* processed dataset exists
PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed")
# Define where the *final, retried* dataset will be saved
FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final") # Save to new location initially
BATCH_SAVE_SIZE = 500 # How often to save intermediate progress *during retry*
MAX_WORKERS = 20
REQUEST_DELAY_SECONDS = 0.15
MAX_RETRIES = 3
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
# --- LLM API Function (call_llm_api) ---
# Use the robust version from the previous answer
def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES):
system_prompt = (
"You are an expert linguist specializing in converting structured prompts or "
"fill-in-the-blank problems into natural, spoken-language questions suitable for "
"text-to-speech (TTS). Your goal is to make the question sound like how a person "
"would naturally ask it. "
"If the input is a fill-in-the-blank problem (e.g., contains '-----'), "
"rephrase it as a direct question asking for the missing information. "
"Keep the core meaning, mathematical context, variables, and numbers exactly the same. "
"Focus only on rephrasing the *user's question* part provided. "
"Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'."
)
payload = json.dumps({
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": original_question}
],
})
headers = {
'Accept': 'application/json',
'Authorization': f'Bearer {api_key}',
'User-Agent': 'HuggingFace Dataset Processing Script (Retry Mode)',
'Content-Type': 'application/json'
}
time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2))
for attempt in range(retries):
logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...")
try:
conn = http.client.HTTPSConnection(host, timeout=60)
conn.request("POST", path, payload, headers)
res = conn.getresponse()
status = res.status
data = res.read()
conn.close()
if status == 200:
response_json = json.loads(data.decode("utf-8"))
if response_json.get("choices") and len(response_json["choices"]) > 0:
message = response_json["choices"][0].get("message")
if message and message.get("content"):
rephrased = message["content"].strip()
if len(rephrased) > 1 and ((rephrased.startswith('"') and rephrased.endswith('"')) or \
(rephrased.startswith("'") and rephrased.endswith("'"))):
rephrased = rephrased[1:-1]
if rephrased and rephrased.strip().lower() != original_question.strip().lower():
logging.debug(f"Successfully rephrased: {rephrased[:80]}...")
return rephrased
elif not rephrased:
logging.warning(f"LLM returned empty response for: {original_question[:50]}...")
return None
else:
logging.warning(f"LLM returned identical response for: {original_question[:50]}...")
return None # Treat identical as failure for rephrasing
logging.error(f"Unexpected API response structure: {data.decode('utf-8')}")
return None
elif status == 429:
retry_after_header = res.getheader('Retry-After', '5')
try: wait_time = int(retry_after_header)
except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
time.sleep(wait_time)
elif status >= 500:
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
time.sleep(wait_time)
else:
logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}")
return None
except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e:
logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}")
if attempt + 1 == retries: return None
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
time.sleep(wait_time)
except json.JSONDecodeError as e:
logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200]}")
if attempt + 1 == retries: return None
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
time.sleep(wait_time) # Wait before next attempt
except Exception as e:
logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True)
if attempt + 1 == retries: return None
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
time.sleep(wait_time)
logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...")
return None
# --- Dataset Processing Function (rephrase_query_entry) ---
# Same as before, returns the full dictionary with status
def rephrase_query_entry(example):
processed_example = example.copy()
# Ensure status field exists, default to unprocessed if missing
if 'query_rephrased_status' not in processed_example:
processed_example['query_rephrased_status'] = 'unprocessed'
original_query_list = example.get("query")
# --- Input Validation ---
if original_query_list is None:
processed_example['query_rephrased_status'] = 'skipped_missing_query_column'
processed_example['query_rephrased'] = None
return processed_example
if not isinstance(original_query_list, list):
processed_example['query_rephrased_status'] = 'skipped_query_not_list'
processed_example['query_rephrased'] = None
return processed_example
if not original_query_list:
processed_example['query_rephrased_status'] = 'skipped_query_list_empty'
processed_example['query_rephrased'] = None
return processed_example
# --- Find User Question ---
user_question = None
for i, message in enumerate(original_query_list):
if isinstance(message, dict) and message.get("role") == "user":
content = message.get("content")
if isinstance(content, str) and content.strip():
user_question = content
break
else:
processed_example['query_rephrased_status'] = 'skipped_invalid_user_content'
processed_example['query_rephrased'] = None
return processed_example
if not user_question:
processed_example['query_rephrased_status'] = 'skipped_no_user_content_found'
processed_example['query_rephrased'] = None
return processed_example
# --- Call LLM API ---
logging.info(f"Attempting to rephrase: {user_question[:60]}...") # Log retry attempt
rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL)
# --- Update Example Based on API Result ---
if rephrased_query_content:
logging.debug(f"Rephrased '{user_question[:30]}...' to '{rephrased_query_content[:30]}...'")
processed_example["query_rephrased"] = rephrased_query_content
processed_example['query_rephrased_status'] = 'success_retried' # New status for successful retry
else:
logging.warning(f"Retry failed for user question: {user_question[:50]}...")
# Keep existing rephrased content (likely None) but update status
processed_example['query_rephrased_status'] = 'failed_llm_retry' # New status for failed retry
return processed_example
# --- Function to Save Progress ---
# Saves the *entire list* of dictionaries
def save_final_dataset(data_list, output_path):
"""Saves the final list of processed data dictionaries."""
if not data_list:
logging.info("No data provided for saving.")
return False
logging.info(f"Attempting to save {len(data_list)} final examples to {output_path}...")
try:
# Define features explicitly to handle potential Nones and ensure consistency
# Adjust types based on your actual dataset structure
features = Features({
'query': [{'role': Value(dtype='string', id=None), 'content': Value(dtype='string', id=None)}],
'query_rephrased': Value(dtype='string', id=None), # Allow nulls
'query_rephrased_status': Value(dtype='string', id=None), # Allow nulls
# Add other columns from your original dataset here...
# Example: 'answer': Value(dtype='string', id=None),
# Example: 'subject': Value(dtype='string', id=None),
# IMPORTANT: List all columns present in your loaded dataset
'query_code': Value(dtype='string', id=None),
'answer': Value(dtype='string', id=None),
'answer_code': Value(dtype='string', id=None),
'subject': Value(dtype='string', id=None),
'grade': Value(dtype='string', id=None),
'source': Value(dtype='string', id=None),
'split': Value(dtype='string', id=None),
'__index_level_0__': Value(dtype='int64', id=None) # Check if this column exists
})
# Clean data slightly - replace python None with "" for string fields if needed by Arrow
# or ensure feature definition handles nulls correctly (Value(dtype='string', id=None) should)
# cleaned_data_list = []
# for item in data_list:
# cleaned_item = item.copy()
# for key, feature_type in features.items():
# if isinstance(feature_type, Value) and feature_type.dtype == 'string':
# if cleaned_item.get(key) is None:
# cleaned_item[key] = "" # Or keep None if schema allows
# cleaned_data_list.append(cleaned_item)
# Use the original list directly if schema handles None
processed_dataset = Dataset.from_list(list(data_list), features=features)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
processed_dataset.save_to_disk(output_path)
logging.info(f"Successfully saved final dataset ({len(data_list)} items) to {output_path}")
return True
except Exception as e:
logging.error(f"Failed to save final dataset to {output_path}: {e}", exc_info=True)
# Try saving as JSON as a fallback
fallback_json_path = output_path + ".jsonl"
logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}")
try:
with open(fallback_json_path, 'w', encoding='utf-8') as f:
for item in data_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logging.info(f"Successfully saved fallback JSON Lines file to {fallback_json_path}")
except Exception as json_e:
logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True)
return False
# --- Helper to get original user query ---
def get_user_query(example):
"""Extracts the user query content from the 'query' list."""
query_list = example.get("query")
if isinstance(query_list, list):
for message in query_list:
if isinstance(message, dict) and message.get("role") == "user":
content = message.get("content")
if isinstance(content, str) and content.strip():
return content
return None
# --- Function to Check if Retry is Needed ---
def needs_retry(example):
"""Determines if an example needs reprocessing based on its current state."""
status = example.get('query_rephrased_status')
rephrased_text = example.get('query_rephrased')
# Condition 1: Explicit failure status from previous (new script) run
if status in ['failed_llm_call', 'failed_llm_retry', 'failed_processing_exception']:
return True
# Condition 2: Certain 'skipped' statuses might warrant a retry (optional, adjust as needed)
# For example, if the user content was invalid originally, retrying won't help.
# if status in ['skipped_no_user_content_found']: # Decide if these should be retried
# return True
# Condition 3: Status indicates success OR status is missing/old,
# BUT the rephrased text is missing or empty. This catches failures
# from the *old* script or inconsistent states.
if rephrased_text is None or not str(rephrased_text).strip():
# Don't retry if it was intentionally skipped due to bad input
if status not in ['skipped_missing_query_column', 'skipped_query_not_list',
'skipped_query_list_empty', 'skipped_invalid_user_content',
'skipped_no_user_content_found']:
return True
# Condition 4 (Optional but recommended): Check if rephrased text is identical to original user query
# This requires extracting the original query here.
# original_user_query = get_user_query(example)
# if original_user_query and isinstance(rephrased_text, str) and \
# rephrased_text.strip().lower() == original_user_query.strip().lower():
# # Check status first - if it was intentionally skipped, don't retry
# if status not in ['skipped_missing_query_column', 'skipped_query_not_list',
# 'skipped_query_list_empty', 'skipped_invalid_user_content',
# 'skipped_no_user_content_found']:
# logging.debug(f"Identified identical query/rephrased text for retry: {original_user_query[:50]}...")
# return True
# Default: No retry needed
return False
# --- Main Execution ---
if __name__ == "__main__":
start_time = time.time()
logging.info("======================================================")
logging.info(f" Starting Dataset Processing Script in RETRY MODE")
logging.info("======================================================")
logging.info(f"Dataset: {DATASET_NAME}, Split: {DATASET_SPLIT}")
logging.info(f"Loading existing processed data from: {PROCESSED_DATA_PATH}")
logging.info(f"Final output will be saved to: {FINAL_OUTPUT_PATH}")
logging.info(f"Max concurrent workers: {MAX_WORKERS}")
# --- Load Existing Processed Dataset ---
if not os.path.exists(PROCESSED_DATA_PATH):
logging.error(f"Existing processed data not found at '{PROCESSED_DATA_PATH}'. Cannot run in retry mode.")
sys.exit(1)
logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...")
try:
# Load the dataset saved by the previous script run
existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH)
# Convert to list of dictionaries for easier modification access by index
# Be mindful of memory usage for very large datasets
results_list = existing_dataset.to_list()
total_examples = len(results_list)
logging.info(f"Loaded {total_examples} examples.")
# Ensure essential columns exist, add them if missing from old format
for i in range(total_examples):
if 'query_rephrased' not in results_list[i]:
results_list[i]['query_rephrased'] = None
if 'query_rephrased_status' not in results_list[i]:
results_list[i]['query_rephrased_status'] = 'unknown_original_status'
except Exception as e:
logging.error(f"Failed to load existing dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True)
sys.exit(1)
# --- Identify Indices to Retry ---
indices_to_retry = [
i for i, example in enumerate(results_list) if needs_retry(example)
]
num_to_retry = len(indices_to_retry)
if num_to_retry == 0:
logging.info("No examples found needing retry based on the criteria.")
logging.info(f"The dataset at {PROCESSED_DATA_PATH} is considered final.")
# Optional: You might still want to save it to FINAL_OUTPUT_PATH for consistency
# if not os.path.exists(FINAL_OUTPUT_PATH):
# save_final_dataset(results_list, FINAL_OUTPUT_PATH)
sys.exit(0)
logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.")
# --- Prepare for Concurrent Retries ---
processed_count_in_retry = 0
# We don't need batch saving in the same way, but can update the list in memory
# A temporary dictionary to store results from futures before updating the main list
retry_results_dict = {}
logging.info("Starting concurrent processing for examples needing retry...")
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# Submit jobs only for the indices needing retry
# Pass the *specific example dictionary* to the function
futures = {
executor.submit(rephrase_query_entry, results_list[i]): i
for i in indices_to_retry
}
try:
pbar = tqdm(total=num_to_retry, desc="Retrying failed examples", unit="example")
for future in concurrent.futures.as_completed(futures):
original_index = futures[future] # Get the index in the full results_list
try:
# Get the updated dictionary result from the retry attempt
updated_example_dict = future.result()
# Store the result temporarily, keyed by original index
retry_results_dict[original_index] = updated_example_dict
pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True)
except Exception as exc:
# Catch errors *during* the retry processing itself
logging.error(f'Retry for example index {original_index} generated an exception: {exc}', exc_info=True)
# Create a placeholder indicating the retry attempt failed due to an exception
error_placeholder = results_list[original_index].copy() # Get original data again
error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}'
# Store this error placeholder
retry_results_dict[original_index] = error_placeholder
pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True)
finally:
processed_count_in_retry += 1
pbar.update(1)
# Optional intermediate save logic (maybe save every N retries)
# Could save the *entire* potentially partially updated list, but might be slow.
# if processed_count_in_retry % BATCH_SAVE_SIZE == 0:
# logging.info(f"Processed {processed_count_in_retry} retries, updating intermediate state...")
# # Update the main list with results gathered so far
# for idx, updated_item in retry_results_dict.items():
# results_list[idx] = updated_item
# # Clear the temporary dict after updating
# retry_results_dict.clear()
# # Save the whole list (potentially slow)
# save_final_dataset(results_list, FINAL_OUTPUT_PATH + "_interim")
except KeyboardInterrupt:
logging.warning("\nCtrl+C detected during retry! Attempting to save progress...")
except Exception as e:
logging.error(f"An unexpected error occurred during the retry loop: {e}", exc_info=True)
finally:
if 'pbar' in locals():
pbar.close()
# --- Update the main results list with all completed retries ---
logging.info("Updating main results list with completed retry attempts...")
update_count = 0
for idx, updated_item in retry_results_dict.items():
if idx < len(results_list):
results_list[idx] = updated_item
update_count += 1
else:
logging.error(f"Index {idx} from retry results is out of bounds for results_list (size {len(results_list)}). Skipping update.")
logging.info(f"Applied updates for {update_count} retried items.")
# --- Final Save ---
logging.info(f"Attempting to save the final updated dataset to: {FINAL_OUTPUT_PATH}")
if save_final_dataset(results_list, FINAL_OUTPUT_PATH):
logging.info("Final dataset saved successfully.")
# Optional: Suggest deleting the old intermediate path if successful
# logging.info(f"You may now safely remove the intermediate directory: {PROCESSED_DATA_PATH}")
else:
logging.error(">>> FINAL SAVE FAILED! <<<")
logging.error(f"Check the logs. The latest state might be in memory or a fallback JSON file if created.")
# --- Final Verification (Optional) ---
logging.info("------------------------------------------------------")
logging.info("Verification: Loading final saved dataset for status check...")
try:
final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH)
logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.")
status_counts = {}
for ex in final_reloaded_dataset:
status = ex.get('query_rephrased_status', 'unknown_status_field')
status_counts[status] = status_counts.get(status, 0) + 1
logging.info("Status counts in the final saved file:")
for status, count in sorted(status_counts.items()):
logging.info(f" - {status}: {count}")
# Highlight remaining failures
remaining_failures = status_counts.get('failed_llm_retry', 0) + \
status_counts.get('failed_retry_exception', 0) + \
status_counts.get('failed_llm_call', 0) # Include original failures if not retried/still failing
if remaining_failures > 0:
logging.warning(f"Found {remaining_failures} examples still marked as failed after retry attempts.")
else:
logging.info("All identified failures appear to have been successfully retried or were not retried.")
except FileNotFoundError:
logging.error(f"Verification failed: Final saved dataset not found at {FINAL_OUTPUT_PATH}.")
except Exception as e:
logging.error(f"Failed to reload or verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True)
end_time = time.time()
logging.info("------------------------------------------------------")
logging.info(f"Retry script finished in {end_time - start_time:.2f} seconds.")
logging.info("======================================================")