sre-openenv / models.py
Dragonfire146's picture
Fix TypeError: asdict() should be called on dataclass instances
4f7c4d5
"""
Data models for the SRE OpenEnv environment.
Adaptive implementation that uses Pydantic BaseModel in OpenEnv 0.1 (Python 3.11)
and dataclasses in legacy environments (Python 3.9) to ensure cross-version
compatibility and avoid inheritance conflicts.
"""
from __future__ import annotations
import os
import sys
from dataclasses import dataclass, field, asdict
from typing import Literal, List, Optional, Any, TypeVar, Union, Dict
from pydantic import BaseModel, Field
# Try to import from openenv.core (v0.1) or openenv_core (legacy/shim)
try:
from openenv.core.env_server import Action, Observation, State
# If we got here and it's 0.1, these are likely Pydantic models
IS_PYDANTIC = issubclass(Action, BaseModel)
except (ImportError, TypeError):
try:
from openenv_core.env_server import Action, Observation, State
# In legacy, these are usually dataclasses
IS_PYDANTIC = False
except ImportError:
# Fallback if neither is available
class Action: pass
class Observation: pass
class State: pass
IS_PYDANTIC = False
if IS_PYDANTIC:
# --- Pydantic v2 implementation (Modern) ---
class SREAction(Action):
action_type: Literal["run_shell", "patch_file"] = "run_shell"
command: str = ""
file_path: str = ""
content: str = ""
class SREObservation(Observation):
stdout: str = ""
stderr: str = ""
exit_code: int = 0
truncated: bool = False
message: str = ""
# Necessary for legacy server extraction
reward: Optional[float] = 0.0
done: bool = False
class SREState(State):
episode_id: str = ""
step_count: int = 0
task_id: str = ""
task_name: str = ""
description: str = ""
difficulty: str = ""
max_steps: int = 30
is_done: bool = False
current_reward: float = 0.0
action_history: List[str] = Field(default_factory=list)
else:
# --- Dataclass implementation (Legacy) ---
@dataclass
class SREAction(Action):
action_type: Literal["run_shell", "patch_file"] = "run_shell"
command: str = ""
file_path: str = ""
content: str = ""
def __post_init__(self):
# Compatibility for legacy initialization
pass
@dataclass
class SREObservation(Observation):
stdout: str = ""
stderr: str = ""
exit_code: int = 0
truncated: bool = False
message: str = ""
# Necessary for legacy server extraction
reward: Optional[float] = 0.0
done: bool = False
@dataclass
class SREState(State):
episode_id: str = ""
step_count: int = 0
task_id: str = ""
task_name: str = ""
description: str = ""
difficulty: str = ""
max_steps: int = 30
is_done: bool = False
current_reward: float = 0.0
action_history: List[str] = field(default_factory=list)
def to_dict(obj: Any) -> Dict[str, Any]:
"""
Polymorphic helper to convert either a Pydantic model or a dataclass to a dict.
Useful for cross-version compatibility with openenv-core.
"""
if isinstance(obj, BaseModel):
# Pydantic v2 uses model_dump, v1 uses dict
return obj.model_dump() if hasattr(obj, "model_dump") else obj.dict()
elif hasattr(obj, "__dataclass_fields__"):
return asdict(obj)
elif isinstance(obj, dict):
return obj
else:
# Fallback for other types
return dict(obj) if hasattr(obj, "__dict__") else obj
# Add Dict to imports