Spaces:
Running
Running
| """Core pipeline execution logic for the RAG system.""" | |
| import asyncio | |
| import csv | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from src.config import BATCH_SIZE, DATA_OUTPUT_DIR | |
| from src.data_processing.answer import normalize_answer | |
| from src.data_processing.formatting import format_choices_display, question_to_state | |
| from src.data_processing.models import InferenceLogEntry, PredictionOutput, QuestionInput | |
| from src.graph import get_graph | |
| from src.utils.checkpointing import ( | |
| append_log_entry, | |
| consolidate_log_file, | |
| generate_csv_from_log, | |
| is_rate_limit_error, | |
| ) | |
| from src.utils.common import sort_qids | |
| from src.utils.ingestion import get_vector_store | |
| from src.utils.logging import log_done, log_pipeline, log_stats, print_log | |
| def sort_questions_by_qid(questions: list[QuestionInput]) -> list[QuestionInput]: | |
| """Sort questions by qid using natural sorting.""" | |
| qid_to_question = {q.qid: q for q in questions} | |
| sorted_qids = sort_qids(list(qid_to_question.keys())) | |
| return [qid_to_question[qid] for qid in sorted_qids] | |
| async def run_pipeline_async( | |
| questions: list[QuestionInput], | |
| batch_size: int = BATCH_SIZE, | |
| ) -> list[PredictionOutput]: | |
| """Run pipeline for inference (assumes pre-built Vector DB). | |
| Args: | |
| questions: List of questions to process | |
| batch_size: Number of concurrent questions to process | |
| Returns: | |
| List of PredictionOutput objects sorted by qid | |
| """ | |
| log_pipeline("Loading pre-built vector store...") | |
| get_vector_store() | |
| questions = sort_questions_by_qid(questions) | |
| graph = get_graph() | |
| total = len(questions) | |
| start_time = time.perf_counter() | |
| sem = asyncio.Semaphore(batch_size) | |
| results: dict[str, PredictionOutput] = {} | |
| async def process_single_question(q: QuestionInput) -> None: | |
| async with sem: | |
| print_log(f"\n[{q.qid}] {q.question}") | |
| print_log(format_choices_display(q.choices)) | |
| state = question_to_state(q) | |
| result = await graph.ainvoke(state) | |
| answer = result.get("answer", "A") | |
| route = result.get("route", "unknown") | |
| num_choices = len(q.choices) | |
| normalized_answer = normalize_answer( | |
| answer=answer, | |
| num_choices=num_choices, | |
| question_id=q.qid, | |
| default="A", | |
| ) | |
| log_done(f"{q.qid}: {normalized_answer} (Route: {route})") | |
| results[q.qid] = PredictionOutput(qid=q.qid, answer=normalized_answer) | |
| tasks = [process_single_question(q) for q in questions] | |
| await asyncio.gather(*tasks) | |
| elapsed = time.perf_counter() - start_time | |
| throughput = total / elapsed if elapsed > 0 else 0 | |
| log_stats(f"Completed {total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)") | |
| sorted_qids = sort_qids(list(results.keys())) | |
| return [results[qid] for qid in sorted_qids] | |
| async def run_pipeline_with_checkpointing( | |
| questions: list[QuestionInput], | |
| log_path: Path, | |
| batch_size: int = BATCH_SIZE, | |
| ) -> int: | |
| """Run pipeline with JSONL checkpointing for resume capability. | |
| Questions are processed in qid order. Results are appended to log file | |
| immediately for fault tolerance, then consolidated at the end. | |
| Args: | |
| questions: List of questions to process (already filtered for unprocessed) | |
| log_path: Path to JSONL log file for checkpointing | |
| batch_size: Number of concurrent questions to process | |
| Returns: | |
| Count of newly processed questions | |
| """ | |
| log_pipeline("Loading pre-built vector store...") | |
| get_vector_store() | |
| questions = sort_questions_by_qid(questions) | |
| log_pipeline(f"Processing {len(questions)} questions in qid order...") | |
| graph = get_graph() | |
| total = len(questions) | |
| start_time = time.perf_counter() | |
| processed_count = 0 | |
| sem = asyncio.Semaphore(batch_size) | |
| stop_event = asyncio.Event() | |
| async def process_single_question(q: QuestionInput) -> None: | |
| nonlocal processed_count | |
| if stop_event.is_set(): | |
| return | |
| async with sem: | |
| if stop_event.is_set(): | |
| return | |
| print_log(f"\n[{q.qid}] {q.question}") | |
| print_log(format_choices_display(q.choices)) | |
| state = question_to_state(q) | |
| try: | |
| result = await graph.ainvoke(state) | |
| route = result.get("route", "unknown") | |
| raw_response = result.get("raw_response", "") | |
| context = result.get("context", "") | |
| answer = normalize_answer( | |
| answer=result.get("answer"), | |
| num_choices=len(q.choices), | |
| question_id=q.qid, | |
| default="A", | |
| ) | |
| log_entry = InferenceLogEntry( | |
| qid=q.qid, | |
| question=q.question, | |
| choices=q.choices, | |
| final_answer=answer, | |
| raw_response=raw_response, | |
| route=route, | |
| retrieved_context=context, | |
| ) | |
| await append_log_entry(log_path, log_entry) | |
| log_done(f"{q.qid}: {answer} (Route: {route})") | |
| processed_count += 1 | |
| # await asyncio.sleep(150) | |
| except Exception as e: | |
| if is_rate_limit_error(e): | |
| print_log(f" [CRITICAL] Rate Limit Detected on {q.qid}: {e}") | |
| stop_event.set() | |
| else: | |
| print_log(f" [Error] Failed to process {q.qid}: {e}") | |
| tasks = [asyncio.create_task(process_single_question(q)) for q in questions] | |
| await asyncio.gather(*tasks) | |
| if stop_event.is_set(): | |
| log_pipeline("!!! PIPELINE STOPPED DUE TO RATE LIMIT !!!") | |
| log_pipeline("Consolidating logs and generating emergency submission...") | |
| consolidate_log_file(log_path) | |
| output_file = DATA_OUTPUT_DIR / "submission_emergency.csv" | |
| total_entries = generate_csv_from_log(log_path, output_file) | |
| log_pipeline(f"Saved emergency submission with {total_entries} entries to: {output_file}") | |
| sys.exit(0) | |
| log_pipeline("Consolidating log file...") | |
| consolidate_log_file(log_path) | |
| elapsed = time.perf_counter() - start_time | |
| throughput = total / elapsed if elapsed > 0 else 0 | |
| log_stats(f"Processed {processed_count}/{total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)") | |
| return processed_count | |
| def save_predictions( | |
| predictions: list[PredictionOutput], | |
| output_path: Path, | |
| ensure_dir: bool = True, | |
| ) -> None: | |
| """Save predictions to CSV file, sorted by qid. | |
| Args: | |
| predictions: List of prediction outputs | |
| output_path: Path to output CSV file | |
| ensure_dir: If True, create parent directory if it doesn't exist | |
| """ | |
| if ensure_dir: | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| sorted_qids = sort_qids([p.qid for p in predictions]) | |
| pred_dict = {p.qid: p for p in predictions} | |
| with open(output_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=["qid", "answer"]) | |
| writer.writeheader() | |
| for qid in sorted_qids: | |
| writer.writerow({"qid": qid, "answer": pred_dict[qid].answer}) | |
| log_pipeline(f"Predictions saved to: {output_path}") | |