File size: 4,950 Bytes
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Higher-level LLM-backed services for planning, grading, and prompt rewriting.
"""

from __future__ import annotations

from typing import Any

from rag_optimizer_env.corpus import Chunk
from rag_optimizer_env.llm_runtime import call_json
from rag_optimizer_env.models import RagAction, RagObservation
from rag_optimizer_env.tasks import Task


async def suggest_action(
    observation: RagObservation,
    *,
    selected_chunk_details: list[dict[str, Any]],
    suggested_citations: list[str],
    top_demo_cases: list[str],
) -> dict[str, Any]:
    result = await call_json(
        system_prompt=(
            "You are ACTION_PLANNER for a grounded RAG optimization environment. "
            "Return exactly one valid RagAction JSON object. "
            "Choose among select_chunk, deselect_chunk, compress_chunk, or submit_answer. "
            "Prefer selecting the most relevant evidence first, compress only selected chunks, "
            "and submit a concise grounded answer with inline citations once evidence is sufficient."
        ),
        user_payload={
            "observation": observation.model_dump(),
            "selected_chunk_details": selected_chunk_details,
            "suggested_citations": suggested_citations,
            "top_demo_cases": top_demo_cases,
        },
        temperature=0.0,
        max_output_tokens=220,
    )
    return RagAction.model_validate(result.data).model_dump(exclude_none=True)


async def judge_answer(
    *,
    task: Task,
    answer: str,
    selected_chunks: list[Chunk],
    required_chunks: list[Chunk],
) -> dict[str, Any]:
    result = await call_json(
        system_prompt=(
            "You are ANSWER_GRADER for a grounded RAG benchmark. "
            "Evaluate whether the answer addresses the task query, covers the required evidence, "
            "and stays grounded in the provided evidence. "
            "Return JSON with numeric fields in [0,1]: "
            '{"answer_quality": 0.0, "groundedness": 0.0, "coverage": 0.0, "citation_support": 0.0, "notes": "short"}'
        ),
        user_payload={
            "task": {
                "name": task.name,
                "difficulty": task.difficulty,
                "query": task.query,
                "required_artifact_ids": task.required_artifact_ids,
                "expected_citation_ids": task.expected_citation_ids,
            },
            "answer": answer,
            "selected_evidence": [
                {
                    "chunk_id": chunk.chunk_id,
                    "domain": chunk.domain,
                    "keywords": chunk.keywords,
                    "text": chunk.text,
                }
                for chunk in selected_chunks
            ],
            "required_evidence": [
                {
                    "chunk_id": chunk.chunk_id,
                    "domain": chunk.domain,
                    "keywords": chunk.keywords,
                    "text": chunk.text,
                }
                for chunk in required_chunks
            ],
        },
        temperature=0.0,
        max_output_tokens=180,
    )
    payload = result.data
    return {
        "answer_quality": max(0.0, min(1.0, float(payload.get("answer_quality", 0.0)))),
        "groundedness": max(0.0, min(1.0, float(payload.get("groundedness", 0.0)))),
        "coverage": max(0.0, min(1.0, float(payload.get("coverage", 0.0)))),
        "citation_support": max(0.0, min(1.0, float(payload.get("citation_support", 0.0)))),
        "notes": str(payload.get("notes", "")).strip(),
    }


async def rewrite_prompt(
    *,
    prompt: str,
    mode: str,
    target_tokens: int,
    evidence_notes: list[dict[str, str]],
    citation_ids: list[str],
) -> dict[str, Any]:
    result = await call_json(
        system_prompt=(
            "You are PROMPT_COMPRESSOR for grounded prompt optimization. "
            "Rewrite the user's prompt to preserve intent while reducing length and keeping essential constraints. "
            "If evidence notes are provided, use them to keep the rewrite grounded. "
            "Return JSON with exactly these fields: "
            '{"optimized_prompt": "text", "estimated_tokens": 123, "citation_ready": true, "citation_guidance": "short note"}'
        ),
        user_payload={
            "mode": mode,
            "target_tokens": target_tokens,
            "prompt": prompt,
            "evidence_notes": evidence_notes,
            "citation_ids": citation_ids,
        },
        temperature=0.1,
        max_output_tokens=max(220, min(600, target_tokens * 8)),
    )
    payload = result.data
    return {
        "optimized_prompt": str(payload.get("optimized_prompt", "")).strip(),
        "estimated_tokens": max(1, int(payload.get("estimated_tokens", target_tokens))),
        "citation_ready": bool(payload.get("citation_ready", False)),
        "citation_guidance": str(payload.get("citation_guidance", "")).strip() or None,
    }