File size: 5,863 Bytes
c4ac745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import json
import os
import shutil

import pandas as pd

try:
    from preprocessor import DataTransformer
    from baselines.ClavaDDPM.preprocess_utils import topological_sort
except (ModuleNotFoundError, ImportError):
    import importlib
    import sys
    base_dir = os.path.dirname(__file__)
    full_path = os.path.abspath(os.path.join(base_dir, "..", "..", "preprocessor.py"))
    spec = importlib.util.spec_from_file_location("preprocessor", full_path)
    preprocessor = importlib.util.module_from_spec(spec)
    sys.modules["preprocessor"] = preprocessor
    spec.loader.exec_module(preprocessor)
    DataTransformer = preprocessor.DataTransformer


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest='op')

    pre_parser = subparsers.add_parser('pre')
    pre_parser.add_argument("--dataset-dir", "-d", default=os.path.join("data"))
    pre_parser.add_argument("--out-dir", "-o", default=os.path.join("."))

    post_parser = subparsers.add_parser('desimplify')
    post_parser.add_argument("--dataset-dir", "-d", default=os.path.join("data"))
    return parser.parse_args()


def main():
    args = parse_args()
    if args.op == "pre":
        table_names = [
            "customers", "geolocation", "order_items", "order_payments", "order_reviews",
            "orders", "products", "sellers"
        ]
        tables = {
            table_name: pd.read_csv(os.path.join(args.dataset_dir, f"olist_{table_name}_dataset.csv"))
            for table_name in table_names
        }
        geolocation = tables["geolocation"].drop(columns=["geolocation_city"])
        grouped = geolocation.groupby("geolocation_zip_code_prefix")
        state = grouped[["geolocation_state"]].first()
        location = grouped[["geolocation_lat", "geolocation_lng"]].mean()
        geolocation = pd.concat([state, location], axis=1).reset_index(drop=False)
        customers = tables["customers"].drop(columns=["customer_unique_id", "customer_city"])
        customers = customers[
            customers["customer_zip_code_prefix"].isin(geolocation["geolocation_zip_code_prefix"])
        ].reset_index(drop=True)
        sellers = tables["sellers"].drop(columns=["seller_city"])
        sellers = sellers[
            sellers["seller_zip_code_prefix"].isin(geolocation["geolocation_zip_code_prefix"])
        ].reset_index(drop=True)
        orders = tables["orders"]
        orders = orders[orders["customer_id"].isin(customers["customer_id"])].reset_index(drop=True)
        products = tables["products"]
        order_items = tables["order_items"]
        order_items = order_items[order_items["order_id"].isin(orders["order_id"])]
        order_items = order_items[order_items["seller_id"].isin(sellers["seller_id"])].reset_index(drop=True)
        order_payments = tables["order_payments"]
        order_payments = order_payments[order_payments["order_id"].isin(orders["order_id"])].reset_index(drop=True)
        order_payments = order_payments.sort_values(["order_id", "payment_sequential"])
        order_reviews = tables["order_reviews"].drop(columns=["review_comment_title", "review_comment_message"])
        order_reviews = order_reviews[order_reviews["order_id"].isin(orders["order_id"])].reset_index(drop=True)
        order_reviews = order_reviews.groupby("review_id").head(1).reset_index(drop=True)

        processors = {
            table: DataTransformer() for table in table_names
        }
        if os.path.exists(os.path.join(args.out_dir, "processor.json")):
            with open(os.path.join(args.out_dir, "processor.json"), "r") as f:
                loaded = json.load(f)
                for table in table_names:
                    processors[table] = DataTransformer.from_dict(loaded[table])
        else:
            processors["geolocation"].fit(geolocation, ["geolocation_zip_code_prefix"])
            processors["products"].fit(products, ["product_id"])
            processors["customers"].fit(customers, ["customer_id"], {
                "customer_zip_code_prefix": processors["geolocation"].columns["geolocation_zip_code_prefix"]
            })
            processors["sellers"].fit(sellers, ["seller_id"], {
                "seller_zip_code_prefix": processors["geolocation"].columns["geolocation_zip_code_prefix"],
            })
            processors["orders"].fit(orders, ["order_id"], {
                "customer_id": processors["customers"].columns["customer_id"],
            })
            processors["order_items"].fit(order_items, ref_cols={
                "order_id": processors["orders"].columns["order_id"],
                "product_id": processors["products"].columns["product_id"],
                "seller_id": processors["sellers"].columns["seller_id"],
            })
            processors["order_payments"].fit(order_payments, ref_cols={
                "order_id": processors["orders"].columns["order_id"],
            })
            processors["order_reviews"].fit(order_reviews, ref_cols={
                "order_id": processors["orders"].columns["order_id"],
            })
            with open(os.path.join(args.out_dir, "processor.json"), "w") as f:
                json.dump({
                    t: p.to_dict() for t, p in processors.items()
                }, f, indent=2)

        os.makedirs(args.out_dir, exist_ok=True)
        os.makedirs(os.path.join(args.out_dir, "preprocessed"), exist_ok=True)
        for table in table_names:
            transformed = processors[table].transform(locals()[table])
            transformed.to_csv(os.path.join(args.out_dir, f"preprocessed/{table}.csv"), index=False)

        shutil.copytree(os.path.join(args.out_dir, "preprocessed"), os.path.join(args.out_dir, "simplified"))

    elif args.op == "desimplify":
        pass


if __name__ == "__main__":
    main()