File size: 5,318 Bytes
6ac92cf
8dde6c4
 
 
 
 
 
 
 
 
6ac92cf
8dde6c4
 
 
 
 
 
6ac92cf
 
8dde6c4
 
6ac92cf
 
 
 
8dde6c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ac92cf
8dde6c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
models.py β€” Canonical Pydantic models for ScholarEnv v0.4.

Authors: Nensi Pansuriya, Krushna Parmar, Ishita Bhojani

Design:
  1. AnyAction uses Annotated discriminated union on 'task' field.
  2. ScholarObservation covers all 4 tasks with Optional fields.
  3. CitationAction supports Task 4 (citation_verification).
  4. No circular imports β€” models.py imports nothing from server/.

References:
  PRS  β€” arxiv 2512.07478
  PBRS β€” Ng, Harada & Russell 1999
  AdaRFT β€” arxiv 2504.05520
  Veri-R1 β€” arxiv 2510.01932 (Task 4 design)
"""
from __future__ import annotations

from enum import Enum
from typing import Annotated, Literal, Optional, Union

from pydantic import BaseModel, Field


# ── Enums ─────────────────────────────────────────────────────────────────────

class EpisodeStatus(str, Enum):
    ACTIVE = "active"
    DONE   = "done"


# ── Actions ───────────────────────────────────────────────────────────────────

class FormattingAction(BaseModel):
    """Task 1: submit the fully reformatted manuscript."""
    task: Literal["formatting_compliance"] = "formatting_compliance"
    formatted_text: str = Field(
        description="Complete reformatted manuscript as a single string."
    )


class ScholarAction(BaseModel):
    """Tasks 2 & 3: navigate the paper or submit findings."""
    task: Literal["internal_consistency", "claim_evidence_audit"]
    action_type: Literal[
        "query_section",
        "check_table",
        "extract_claims",
        "submit_findings",
    ]
    section_name: Optional[str] = Field(default=None)
    table_id:     Optional[str] = Field(default=None)
    findings:     Optional[list[dict]] = Field(default=None)


class CitationAction(BaseModel):
    """Task 4: verify citations in paper reference list."""
    task: Literal["citation_verification"] = "citation_verification"
    action_type: Literal[
        "check_citation",   # β†’ returns citation_data in obs
        "submit_verdicts",  # β†’ final grade, done=True
    ]
    citation_id: Optional[str] = Field(
        default=None,
        description="Reference ID for check_citation, e.g. 'ref_1'"
    )
    verdicts: Optional[list[dict]] = Field(
        default=None,
        description=(
            "For submit_verdicts. Each dict: "
            "citation_id, status (valid|ghost|misattributed), issue, confidence."
        ),
    )


# Discriminated union β€” FastAPI deserialises on the 'task' field
AnyAction = Annotated[
    Union[FormattingAction, ScholarAction, CitationAction],
    Field(discriminator="task"),
]


# ── Observation ───────────────────────────────────────────────────────────────

class ScholarObservation(BaseModel):
    """Unified observation returned by reset() and every step()."""

    # Always present
    task_id:          str
    task_description: str
    paper_id:         str
    step_count:       int   = 0
    max_steps:        int   = 3
    cumulative_score: float = 0.0
    feedback:         str   = ""
    hint:             str   = ""

    # Task 1 only
    manuscript_text: Optional[str] = Field(
        default=None,
        description="Badly-formatted manuscript (Task 1 initial observation)."
    )
    style_guide: Optional[dict] = Field(
        default=None,
        description="IEEE style rule config."
    )

    # Tasks 2 & 3 β€” navigation
    available_sections:      list[str]          = Field(default_factory=list)
    available_tables:        list[str]          = Field(default_factory=list)
    current_section_content: Optional[str]      = None
    current_table_content:   Optional[dict]     = None
    extracted_claims:        Optional[list[dict]] = None
    findings_so_far:         list[dict]          = Field(default_factory=list)

    # Task 4 β€” citation verification
    available_references: list[dict] = Field(
        default_factory=list,
        description="Task 4: list of {id, citation_number, raw} dicts."
    )
    citation_data: Optional[dict] = Field(
        default=None,
        description="Task 4: returned after check_citation action."
    )


# ── Reward (logging / documentation only) ────────────────────────────────────

class ScholarReward(BaseModel):
    """Full reward breakdown β€” logged in step info dict."""
    total: float = Field(ge=0.0, le=1.0)
    # Task 1 β€” PRS stages
    stage_1_score: float = 0.0
    stage_2_score: float = 0.0
    stage_3_score: float = 0.0
    # Tasks 2 & 3 β€” F-beta
    f_beta:               float = 0.0
    precision:            float = 0.0
    recall:               float = 0.0
    evidence_specificity: float = 0.0
    coverage_bonus:       float = 0.0
    shaping_bonus:        float = 0.0
    # Task 4 β€” citation
    precision_valid:      float = 0.0
    recall_ghost:         float = 0.0
    rule_breakdown: dict[str, float] = Field(default_factory=dict)