data_analysis_env / models.py
HimanshuSardana2's picture
Upload folder using huggingface_hub
abb357f verified
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
class DataAnalysisAction(BaseModel):
tool: str = Field(..., description="Tool name to execute")
parameters: dict[str, Any] = Field(
default_factory=dict, description="Tool parameters"
)
@field_validator("tool", mode="before")
@classmethod
def _coerce_tool(cls, value):
if isinstance(value, dict):
return value.get("tool", "")
return str(value)
class DataAnalysisObservation(BaseModel):
done: bool = Field(default=False, description="Whether episode is done")
reward: float = Field(default=0.0, description="Reward for this step")
success: bool = Field(
default=True, description="Whether tool executed successfully"
)
output: str = Field(default="", description="Tool output or error message")
data_shape: Optional[tuple[int, int]] = Field(
default=None, description="(rows, columns) of current data"
)
columns: list[str] = Field(
default_factory=list, description="Column names of current data"
)
tools_used: list[str] = Field(
default_factory=list, description="History of tools called"
)
error: Optional[str] = Field(
default=None, description="Error message if tool failed"
)
@field_validator("data_shape", mode="before")
@classmethod
def _coerce_shape(cls, value):
if isinstance(value, list) and len(value) == 2:
return tuple(value)
return value
class DataAnalysisState(BaseModel):
episode_id: Optional[str] = Field(
default=None, description="Unique episode identifier"
)
task_name: str = Field(default="", description="Current task name")
step_count: int = Field(default=0, description="Number of steps taken")
max_steps: int = Field(default=20, description="Maximum steps allowed per episode")
data_loaded: bool = Field(default=False, description="Whether data has been loaded")
AVAILABLE_TOOLS = [
"load_csv",
"show_data",
"show_columns",
"fill_missing",
"remove_duplicates",
"filter_rows",
"select_columns",
"group_by",
"calculate",
"sort_by",
"get_result",
"merge_datasets",
]