File size: 3,670 Bytes
9eb0831
 
 
 
 
 
 
 
 
 
 
 
4f7c4d5
 
9eb0831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f7c4d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Data models for the SRE OpenEnv environment.

Adaptive implementation that uses Pydantic BaseModel in OpenEnv 0.1 (Python 3.11)
and dataclasses in legacy environments (Python 3.9) to ensure cross-version 
compatibility and avoid inheritance conflicts.
"""

from __future__ import annotations

import os
import sys
from dataclasses import dataclass, field, asdict
from typing import Literal, List, Optional, Any, TypeVar, Union, Dict
from pydantic import BaseModel, Field

# Try to import from openenv.core (v0.1) or openenv_core (legacy/shim)
try:
    from openenv.core.env_server import Action, Observation, State
    # If we got here and it's 0.1, these are likely Pydantic models
    IS_PYDANTIC = issubclass(Action, BaseModel)
except (ImportError, TypeError):
    try:
        from openenv_core.env_server import Action, Observation, State
        # In legacy, these are usually dataclasses
        IS_PYDANTIC = False
    except ImportError:
        # Fallback if neither is available
        class Action: pass
        class Observation: pass
        class State: pass
        IS_PYDANTIC = False


if IS_PYDANTIC:
    # --- Pydantic v2 implementation (Modern) ---
    class SREAction(Action):
        action_type: Literal["run_shell", "patch_file"] = "run_shell"
        command: str = ""
        file_path: str = ""
        content: str = ""

    class SREObservation(Observation):
        stdout: str = ""
        stderr: str = ""
        exit_code: int = 0
        truncated: bool = False
        message: str = ""
        
        # Necessary for legacy server extraction
        reward: Optional[float] = 0.0
        done: bool = False

    class SREState(State):
        episode_id: str = ""
        step_count: int = 0
        task_id: str = ""
        task_name: str = ""
        description: str = ""
        difficulty: str = ""
        max_steps: int = 30
        is_done: bool = False
        current_reward: float = 0.0
        action_history: List[str] = Field(default_factory=list)

else:
    # --- Dataclass implementation (Legacy) ---
    @dataclass
    class SREAction(Action):
        action_type: Literal["run_shell", "patch_file"] = "run_shell"
        command: str = ""
        file_path: str = ""
        content: str = ""

        def __post_init__(self):
            # Compatibility for legacy initialization
            pass

    @dataclass
    class SREObservation(Observation):
        stdout: str = ""
        stderr: str = ""
        exit_code: int = 0
        truncated: bool = False
        message: str = ""
        
        # Necessary for legacy server extraction
        reward: Optional[float] = 0.0
        done: bool = False

    @dataclass
    class SREState(State):
        episode_id: str = ""
        step_count: int = 0
        task_id: str = ""
        task_name: str = ""
        description: str = ""
        difficulty: str = ""
        max_steps: int = 30
        is_done: bool = False
        current_reward: float = 0.0
        action_history: List[str] = field(default_factory=list)

def to_dict(obj: Any) -> Dict[str, Any]:
    """
    Polymorphic helper to convert either a Pydantic model or a dataclass to a dict.
    Useful for cross-version compatibility with openenv-core.
    """
    if isinstance(obj, BaseModel):
        # Pydantic v2 uses model_dump, v1 uses dict
        return obj.model_dump() if hasattr(obj, "model_dump") else obj.dict()
    elif hasattr(obj, "__dataclass_fields__"):
        return asdict(obj)
    elif isinstance(obj, dict):
        return obj
    else:
        # Fallback for other types
        return dict(obj) if hasattr(obj, "__dict__") else obj

# Add Dict to imports