Spaces:
Sleeping
Sleeping
| """ | |
| Higher-level LLM-backed services for planning, grading, and prompt rewriting. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| from rag_optimizer_env.corpus import Chunk | |
| from rag_optimizer_env.llm_runtime import call_json | |
| from rag_optimizer_env.models import RagAction, RagObservation | |
| from rag_optimizer_env.tasks import Task | |
| async def suggest_action( | |
| observation: RagObservation, | |
| *, | |
| selected_chunk_details: list[dict[str, Any]], | |
| suggested_citations: list[str], | |
| top_demo_cases: list[str], | |
| ) -> dict[str, Any]: | |
| result = await call_json( | |
| system_prompt=( | |
| "You are ACTION_PLANNER for a grounded RAG optimization environment. " | |
| "Return exactly one valid RagAction JSON object. " | |
| "Choose among select_chunk, deselect_chunk, compress_chunk, or submit_answer. " | |
| "Prefer selecting the most relevant evidence first, compress only selected chunks, " | |
| "and submit a concise grounded answer with inline citations once evidence is sufficient." | |
| ), | |
| user_payload={ | |
| "observation": observation.model_dump(), | |
| "selected_chunk_details": selected_chunk_details, | |
| "suggested_citations": suggested_citations, | |
| "top_demo_cases": top_demo_cases, | |
| }, | |
| temperature=0.0, | |
| max_output_tokens=220, | |
| ) | |
| return RagAction.model_validate(result.data).model_dump(exclude_none=True) | |
| async def judge_answer( | |
| *, | |
| task: Task, | |
| answer: str, | |
| selected_chunks: list[Chunk], | |
| required_chunks: list[Chunk], | |
| ) -> dict[str, Any]: | |
| result = await call_json( | |
| system_prompt=( | |
| "You are ANSWER_GRADER for a grounded RAG benchmark. " | |
| "Evaluate whether the answer addresses the task query, covers the required evidence, " | |
| "and stays grounded in the provided evidence. " | |
| "Return JSON with numeric fields in [0,1]: " | |
| '{"answer_quality": 0.0, "groundedness": 0.0, "coverage": 0.0, "citation_support": 0.0, "notes": "short"}' | |
| ), | |
| user_payload={ | |
| "task": { | |
| "name": task.name, | |
| "difficulty": task.difficulty, | |
| "query": task.query, | |
| "required_artifact_ids": task.required_artifact_ids, | |
| "expected_citation_ids": task.expected_citation_ids, | |
| }, | |
| "answer": answer, | |
| "selected_evidence": [ | |
| { | |
| "chunk_id": chunk.chunk_id, | |
| "domain": chunk.domain, | |
| "keywords": chunk.keywords, | |
| "text": chunk.text, | |
| } | |
| for chunk in selected_chunks | |
| ], | |
| "required_evidence": [ | |
| { | |
| "chunk_id": chunk.chunk_id, | |
| "domain": chunk.domain, | |
| "keywords": chunk.keywords, | |
| "text": chunk.text, | |
| } | |
| for chunk in required_chunks | |
| ], | |
| }, | |
| temperature=0.0, | |
| max_output_tokens=180, | |
| ) | |
| payload = result.data | |
| return { | |
| "answer_quality": max(0.0, min(1.0, float(payload.get("answer_quality", 0.0)))), | |
| "groundedness": max(0.0, min(1.0, float(payload.get("groundedness", 0.0)))), | |
| "coverage": max(0.0, min(1.0, float(payload.get("coverage", 0.0)))), | |
| "citation_support": max(0.0, min(1.0, float(payload.get("citation_support", 0.0)))), | |
| "notes": str(payload.get("notes", "")).strip(), | |
| } | |
| async def rewrite_prompt( | |
| *, | |
| prompt: str, | |
| mode: str, | |
| target_tokens: int, | |
| evidence_notes: list[dict[str, str]], | |
| citation_ids: list[str], | |
| ) -> dict[str, Any]: | |
| result = await call_json( | |
| system_prompt=( | |
| "You are PROMPT_COMPRESSOR for grounded prompt optimization. " | |
| "Rewrite the user's prompt to preserve intent while reducing length and keeping essential constraints. " | |
| "If evidence notes are provided, use them to keep the rewrite grounded. " | |
| "Return JSON with exactly these fields: " | |
| '{"optimized_prompt": "text", "estimated_tokens": 123, "citation_ready": true, "citation_guidance": "short note"}' | |
| ), | |
| user_payload={ | |
| "mode": mode, | |
| "target_tokens": target_tokens, | |
| "prompt": prompt, | |
| "evidence_notes": evidence_notes, | |
| "citation_ids": citation_ids, | |
| }, | |
| temperature=0.1, | |
| max_output_tokens=max(220, min(600, target_tokens * 8)), | |
| ) | |
| payload = result.data | |
| return { | |
| "optimized_prompt": str(payload.get("optimized_prompt", "")).strip(), | |
| "estimated_tokens": max(1, int(payload.get("estimated_tokens", target_tokens))), | |
| "citation_ready": bool(payload.get("citation_ready", False)), | |
| "citation_guidance": str(payload.get("citation_guidance", "")).strip() or None, | |
| } | |