Spaces:
Running
Running
| """Checkpointing utilities for resumable pipeline execution.""" | |
| import asyncio | |
| import csv | |
| import json | |
| from pathlib import Path | |
| from src.data_processing.models import InferenceLogEntry | |
| from src.utils.common import sort_qids | |
| _jsonl_lock = asyncio.Lock() | |
| def load_log_entries(log_path: Path) -> dict[str, dict]: | |
| """Load all log entries from JSONL file as a dictionary. | |
| Args: | |
| log_path: Path to the JSONL log file | |
| Returns: | |
| Dictionary mapping qid to full entry data | |
| """ | |
| entries = {} | |
| if not log_path.exists(): | |
| return entries | |
| with open(log_path, encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| entry = json.loads(line) | |
| if "qid" in entry: | |
| entries[entry["qid"]] = entry | |
| except json.JSONDecodeError: | |
| continue | |
| return entries | |
| def load_processed_qids(log_path: Path) -> set[str]: | |
| """Load already processed question IDs from JSONL log. | |
| Args: | |
| log_path: Path to the JSONL log file | |
| Returns: | |
| Set of question IDs that have been processed | |
| """ | |
| return set(load_log_entries(log_path).keys()) | |
| async def append_log_entry(log_path: Path, entry: InferenceLogEntry) -> None: | |
| """Append a single log entry to JSONL file (thread-safe). | |
| Args: | |
| log_path: Path to the JSONL log file | |
| entry: InferenceLogEntry to append | |
| """ | |
| async with _jsonl_lock: | |
| log_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(log_path, "a", encoding="utf-8") as f: | |
| f.write(entry.model_dump_json() + "\n") | |
| def consolidate_log_file(log_path: Path) -> None: | |
| """Consolidate and sort log file by qid. | |
| Reads all entries, removes duplicates (keeps latest), and writes back sorted. | |
| Args: | |
| log_path: Path to the JSONL log file | |
| """ | |
| if not log_path.exists(): | |
| return | |
| entries = load_log_entries(log_path) | |
| if not entries: | |
| return | |
| sorted_qids = sort_qids(list(entries.keys())) | |
| # Write back sorted entries | |
| with open(log_path, "w", encoding="utf-8") as f: | |
| for qid in sorted_qids: | |
| f.write(json.dumps(entries[qid], ensure_ascii=False) + "\n") | |
| def generate_csv_from_log(log_path: Path, output_path: Path) -> int: | |
| """Generate submission CSV from JSONL log, sorted by qid. | |
| Args: | |
| log_path: Path to the JSONL log file | |
| output_path: Path to the output CSV file | |
| Returns: | |
| Count of entries written to CSV | |
| """ | |
| entries = load_log_entries(log_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| sorted_qids = sort_qids(list(entries.keys())) | |
| with open(output_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=["qid", "answer"]) | |
| writer.writeheader() | |
| for qid in sorted_qids: | |
| writer.writerow({"qid": qid, "answer": entries[qid]["final_answer"]}) | |
| return len(entries) | |
| def is_rate_limit_error(error: Exception) -> bool: | |
| """Check if error is an API rate limit error.""" | |
| error_str = str(error).lower() | |
| return ( | |
| "429" in error_str or | |
| "too many requests" in error_str or | |
| "rate limit" in error_str or | |
| "quota exceeded" in error_str | |
| ) | |