Spaces:
Running
Running
| """ | |
| models.py β Canonical Pydantic models for ScholarEnv v0.4. | |
| Authors: Nensi Pansuriya, Krushna Parmar, Ishita Bhojani | |
| Design: | |
| 1. AnyAction uses Annotated discriminated union on 'task' field. | |
| 2. ScholarObservation covers all 4 tasks with Optional fields. | |
| 3. CitationAction supports Task 4 (citation_verification). | |
| 4. No circular imports β models.py imports nothing from server/. | |
| References: | |
| PRS β arxiv 2512.07478 | |
| PBRS β Ng, Harada & Russell 1999 | |
| AdaRFT β arxiv 2504.05520 | |
| Veri-R1 β arxiv 2510.01932 (Task 4 design) | |
| """ | |
| from __future__ import annotations | |
| from enum import Enum | |
| from typing import Annotated, Literal, Optional, Union | |
| from pydantic import BaseModel, Field | |
| # ββ Enums βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EpisodeStatus(str, Enum): | |
| ACTIVE = "active" | |
| DONE = "done" | |
| # ββ Actions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FormattingAction(BaseModel): | |
| """Task 1: submit the fully reformatted manuscript.""" | |
| task: Literal["formatting_compliance"] = "formatting_compliance" | |
| formatted_text: str = Field( | |
| description="Complete reformatted manuscript as a single string." | |
| ) | |
| class ScholarAction(BaseModel): | |
| """Tasks 2 & 3: navigate the paper or submit findings.""" | |
| task: Literal["internal_consistency", "claim_evidence_audit"] | |
| action_type: Literal[ | |
| "query_section", | |
| "check_table", | |
| "extract_claims", | |
| "submit_findings", | |
| ] | |
| section_name: Optional[str] = Field(default=None) | |
| table_id: Optional[str] = Field(default=None) | |
| findings: Optional[list[dict]] = Field(default=None) | |
| class CitationAction(BaseModel): | |
| """Task 4: verify citations in paper reference list.""" | |
| task: Literal["citation_verification"] = "citation_verification" | |
| action_type: Literal[ | |
| "check_citation", # β returns citation_data in obs | |
| "submit_verdicts", # β final grade, done=True | |
| ] | |
| citation_id: Optional[str] = Field( | |
| default=None, | |
| description="Reference ID for check_citation, e.g. 'ref_1'" | |
| ) | |
| verdicts: Optional[list[dict]] = Field( | |
| default=None, | |
| description=( | |
| "For submit_verdicts. Each dict: " | |
| "citation_id, status (valid|ghost|misattributed), issue, confidence." | |
| ), | |
| ) | |
| # Discriminated union β FastAPI deserialises on the 'task' field | |
| AnyAction = Annotated[ | |
| Union[FormattingAction, ScholarAction, CitationAction], | |
| Field(discriminator="task"), | |
| ] | |
| # ββ Observation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ScholarObservation(BaseModel): | |
| """Unified observation returned by reset() and every step().""" | |
| # Always present | |
| task_id: str | |
| task_description: str | |
| paper_id: str | |
| step_count: int = 0 | |
| max_steps: int = 3 | |
| cumulative_score: float = 0.0 | |
| feedback: str = "" | |
| hint: str = "" | |
| # Task 1 only | |
| manuscript_text: Optional[str] = Field( | |
| default=None, | |
| description="Badly-formatted manuscript (Task 1 initial observation)." | |
| ) | |
| style_guide: Optional[dict] = Field( | |
| default=None, | |
| description="IEEE style rule config." | |
| ) | |
| # Tasks 2 & 3 β navigation | |
| available_sections: list[str] = Field(default_factory=list) | |
| available_tables: list[str] = Field(default_factory=list) | |
| current_section_content: Optional[str] = None | |
| current_table_content: Optional[dict] = None | |
| extracted_claims: Optional[list[dict]] = None | |
| findings_so_far: list[dict] = Field(default_factory=list) | |
| # Task 4 β citation verification | |
| available_references: list[dict] = Field( | |
| default_factory=list, | |
| description="Task 4: list of {id, citation_number, raw} dicts." | |
| ) | |
| citation_data: Optional[dict] = Field( | |
| default=None, | |
| description="Task 4: returned after check_citation action." | |
| ) | |
| # ββ Reward (logging / documentation only) ββββββββββββββββββββββββββββββββββββ | |
| class ScholarReward(BaseModel): | |
| """Full reward breakdown β logged in step info dict.""" | |
| total: float = Field(ge=0.0, le=1.0) | |
| # Task 1 β PRS stages | |
| stage_1_score: float = 0.0 | |
| stage_2_score: float = 0.0 | |
| stage_3_score: float = 0.0 | |
| # Tasks 2 & 3 β F-beta | |
| f_beta: float = 0.0 | |
| precision: float = 0.0 | |
| recall: float = 0.0 | |
| evidence_specificity: float = 0.0 | |
| coverage_bonus: float = 0.0 | |
| shaping_bonus: float = 0.0 | |
| # Task 4 β citation | |
| precision_valid: float = 0.0 | |
| recall_ghost: float = 0.0 | |
| rule_breakdown: dict[str, float] = Field(default_factory=dict) | |