File size: 2,560 Bytes
0683cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from models import SQLAction, SQLObservation, SQLState


class SQLTutorEnv(EnvClient[SQLAction, SQLObservation, SQLState]):
    """
    Client for the SQL Tutor environment.

    Usage:
        # Connect to a running HF Space
        env = SQLTutorEnv(base_url="https://your-space.hf.space")

        # Or load locally from Hub
        env = SQLTutorEnv.from_hub("your-username/sql-tutor-env")

        obs, state = env.reset()
        result = env.step(SQLAction(action_type="submit_fix", sql_query="SELECT ..."))
    """

    def __init__(self, base_url: str, **kwargs):
        super().__init__(base_url=base_url, **kwargs)

    def _step_payload(self, action: SQLAction) -> Dict:
        return {
            "action_type": action.action_type,
            "sql_query": action.sql_query,
        }

    def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
        obs_data = payload.get("observation", {})
        observation = SQLObservation(
            broken_query=obs_data.get("broken_query", ""),
            schema_description=obs_data.get("schema_description", ""),
            task_description=obs_data.get("task_description", ""),
            execution_result=obs_data.get("execution_result", ""),
            is_correct=obs_data.get("is_correct", False),
            hint=obs_data.get("hint"),
            steps_taken=obs_data.get("steps_taken", 0),
            max_steps=obs_data.get("max_steps", 5),
            hints_used=obs_data.get("hints_used", 0),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward", 0.0),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> SQLState:
        return SQLState(
            challenge_id=payload.get("challenge_id", ""),
            broken_query=payload.get("broken_query", ""),
            correct_query=payload.get("correct_query", ""),
            schema_sql=payload.get("schema_sql", ""),
            schema_description=payload.get("schema_description", ""),
            task_description=payload.get("task_description", ""),
            hints=payload.get("hints", []),
            steps_taken=payload.get("steps_taken", 0),
            max_steps=payload.get("max_steps", 5),
            hints_used=payload.get("hints_used", 0),
            is_resolved=payload.get("is_resolved", False),
            cumulative_reward=payload.get("cumulative_reward", 0.0),
        )