Spaces:
Sleeping
Sleeping
File size: 5,318 Bytes
6ac92cf 8dde6c4 6ac92cf 8dde6c4 6ac92cf 8dde6c4 6ac92cf 8dde6c4 6ac92cf 8dde6c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """
models.py β Canonical Pydantic models for ScholarEnv v0.4.
Authors: Nensi Pansuriya, Krushna Parmar, Ishita Bhojani
Design:
1. AnyAction uses Annotated discriminated union on 'task' field.
2. ScholarObservation covers all 4 tasks with Optional fields.
3. CitationAction supports Task 4 (citation_verification).
4. No circular imports β models.py imports nothing from server/.
References:
PRS β arxiv 2512.07478
PBRS β Ng, Harada & Russell 1999
AdaRFT β arxiv 2504.05520
Veri-R1 β arxiv 2510.01932 (Task 4 design)
"""
from __future__ import annotations
from enum import Enum
from typing import Annotated, Literal, Optional, Union
from pydantic import BaseModel, Field
# ββ Enums βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class EpisodeStatus(str, Enum):
ACTIVE = "active"
DONE = "done"
# ββ Actions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class FormattingAction(BaseModel):
"""Task 1: submit the fully reformatted manuscript."""
task: Literal["formatting_compliance"] = "formatting_compliance"
formatted_text: str = Field(
description="Complete reformatted manuscript as a single string."
)
class ScholarAction(BaseModel):
"""Tasks 2 & 3: navigate the paper or submit findings."""
task: Literal["internal_consistency", "claim_evidence_audit"]
action_type: Literal[
"query_section",
"check_table",
"extract_claims",
"submit_findings",
]
section_name: Optional[str] = Field(default=None)
table_id: Optional[str] = Field(default=None)
findings: Optional[list[dict]] = Field(default=None)
class CitationAction(BaseModel):
"""Task 4: verify citations in paper reference list."""
task: Literal["citation_verification"] = "citation_verification"
action_type: Literal[
"check_citation", # β returns citation_data in obs
"submit_verdicts", # β final grade, done=True
]
citation_id: Optional[str] = Field(
default=None,
description="Reference ID for check_citation, e.g. 'ref_1'"
)
verdicts: Optional[list[dict]] = Field(
default=None,
description=(
"For submit_verdicts. Each dict: "
"citation_id, status (valid|ghost|misattributed), issue, confidence."
),
)
# Discriminated union β FastAPI deserialises on the 'task' field
AnyAction = Annotated[
Union[FormattingAction, ScholarAction, CitationAction],
Field(discriminator="task"),
]
# ββ Observation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ScholarObservation(BaseModel):
"""Unified observation returned by reset() and every step()."""
# Always present
task_id: str
task_description: str
paper_id: str
step_count: int = 0
max_steps: int = 3
cumulative_score: float = 0.0
feedback: str = ""
hint: str = ""
# Task 1 only
manuscript_text: Optional[str] = Field(
default=None,
description="Badly-formatted manuscript (Task 1 initial observation)."
)
style_guide: Optional[dict] = Field(
default=None,
description="IEEE style rule config."
)
# Tasks 2 & 3 β navigation
available_sections: list[str] = Field(default_factory=list)
available_tables: list[str] = Field(default_factory=list)
current_section_content: Optional[str] = None
current_table_content: Optional[dict] = None
extracted_claims: Optional[list[dict]] = None
findings_so_far: list[dict] = Field(default_factory=list)
# Task 4 β citation verification
available_references: list[dict] = Field(
default_factory=list,
description="Task 4: list of {id, citation_number, raw} dicts."
)
citation_data: Optional[dict] = Field(
default=None,
description="Task 4: returned after check_citation action."
)
# ββ Reward (logging / documentation only) ββββββββββββββββββββββββββββββββββββ
class ScholarReward(BaseModel):
"""Full reward breakdown β logged in step info dict."""
total: float = Field(ge=0.0, le=1.0)
# Task 1 β PRS stages
stage_1_score: float = 0.0
stage_2_score: float = 0.0
stage_3_score: float = 0.0
# Tasks 2 & 3 β F-beta
f_beta: float = 0.0
precision: float = 0.0
recall: float = 0.0
evidence_specificity: float = 0.0
coverage_bonus: float = 0.0
shaping_bonus: float = 0.0
# Task 4 β citation
precision_valid: float = 0.0
recall_ghost: float = 0.0
rule_breakdown: dict[str, float] = Field(default_factory=dict)
|