vietqa-api / src /pipeline.py
quanho114
Deploy VietQA API
ebb8326
"""Core pipeline execution logic for the RAG system."""
import asyncio
import csv
import sys
import time
from pathlib import Path
from src.config import BATCH_SIZE, DATA_OUTPUT_DIR
from src.data_processing.answer import normalize_answer
from src.data_processing.formatting import format_choices_display, question_to_state
from src.data_processing.models import InferenceLogEntry, PredictionOutput, QuestionInput
from src.graph import get_graph
from src.utils.checkpointing import (
append_log_entry,
consolidate_log_file,
generate_csv_from_log,
is_rate_limit_error,
)
from src.utils.common import sort_qids
from src.utils.ingestion import get_vector_store
from src.utils.logging import log_done, log_pipeline, log_stats, print_log
def sort_questions_by_qid(questions: list[QuestionInput]) -> list[QuestionInput]:
"""Sort questions by qid using natural sorting."""
qid_to_question = {q.qid: q for q in questions}
sorted_qids = sort_qids(list(qid_to_question.keys()))
return [qid_to_question[qid] for qid in sorted_qids]
async def run_pipeline_async(
questions: list[QuestionInput],
batch_size: int = BATCH_SIZE,
) -> list[PredictionOutput]:
"""Run pipeline for inference (assumes pre-built Vector DB).
Args:
questions: List of questions to process
batch_size: Number of concurrent questions to process
Returns:
List of PredictionOutput objects sorted by qid
"""
log_pipeline("Loading pre-built vector store...")
get_vector_store()
questions = sort_questions_by_qid(questions)
graph = get_graph()
total = len(questions)
start_time = time.perf_counter()
sem = asyncio.Semaphore(batch_size)
results: dict[str, PredictionOutput] = {}
async def process_single_question(q: QuestionInput) -> None:
async with sem:
print_log(f"\n[{q.qid}] {q.question}")
print_log(format_choices_display(q.choices))
state = question_to_state(q)
result = await graph.ainvoke(state)
answer = result.get("answer", "A")
route = result.get("route", "unknown")
num_choices = len(q.choices)
normalized_answer = normalize_answer(
answer=answer,
num_choices=num_choices,
question_id=q.qid,
default="A",
)
log_done(f"{q.qid}: {normalized_answer} (Route: {route})")
results[q.qid] = PredictionOutput(qid=q.qid, answer=normalized_answer)
tasks = [process_single_question(q) for q in questions]
await asyncio.gather(*tasks)
elapsed = time.perf_counter() - start_time
throughput = total / elapsed if elapsed > 0 else 0
log_stats(f"Completed {total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
sorted_qids = sort_qids(list(results.keys()))
return [results[qid] for qid in sorted_qids]
async def run_pipeline_with_checkpointing(
questions: list[QuestionInput],
log_path: Path,
batch_size: int = BATCH_SIZE,
) -> int:
"""Run pipeline with JSONL checkpointing for resume capability.
Questions are processed in qid order. Results are appended to log file
immediately for fault tolerance, then consolidated at the end.
Args:
questions: List of questions to process (already filtered for unprocessed)
log_path: Path to JSONL log file for checkpointing
batch_size: Number of concurrent questions to process
Returns:
Count of newly processed questions
"""
log_pipeline("Loading pre-built vector store...")
get_vector_store()
questions = sort_questions_by_qid(questions)
log_pipeline(f"Processing {len(questions)} questions in qid order...")
graph = get_graph()
total = len(questions)
start_time = time.perf_counter()
processed_count = 0
sem = asyncio.Semaphore(batch_size)
stop_event = asyncio.Event()
async def process_single_question(q: QuestionInput) -> None:
nonlocal processed_count
if stop_event.is_set():
return
async with sem:
if stop_event.is_set():
return
print_log(f"\n[{q.qid}] {q.question}")
print_log(format_choices_display(q.choices))
state = question_to_state(q)
try:
result = await graph.ainvoke(state)
route = result.get("route", "unknown")
raw_response = result.get("raw_response", "")
context = result.get("context", "")
answer = normalize_answer(
answer=result.get("answer"),
num_choices=len(q.choices),
question_id=q.qid,
default="A",
)
log_entry = InferenceLogEntry(
qid=q.qid,
question=q.question,
choices=q.choices,
final_answer=answer,
raw_response=raw_response,
route=route,
retrieved_context=context,
)
await append_log_entry(log_path, log_entry)
log_done(f"{q.qid}: {answer} (Route: {route})")
processed_count += 1
# await asyncio.sleep(150)
except Exception as e:
if is_rate_limit_error(e):
print_log(f" [CRITICAL] Rate Limit Detected on {q.qid}: {e}")
stop_event.set()
else:
print_log(f" [Error] Failed to process {q.qid}: {e}")
tasks = [asyncio.create_task(process_single_question(q)) for q in questions]
await asyncio.gather(*tasks)
if stop_event.is_set():
log_pipeline("!!! PIPELINE STOPPED DUE TO RATE LIMIT !!!")
log_pipeline("Consolidating logs and generating emergency submission...")
consolidate_log_file(log_path)
output_file = DATA_OUTPUT_DIR / "submission_emergency.csv"
total_entries = generate_csv_from_log(log_path, output_file)
log_pipeline(f"Saved emergency submission with {total_entries} entries to: {output_file}")
sys.exit(0)
log_pipeline("Consolidating log file...")
consolidate_log_file(log_path)
elapsed = time.perf_counter() - start_time
throughput = total / elapsed if elapsed > 0 else 0
log_stats(f"Processed {processed_count}/{total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
return processed_count
def save_predictions(
predictions: list[PredictionOutput],
output_path: Path,
ensure_dir: bool = True,
) -> None:
"""Save predictions to CSV file, sorted by qid.
Args:
predictions: List of prediction outputs
output_path: Path to output CSV file
ensure_dir: If True, create parent directory if it doesn't exist
"""
if ensure_dir:
output_path.parent.mkdir(parents=True, exist_ok=True)
sorted_qids = sort_qids([p.qid for p in predictions])
pred_dict = {p.qid: p for p in predictions}
with open(output_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["qid", "answer"])
writer.writeheader()
for qid in sorted_qids:
writer.writerow({"qid": qid, "answer": pred_dict[qid].answer})
log_pipeline(f"Predictions saved to: {output_path}")