grsdfdf / r1-a /dataset /retry_rewrite.py
1f's picture
Add files using upload-large-folder tool
19891ba verified
import os
import http.client
import json
import time
import random
# Import necessary types from datasets
from datasets import load_dataset, Dataset, DatasetDict, Features, Value, Sequence
from tqdm.auto import tqdm
import sys
import logging
# Removed unused imports (like socket, already used by http.client indirectly)
# import socket # Not directly needed now
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import shutil # Needed for atomic directory removal
# --- Configuration ---
DATASET_NAME = "virtuoussy/Multi-subject-RLVR" # Or the original source if needed
DATASET_SPLIT = "train"
API_HOST = "api2.aigcbest.top"
API_PATH = "/v1/chat/completions"
LLM_MODEL = "gpt-4.1-mini"
API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI")
if not API_KEY or API_KEY == "YOUR_API_KEY_HERE":
print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.")
sys.exit(1)
OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased"
# Path to the existing, potentially incomplete, processed dataset (LOAD ONLY)
PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed")
# Path where intermediate and final results will be saved (SAVE ONLY)
FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final")
MAX_WORKERS = 20 # Adjust based on your system and API rate limits
REQUEST_DELAY_SECONDS = 0.15 # Base delay between requests
MAX_RETRIES = 3 # Max retries for each API call
SAVE_INTERVAL = 2000 # <<<--- How often to save progress (in number of processed items)
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING) # Quiet down filelock warnings during save
# --- LLM API Function (call_llm_api) ---
# (No changes needed here, keep the robust version)
def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES):
system_prompt = (
"You are an expert linguist specializing in converting structured prompts or "
"fill-in-the-blank problems into natural, spoken-language questions suitable for "
"text-to-speech (TTS). Your goal is to make the question sound like how a person "
"would naturally ask it. "
"If the input is a fill-in-the-blank problem (e.g., contains '-----'), "
"rephrase it as a direct question asking for the missing information. "
"Keep the core meaning, mathematical context, variables, and numbers exactly the same. "
"Focus only on rephrasing the *user's question* part provided. "
"Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'."
)
payload = json.dumps({
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": original_question}
],
})
headers = {
'Accept': 'application/json',
'Authorization': f'Bearer {api_key}',
'User-Agent': 'HuggingFace Dataset Processing Script (Retry w/ Save)',
'Content-Type': 'application/json'
}
time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2))
for attempt in range(retries):
# logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...")
try:
conn = http.client.HTTPSConnection(host, timeout=60) # Increased timeout
conn.request("POST", path, payload, headers)
res = conn.getresponse()
status = res.status
data = res.read()
conn.close()
if status == 200:
response_json = json.loads(data.decode("utf-8"))
if response_json.get("choices") and len(response_json["choices"]) > 0:
message = response_json["choices"][0].get("message")
if message and message.get("content"):
rephrased = message["content"].strip()
# Remove surrounding quotes more robustly
if len(rephrased) > 1:
if (rephrased.startswith('"') and rephrased.endswith('"')) or \
(rephrased.startswith("'") and rephrased.endswith("'")):
rephrased = rephrased[1:-1]
# Handle cases like 'Rephrased: "..."'
if rephrased.lower().startswith(("rephrased:", "here's the rephrased question:")):
parts = rephrased.split(":", 1)
if len(parts) > 1:
potential_rephrased = parts[1].strip()
if (potential_rephrased.startswith('"') and potential_rephrased.endswith('"')) or \
(potential_rephrased.startswith("'") and potential_rephrased.endswith("'")):
rephrased = potential_rephrased[1:-1]
else:
rephrased = potential_rephrased
if rephrased and rephrased.strip().lower() != original_question.strip().lower():
# logging.debug(f"Successfully rephrased: {rephrased[:80]}...")
return rephrased
elif not rephrased:
logging.warning(f"LLM returned empty/whitespace response for: {original_question[:50]}...")
return None
else:
logging.warning(f"LLM returned identical response for: {original_question[:50]}...")
return None # Treat identical as failure
logging.error(f"Unexpected API response structure: {data.decode('utf-8')}")
return None
elif status == 429: # Rate limit
retry_after_header = res.getheader('Retry-After', '5')
try: wait_time = int(retry_after_header)
except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
time.sleep(wait_time)
elif status >= 500: # Server error
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
time.sleep(wait_time)
else: # Other client errors (4xx) - Don't retry these
logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}")
return None
except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e:
logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}")
if attempt + 1 == retries: return None
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
time.sleep(wait_time)
except json.JSONDecodeError as e:
logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200] if data else 'N/A'}")
if attempt + 1 == retries: return None
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
time.sleep(wait_time)
except Exception as e:
logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True)
if attempt + 1 == retries: return None
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
time.sleep(wait_time)
logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...")
return None
# --- Dataset Processing Function (rephrase_query_entry) ---
# (No changes needed here)
def rephrase_query_entry(example):
processed_example = example.copy()
original_query_list = example.get("query")
processed_example['query_rephrased_status'] = 'processing_retry'
if original_query_list is None:
processed_example['query_rephrased_status'] = 'skipped_missing_query_column'
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
return processed_example
if not isinstance(original_query_list, list):
processed_example['query_rephrased_status'] = 'skipped_query_not_list'
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
return processed_example
if not original_query_list:
processed_example['query_rephrased_status'] = 'skipped_query_list_empty'
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
return processed_example
user_question = None
for message in original_query_list:
if isinstance(message, dict) and message.get("role") == "user":
content = message.get("content")
if isinstance(content, str) and content.strip():
user_question = content
break
else:
processed_example['query_rephrased_status'] = 'skipped_invalid_user_content'
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
return processed_example
if not user_question:
processed_example['query_rephrased_status'] = 'skipped_no_user_content_found'
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
return processed_example
# logging.info(f"Retrying: {user_question[:60]}...")
rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL)
if rephrased_query_content:
processed_example["query_rephrased"] = rephrased_query_content
processed_example['query_rephrased_status'] = 'success'
else:
# Keep the OLD 'query_rephrased' value if LLM call fails this time
processed_example['query_rephrased'] = example.get('query_rephrased')
processed_example['query_rephrased_status'] = 'failed_llm_call'
return processed_example
# --- Function to Save Dataset Atomically ---
# Saves to a temporary path then renames for safety.
def save_dataset_atomically(data_list, output_path, features):
"""Saves the list of data dictionaries atomically using the correct schema."""
if not data_list:
logging.info("No data provided for saving.")
return False
temp_output_path = output_path + "_saving" # Temporary directory
final_output_path = output_path
logging.info(f"Attempting to save {len(data_list)} examples to temp path {temp_output_path}...")
try:
# Create dataset from the list of dictionaries using the defined features
processed_dataset = Dataset.from_list(list(data_list), features=features) # Convert just in case
# Ensure parent directory exists
os.makedirs(os.path.dirname(final_output_path), exist_ok=True)
# Remove any previous temporary directory if it exists
if os.path.exists(temp_output_path):
logging.warning(f"Removing existing temporary save directory: {temp_output_path}")
shutil.rmtree(temp_output_path) # Use shutil for directories
# Save the dataset to the temporary path
processed_dataset.save_to_disk(temp_output_path)
logging.info(f"Successfully saved dataset to temporary path: {temp_output_path}")
# --- Atomic Rename ---
# Remove the final destination path if it exists
if os.path.exists(final_output_path):
logging.debug(f"Removing existing final destination directory before rename: {final_output_path}")
shutil.rmtree(final_output_path)
# Rename the temporary path to the final path
os.rename(temp_output_path, final_output_path)
logging.info(f"Successfully moved temporary save to final path: {final_output_path}")
return True
except Exception as e:
logging.error(f"Failed during atomic save process to {final_output_path}: {e}", exc_info=True)
# Attempt to clean up temporary directory if it still exists after failure
if os.path.exists(temp_output_path):
try:
shutil.rmtree(temp_output_path)
logging.info(f"Cleaned up temporary directory {temp_output_path} after error.")
except Exception as cleanup_e:
logging.error(f"Could not clean up temporary directory {temp_output_path} after error: {cleanup_e}")
# Fallback save attempt to JSON Lines (unchanged)
fallback_json_path = final_output_path + ".jsonl.failed_save" # Indicate it's a fallback
logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}")
try:
with open(fallback_json_path, 'w', encoding='utf-8') as f:
for item in data_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logging.info(f"Successfully saved fallback JSON Lines file.")
except Exception as json_e:
logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True)
return False
# --- Function to Check if Retry is Needed ---
# (No changes needed here)
def needs_retry(example):
rephrased = example.get('query_rephrased')
status = example.get('query_rephrased_status')
# Retry if rephrased is missing OR status is anything other than 'success'
# This ensures failed/skipped items are retried.
retry_flag = (rephrased is None) or (status != 'success')
return retry_flag
# --- Main Execution ---
if __name__ == "__main__":
start_time = time.time()
logging.info("======================================================")
logging.info(f" Starting Dataset Processing - RETRY w/ PERIODIC SAVE")
logging.info(f" Saving progress every {SAVE_INTERVAL} processed items.")
logging.info("======================================================")
logging.info(f"Loading existing data from: {PROCESSED_DATA_PATH}")
logging.info(f"Intermediate and final output will be saved to: {FINAL_OUTPUT_PATH}")
# --- Load Existing Processed Dataset ---
if not os.path.exists(PROCESSED_DATA_PATH):
logging.error(f"Existing data directory not found at '{PROCESSED_DATA_PATH}'. Cannot run retry mode.")
sys.exit(1)
logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...")
try:
existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH)
# Get features *before* converting to list
dataset_features = existing_dataset.features
logging.info(f"Dataset features detected: {dataset_features}")
# Convert to a list of dictionaries for in-memory modification
results_list = existing_dataset.to_list()
total_examples = len(results_list)
logging.info(f"Loaded {total_examples} examples.")
except Exception as e:
logging.error(f"Failed to load dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True)
# Check if the final path exists from a previous run - maybe load that?
# For now, exiting is safer to avoid inconsistent states.
# if os.path.exists(FINAL_OUTPUT_PATH):
# logging.warning(f"Consider manually checking/using the existing data at {FINAL_OUTPUT_PATH}")
sys.exit(1)
# --- Identify Indices to Retry ---
logging.info("Identifying examples needing retry...")
indices_to_retry = [
i for i, example in enumerate(tqdm(results_list, desc="Checking examples")) if needs_retry(example)
]
num_to_retry = len(indices_to_retry)
if num_to_retry == 0:
logging.info("No examples found needing retry based on the criteria ('query_rephrased' is None or status != 'success').")
logging.info(f"Saving the existing dataset to the final location '{FINAL_OUTPUT_PATH}' as is...")
if not save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): # Use atomic save
logging.error("Failed to save the dataset to the final location even though no retries were needed.")
sys.exit(0)
logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.")
# --- Prepare for Concurrent Retries ---
processed_count_total = 0 # Total processed in this run
processed_since_last_save = 0 # Counter for periodic saving
last_save_time = time.time() # Track time for saving message
logging.info("Starting concurrent retries with periodic saving...")
# --- ThreadPoolExecutor for Concurrency ---
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# Submit tasks only for the identified indices
futures = {
executor.submit(rephrase_query_entry, results_list[i]): i
for i in indices_to_retry
}
try:
# Initialize progress bar for retries
pbar = tqdm(total=num_to_retry, desc="Retrying examples", unit="example")
# Process futures as they complete
for future in concurrent.futures.as_completed(futures):
original_index = futures[future] # Get the original list index
try:
# Get the result (the updated dictionary)
updated_example_dict = future.result()
# --- IMMEDIATE UPDATE of the main list ---
results_list[original_index] = updated_example_dict
pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True)
except Exception as exc:
# Catch potential exceptions *from* the rephrase_query_entry function
logging.error(f'Retry task for index {original_index} encountered an exception: {exc}', exc_info=True)
# Create an error placeholder and update the main list
error_placeholder = results_list[original_index].copy() # Start with original data
error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}'
# Keep the old query_rephrased value
results_list[original_index] = error_placeholder
pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True)
finally:
# Increment counters and update progress bar
processed_count_total += 1
processed_since_last_save += 1
pbar.update(1)
# --- Periodic Save Check ---
if processed_since_last_save >= SAVE_INTERVAL:
current_time = time.time()
time_since_last = current_time - last_save_time
logging.info(f"\n--- Processed {processed_since_last_save} items (Total: {processed_count_total}/{num_to_retry}). Time since last save: {time_since_last:.1f}s. Saving progress... ---")
if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features):
logging.info(f"--- Progress successfully saved to {FINAL_OUTPUT_PATH} ---")
processed_since_last_save = 0 # Reset counter
last_save_time = current_time
else:
logging.error(f"--- FAILED TO SAVE PROGRESS! Check errors above. Will retry saving later. ---")
# Don't reset the counter, maybe the next save will work
except KeyboardInterrupt:
logging.warning("\nCtrl+C detected! Attempting final save...")
# Let the finally block handle the save
except Exception as e:
logging.error(f"An unexpected error occurred during the main retry loop: {e}", exc_info=True)
logging.error("Attempting final save...")
# Let the finally block handle the save
finally:
# --- This block executes after the loop finishes, OR if an exception/interrupt occurs ---
if 'pbar' in locals() and pbar is not None:
pbar.close()
logging.info("--- Processing loop finished or interrupted. ---")
# --- Final Save Attempt ---
# No need to update results_list again, it was updated incrementally.
logging.info(f"Attempting final save of the dataset ({len(results_list)} items) to: {FINAL_OUTPUT_PATH}")
if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features):
logging.info("--- Final dataset state saved successfully. ---")
else:
logging.error(">>> FINAL SAVE FAILED! <<< Check logs. Fallback JSON file might exist.")
# --- Final Verification (Optional but Recommended) ---
logging.info("------------------------------------------------------")
logging.info("Verification: Attempting to load final saved dataset...")
try:
final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH)
logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.")
# Simple status count
status_counts = {}
none_rephrased_count = 0
for ex in final_reloaded_dataset:
status = ex.get('query_rephrased_status', 'unknown_status')
status_counts[status] = status_counts.get(status, 0) + 1
if ex.get('query_rephrased') is None or not str(ex.get('query_rephrased')).strip():
none_rephrased_count += 1
logging.info("Final status counts:")
for status, count in sorted(status_counts.items()):
logging.info(f" - {status}: {count}")
final_success = status_counts.get('success', 0)
final_failed = sum(count for st, count in status_counts.items() if st and (st.startswith('failed_') or st == 'processing_retry')) # Items potentially stuck
final_skipped = sum(count for st, count in status_counts.items() if st and st.startswith('skipped_'))
other_count = len(final_reloaded_dataset) - final_success - final_failed - final_skipped
logging.info(f"Summary: Success={final_success}, Failed/Incomplete={final_failed}, Skipped={final_skipped}, Other={other_count}")
if none_rephrased_count > 0:
logging.warning(f"WARNING: {none_rephrased_count} items have None/empty 'query_rephrased' in the final dataset.")
if final_failed > 0:
logging.warning(f"WARNING: {final_failed} items did not reach 'success' or 'skipped' status.")
except FileNotFoundError:
logging.error(f"Verification failed: Final dataset directory not found at {FINAL_OUTPUT_PATH}. Final save likely failed.")
except Exception as e:
logging.error(f"Verification failed: Could not reload/verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True)
# --- Script End ---
end_time = time.time()
logging.info("------------------------------------------------------")
logging.info(f"Script finished in {end_time - start_time:.2f} seconds.")
logging.info("======================================================")