File size: 1,033 Bytes
c78c2fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from server.data_wrangler_environment import DataWranglerEnvironment
from models import DataWranglerAction

def test_level_4():
    os.environ["TASK_LEVEL"] = "4"
    env = DataWranglerEnvironment()
    env.reset()
    
    # 1. regex to extract number. 
    # Strings: "$500.5" -> pattern "^\$?([0-9.]+)"
    env.step(DataWranglerAction(action_type="extract_regex", target_column="amount", new_name="amount", regex_pattern=r"^\$?([0-9.]+)"))
    
    # 2. Cast amount to float
    env.step(DataWranglerAction(action_type="cast_type", target_column="amount", cast_to="float"))
    
    # 3. Group by customer_category and sum amount
    env.step(DataWranglerAction(action_type="group_by_aggregate", target_column="customer_category", agg_column="amount", agg_func="sum"))
    
    obs = env.step(DataWranglerAction(action_type="submit"))
    assert obs.reward > 0.8 # It should grade highly because the DFs will match exactly!