Spaces:
Sleeping
Sleeping
File size: 7,446 Bytes
ebb8326 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """Core pipeline execution logic for the RAG system."""
import asyncio
import csv
import sys
import time
from pathlib import Path
from src.config import BATCH_SIZE, DATA_OUTPUT_DIR
from src.data_processing.answer import normalize_answer
from src.data_processing.formatting import format_choices_display, question_to_state
from src.data_processing.models import InferenceLogEntry, PredictionOutput, QuestionInput
from src.graph import get_graph
from src.utils.checkpointing import (
append_log_entry,
consolidate_log_file,
generate_csv_from_log,
is_rate_limit_error,
)
from src.utils.common import sort_qids
from src.utils.ingestion import get_vector_store
from src.utils.logging import log_done, log_pipeline, log_stats, print_log
def sort_questions_by_qid(questions: list[QuestionInput]) -> list[QuestionInput]:
"""Sort questions by qid using natural sorting."""
qid_to_question = {q.qid: q for q in questions}
sorted_qids = sort_qids(list(qid_to_question.keys()))
return [qid_to_question[qid] for qid in sorted_qids]
async def run_pipeline_async(
questions: list[QuestionInput],
batch_size: int = BATCH_SIZE,
) -> list[PredictionOutput]:
"""Run pipeline for inference (assumes pre-built Vector DB).
Args:
questions: List of questions to process
batch_size: Number of concurrent questions to process
Returns:
List of PredictionOutput objects sorted by qid
"""
log_pipeline("Loading pre-built vector store...")
get_vector_store()
questions = sort_questions_by_qid(questions)
graph = get_graph()
total = len(questions)
start_time = time.perf_counter()
sem = asyncio.Semaphore(batch_size)
results: dict[str, PredictionOutput] = {}
async def process_single_question(q: QuestionInput) -> None:
async with sem:
print_log(f"\n[{q.qid}] {q.question}")
print_log(format_choices_display(q.choices))
state = question_to_state(q)
result = await graph.ainvoke(state)
answer = result.get("answer", "A")
route = result.get("route", "unknown")
num_choices = len(q.choices)
normalized_answer = normalize_answer(
answer=answer,
num_choices=num_choices,
question_id=q.qid,
default="A",
)
log_done(f"{q.qid}: {normalized_answer} (Route: {route})")
results[q.qid] = PredictionOutput(qid=q.qid, answer=normalized_answer)
tasks = [process_single_question(q) for q in questions]
await asyncio.gather(*tasks)
elapsed = time.perf_counter() - start_time
throughput = total / elapsed if elapsed > 0 else 0
log_stats(f"Completed {total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
sorted_qids = sort_qids(list(results.keys()))
return [results[qid] for qid in sorted_qids]
async def run_pipeline_with_checkpointing(
questions: list[QuestionInput],
log_path: Path,
batch_size: int = BATCH_SIZE,
) -> int:
"""Run pipeline with JSONL checkpointing for resume capability.
Questions are processed in qid order. Results are appended to log file
immediately for fault tolerance, then consolidated at the end.
Args:
questions: List of questions to process (already filtered for unprocessed)
log_path: Path to JSONL log file for checkpointing
batch_size: Number of concurrent questions to process
Returns:
Count of newly processed questions
"""
log_pipeline("Loading pre-built vector store...")
get_vector_store()
questions = sort_questions_by_qid(questions)
log_pipeline(f"Processing {len(questions)} questions in qid order...")
graph = get_graph()
total = len(questions)
start_time = time.perf_counter()
processed_count = 0
sem = asyncio.Semaphore(batch_size)
stop_event = asyncio.Event()
async def process_single_question(q: QuestionInput) -> None:
nonlocal processed_count
if stop_event.is_set():
return
async with sem:
if stop_event.is_set():
return
print_log(f"\n[{q.qid}] {q.question}")
print_log(format_choices_display(q.choices))
state = question_to_state(q)
try:
result = await graph.ainvoke(state)
route = result.get("route", "unknown")
raw_response = result.get("raw_response", "")
context = result.get("context", "")
answer = normalize_answer(
answer=result.get("answer"),
num_choices=len(q.choices),
question_id=q.qid,
default="A",
)
log_entry = InferenceLogEntry(
qid=q.qid,
question=q.question,
choices=q.choices,
final_answer=answer,
raw_response=raw_response,
route=route,
retrieved_context=context,
)
await append_log_entry(log_path, log_entry)
log_done(f"{q.qid}: {answer} (Route: {route})")
processed_count += 1
# await asyncio.sleep(150)
except Exception as e:
if is_rate_limit_error(e):
print_log(f" [CRITICAL] Rate Limit Detected on {q.qid}: {e}")
stop_event.set()
else:
print_log(f" [Error] Failed to process {q.qid}: {e}")
tasks = [asyncio.create_task(process_single_question(q)) for q in questions]
await asyncio.gather(*tasks)
if stop_event.is_set():
log_pipeline("!!! PIPELINE STOPPED DUE TO RATE LIMIT !!!")
log_pipeline("Consolidating logs and generating emergency submission...")
consolidate_log_file(log_path)
output_file = DATA_OUTPUT_DIR / "submission_emergency.csv"
total_entries = generate_csv_from_log(log_path, output_file)
log_pipeline(f"Saved emergency submission with {total_entries} entries to: {output_file}")
sys.exit(0)
log_pipeline("Consolidating log file...")
consolidate_log_file(log_path)
elapsed = time.perf_counter() - start_time
throughput = total / elapsed if elapsed > 0 else 0
log_stats(f"Processed {processed_count}/{total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
return processed_count
def save_predictions(
predictions: list[PredictionOutput],
output_path: Path,
ensure_dir: bool = True,
) -> None:
"""Save predictions to CSV file, sorted by qid.
Args:
predictions: List of prediction outputs
output_path: Path to output CSV file
ensure_dir: If True, create parent directory if it doesn't exist
"""
if ensure_dir:
output_path.parent.mkdir(parents=True, exist_ok=True)
sorted_qids = sort_qids([p.qid for p in predictions])
pred_dict = {p.qid: p for p in predictions}
with open(output_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["qid", "answer"])
writer.writeheader()
for qid in sorted_qids:
writer.writerow({"qid": qid, "answer": pred_dict[qid].answer})
log_pipeline(f"Predictions saved to: {output_path}")
|