vietqa-api / src /utils /checkpointing.py
quanho114
Deploy VietQA API
ebb8326
"""Checkpointing utilities for resumable pipeline execution."""
import asyncio
import csv
import json
from pathlib import Path
from src.data_processing.models import InferenceLogEntry
from src.utils.common import sort_qids
_jsonl_lock = asyncio.Lock()
def load_log_entries(log_path: Path) -> dict[str, dict]:
"""Load all log entries from JSONL file as a dictionary.
Args:
log_path: Path to the JSONL log file
Returns:
Dictionary mapping qid to full entry data
"""
entries = {}
if not log_path.exists():
return entries
with open(log_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
if "qid" in entry:
entries[entry["qid"]] = entry
except json.JSONDecodeError:
continue
return entries
def load_processed_qids(log_path: Path) -> set[str]:
"""Load already processed question IDs from JSONL log.
Args:
log_path: Path to the JSONL log file
Returns:
Set of question IDs that have been processed
"""
return set(load_log_entries(log_path).keys())
async def append_log_entry(log_path: Path, entry: InferenceLogEntry) -> None:
"""Append a single log entry to JSONL file (thread-safe).
Args:
log_path: Path to the JSONL log file
entry: InferenceLogEntry to append
"""
async with _jsonl_lock:
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a", encoding="utf-8") as f:
f.write(entry.model_dump_json() + "\n")
def consolidate_log_file(log_path: Path) -> None:
"""Consolidate and sort log file by qid.
Reads all entries, removes duplicates (keeps latest), and writes back sorted.
Args:
log_path: Path to the JSONL log file
"""
if not log_path.exists():
return
entries = load_log_entries(log_path)
if not entries:
return
sorted_qids = sort_qids(list(entries.keys()))
# Write back sorted entries
with open(log_path, "w", encoding="utf-8") as f:
for qid in sorted_qids:
f.write(json.dumps(entries[qid], ensure_ascii=False) + "\n")
def generate_csv_from_log(log_path: Path, output_path: Path) -> int:
"""Generate submission CSV from JSONL log, sorted by qid.
Args:
log_path: Path to the JSONL log file
output_path: Path to the output CSV file
Returns:
Count of entries written to CSV
"""
entries = load_log_entries(log_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
sorted_qids = sort_qids(list(entries.keys()))
with open(output_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["qid", "answer"])
writer.writeheader()
for qid in sorted_qids:
writer.writerow({"qid": qid, "answer": entries[qid]["final_answer"]})
return len(entries)
def is_rate_limit_error(error: Exception) -> bool:
"""Check if error is an API rate limit error."""
error_str = str(error).lower()
return (
"429" in error_str or
"too many requests" in error_str or
"rate limit" in error_str or
"quota exceeded" in error_str
)