Spaces:
Running
Running
| """CLI: Generation quality annotation tool. | |
| For each annotated question in eval_set.json, runs the full RAG pipeline | |
| and presents the generated answer for human scoring. | |
| Scores: | |
| - Faithfulness: y/n (does the answer match the retrieved context?) | |
| - Relevance: 1-5 (how well does it address the question?) | |
| - Citation accuracy: y/n (do [1], [2] markers support the claims?) | |
| Usage: | |
| python scripts/annotate_generation.py | |
| python scripts/annotate_generation.py --backend groq | |
| python scripts/annotate_generation.py --force # re-score already scored | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.config import PROJECT_ROOT, get_config | |
| from src.generation.llm_backend_base import LLMBackend | |
| from src.generation.rag_engine import RAGEngine | |
| from src.ingestion.embeddings import EmbeddingGenerator | |
| from src.retrieval.pipeline import RetrievalPipeline | |
| from src.retrieval.reranker import CrossEncoderReranker | |
| from src.storage.chroma_store import ChromaStore | |
| from src.storage.sqlite_db import SQLiteDB | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| EVAL_SET_PATH = PROJECT_ROOT / "data" / "eval_set.json" | |
| def load_eval_set() -> list[dict]: | |
| if EVAL_SET_PATH.exists(): | |
| with open(EVAL_SET_PATH, encoding="utf-8") as f: | |
| return json.load(f) | |
| return [] | |
| def save_eval_set(eval_set: list[dict]) -> None: | |
| with open(EVAL_SET_PATH, "w", encoding="utf-8") as f: | |
| json.dump(eval_set, f, indent=2, ensure_ascii=False) | |
| def make_llm_backend(backend_name: str, config) -> LLMBackend: | |
| if backend_name == "groq": | |
| from src.generation.groq_backend import GroqBackend | |
| if not config.groq_api_key: | |
| logger.error("GROQ_API_KEY not set") | |
| sys.exit(1) | |
| return GroqBackend(api_key=config.groq_api_key) | |
| else: | |
| from src.generation.ollama_backend import OllamaBackend | |
| return OllamaBackend(host=config.ollama_host) | |
| def init_rag_engine(config, backend_name: str) -> RAGEngine: | |
| db = SQLiteDB(config.sqlite_db_path) | |
| chroma = ChromaStore(config.chroma_db_path) | |
| embed_gen = EmbeddingGenerator(config.embedding_model) | |
| reranker = CrossEncoderReranker(config.reranker_model) | |
| pipeline = RetrievalPipeline( | |
| db=db, chroma_store=chroma, | |
| embedding_generator=embed_gen, reranker=reranker, | |
| ) | |
| pipeline.build_index() | |
| llm = make_llm_backend(backend_name, config) | |
| return RAGEngine(pipeline, llm) | |
| def display_answer(entry: dict, answer: str, sources: list[dict]): | |
| """Display the generated answer with context for scoring.""" | |
| print(f"\n{'='*60}") | |
| print(f" Question [{entry['id']}]: {entry['question']}") | |
| print(f" Type: {entry.get('type', '?')}") | |
| kw = entry.get("expected_keywords", []) | |
| if kw: | |
| print(f" Expected keywords: {', '.join(kw)}") | |
| print(f"{'β'*60}") | |
| print(f" Relevant chunks: {len(entry.get('relevant_chunk_ids', []))}") | |
| print(f"{'β'*60}") | |
| print(f"\n === Generated Answer ===\n") | |
| print(answer) | |
| if sources: | |
| print(f"\n === Sources ===") | |
| for i, s in enumerate(sources, 1): | |
| print(f" [{i}] {s.get('title', '?')} ({s.get('venue', '?')}, {s.get('year', '?')})") | |
| print(f"\n{'β'*60}") | |
| def prompt_yn(label: str) -> bool | None: | |
| """Prompt for y/n, return None on quit.""" | |
| while True: | |
| val = input(f" {label} (y/n/q): ").strip().lower() | |
| if val == "y": | |
| return True | |
| if val == "n": | |
| return False | |
| if val == "q": | |
| return None | |
| print(" Invalid. Use y/n/q.") | |
| def prompt_score(label: str, min_val: int = 1, max_val: int = 5) -> int | None: | |
| """Prompt for a numeric score, return None on quit.""" | |
| while True: | |
| val = input(f" {label} ({min_val}-{max_val}/q): ").strip().lower() | |
| if val == "q": | |
| return None | |
| try: | |
| num = int(val) | |
| if min_val <= num <= max_val: | |
| return num | |
| print(f" Must be between {min_val} and {max_val}.") | |
| except ValueError: | |
| print(" Invalid. Enter a number or 'q'.") | |
| def annotate_entry(entry: dict, engine: RAGEngine) -> bool: | |
| """Score generation quality for one entry. Returns False if user quit.""" | |
| question = entry["question"] | |
| print(f"\nGenerating answer for [{entry['id']}]...") | |
| response = engine.query(question=question, top_k=5) | |
| display_answer(entry, response.answer, response.sources) | |
| # Faithfulness | |
| faithfulness = prompt_yn("Faithfulness β does the answer match the context?") | |
| if faithfulness is None: | |
| return False | |
| # Relevance | |
| relevance = prompt_score("Relevance β how well does it address the question?") | |
| if relevance is None: | |
| return False | |
| # Citation accuracy | |
| citation = prompt_yn("Citation accuracy β do [1], [2] markers support the claims?") | |
| if citation is None: | |
| return False | |
| entry["generation_scores"] = { | |
| "faithfulness": faithfulness, | |
| "relevance": relevance, | |
| "citation_accuracy": citation, | |
| "model": response.model, | |
| "answer": response.answer, | |
| "scored_at": datetime.now(timezone.utc).isoformat(), | |
| } | |
| print(f" -> Scored: faithful={faithfulness}, relevance={relevance}, citations={citation}") | |
| return True | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Annotate RAG generation quality" | |
| ) | |
| parser.add_argument( | |
| "--backend", choices=["ollama", "groq"], default=None, | |
| help="LLM backend (default: from LLM_BACKEND env var)", | |
| ) | |
| parser.add_argument( | |
| "--force", action="store_true", | |
| help="Re-score entries that already have generation_scores", | |
| ) | |
| args = parser.parse_args() | |
| config = get_config() | |
| backend_name = args.backend or config.llm_backend | |
| eval_set = load_eval_set() | |
| if not eval_set: | |
| print(f"No annotations found at {EVAL_SET_PATH}") | |
| print("Run: python scripts/annotate.py first") | |
| sys.exit(1) | |
| # Filter to entries that have retrieval annotations | |
| annotated = [e for e in eval_set if e.get("relevant_chunk_ids")] | |
| if not annotated: | |
| print("No entries with retrieval annotations. Run scripts/annotate.py first.") | |
| sys.exit(1) | |
| print(f"\n=== Generation Quality Annotation ===") | |
| print(f"Entries with retrieval annotations: {len(annotated)}") | |
| print(f"Backend: {backend_name}\n") | |
| engine = init_rag_engine(config, backend_name) | |
| for entry in annotated: | |
| if entry.get("generation_scores") and not args.force: | |
| print(f" [{entry['id']}] already scored β skipping (use --force to redo)") | |
| continue | |
| if not annotate_entry(entry, engine): | |
| save_eval_set(eval_set) | |
| print("\nAnnotation paused. Progress saved.") | |
| return | |
| save_eval_set(eval_set) | |
| print(f"\nDone. All generation scores saved to {EVAL_SET_PATH}") | |
| if __name__ == "__main__": | |
| main() | |