File size: 2,727 Bytes
d02bacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

from typing import Any, Dict, List, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field


class RewardBreakdown(BaseModel):
    model_config = ConfigDict(extra="forbid")

    immediate: float
    cumulative: float
    terminal_grader: Optional[float] = None


class CoSAction(BaseModel):
    model_config = ConfigDict(extra="ignore")

    action_type: Literal["consult", "ask", "summarize", "submit", "noop"]
    expert_id: Optional[Literal["analyst", "finance", "hr", "strategy"]] = None
    sub_question_id: Optional[str] = None
    notes: Optional[str] = None


class ExpertReport(BaseModel):
    model_config = ConfigDict(extra="forbid")

    expert_id: Literal["analyst", "finance", "hr", "strategy"]
    title: str
    summary: str
    metrics: Dict[str, float | int | str] = Field(default_factory=dict)
    bullet_points: List[str] = Field(default_factory=list)
    issues: List[str] = Field(default_factory=list)
    citations: List[str] = Field(default_factory=list)
    memory_citations: List[str] = Field(default_factory=list)
    memory_snippets: List[str] = Field(default_factory=list)
    memo: Optional[str] = None
    score: Optional[float] = None


class Brief(BaseModel):
    model_config = ConfigDict(extra="forbid")

    summary: str
    metrics: Dict[str, float | int | str] = Field(default_factory=dict)
    recommendations: List[str] = Field(default_factory=list)
    hr_memo: str = ""
    consulted_experts: List[str] = Field(default_factory=list)


class CoSObservation(BaseModel):
    model_config = ConfigDict(extra="forbid")

    done: bool
    reward: float = 0.0
    instruction: str
    history: List[str] = Field(default_factory=list)
    issues: List[str] = Field(default_factory=list)
    data_quality_score: float = 0.0
    task_name: str
    task_difficulty: str
    rag_enabled: bool = False
    max_steps: int = 12
    step_count: int = 0
    available_experts: List[str] = Field(default_factory=lambda: ["analyst", "finance", "hr", "strategy"])
    consulted_experts: List[str] = Field(default_factory=list)
    expert_reports: Dict[str, ExpertReport] = Field(default_factory=dict)
    current_brief: Optional[Brief] = None
    reward_breakdown: Optional[RewardBreakdown] = None
    terminal_grader_score: Optional[float] = None


class CoSState(BaseModel):
    model_config = ConfigDict(extra="forbid")

    episode_id: str
    task_name: str
    step_count: int = 0
    done: bool = False
    rag_enabled: bool = False
    consulted_experts: List[str] = Field(default_factory=list)
    expert_reports: Dict[str, ExpertReport] = Field(default_factory=dict)
    current_brief: Optional[Brief] = None
    cumulative_reward: float = 0.0