File size: 4,969 Bytes
89ca667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import logging
from pathlib import Path
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, month, dayofweek, dayofyear, when, lag, avg, stddev, last, lit
from pyspark.sql.window import Window
import duckdb

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_spark_session():
    """Initialize a local Spark session simulating Databricks."""
    return SparkSession.builder \
        .appName("FMCG_Feature_Engineering") \
        .config("spark.driver.memory", "4g") \
        .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
        .getOrCreate()

def build_features(spark, db_path: str, output_parquet: str):
    """
    Reads raw data from DuckDB (simulating Snowflake),
    uses PySpark to build scalable time-series features,
    and saves the engineered features to Parquet.
    """
    logging.info("Starting PySpark feature engineering...")
    
    # 1. Read from DuckDB using Arrow to PySpark
    con = duckdb.connect(db_path)
    
    sales_pdf = con.execute("SELECT * FROM raw_sales ORDER BY store_nbr, family, date").df()
    stores_pdf = con.execute("SELECT * FROM stores").df()
    oil_pdf = con.execute("SELECT * FROM oil").df()
    holidays_pdf = con.execute("SELECT * FROM holidays_events").df()
    transactions_pdf = con.execute("SELECT * FROM transactions").df()
    con.close()
    
    df = spark.createDataFrame(sales_pdf)
    stores_df = spark.createDataFrame(stores_pdf).withColumnRenamed("type", "store_type")
    oil_df = spark.createDataFrame(oil_pdf)
    holidays_df = spark.createDataFrame(holidays_pdf)
    transactions_df = spark.createDataFrame(transactions_pdf)
    
    # 1.5 Pre-process Oil data (Forward-fill on the small table to prevent global sorting on 3M rows)
    # Add a dummy partition key to satisfy Spark's optimizer and silence the WindowExec warning
    oil_df = oil_df.withColumn("_dummy", lit(1))
    window_oil = Window.partitionBy("_dummy").orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
    oil_df = oil_df.withColumn("dcoilwtico", last("dcoilwtico", ignorenulls=True).over(window_oil)).drop("_dummy")
    
    # 2. Join Auxiliary Tables
    df = df.join(stores_df, on="store_nbr", how="left")
    df = df.join(oil_df, on="date", how="left").fillna({"dcoilwtico": 50.0})
    df = df.join(transactions_df, on=["date", "store_nbr"], how="left").fillna({"transactions": 0})
    
    national_holidays = holidays_df.filter(
        (col("locale") == "National") & (col("transferred") == False)
    ).select("date").distinct().withColumn("is_holiday", lit(1))
    
    df = df.join(national_holidays, on="date", how="left").fillna({"is_holiday": 0})
    
    # 3. Time-based features
    df = df.withColumn("month", month(col("date"))) \
           .withColumn("day_of_week", dayofweek(col("date"))) \
           .withColumn("day_of_year", dayofyear(col("date"))) \
           .withColumn("is_weekend", when(col("day_of_week").isin([1, 7]), 1).otherwise(0))
           
    # 4. Define Window for Lag and Rolling features
    # Partition by store and product family, ordered by date
    windowSpec = Window.partitionBy("store_nbr", "family").orderBy("date")
    
    # 5. Lag Features (e.g., sales from 1 day ago, 7 days ago)
    df = df.withColumn("sales_lag_1", lag("sales", 1).over(windowSpec)) \
           .withColumn("sales_lag_7", lag("sales", 7).over(windowSpec)) \
           .withColumn("sales_lag_28", lag("sales", 28).over(windowSpec)) \
           .withColumn("transactions_lag_1", lag("transactions", 1).over(windowSpec)) \
           .withColumn("transactions_lag_7", lag("transactions", 7).over(windowSpec))
           
    # 6. Rolling Window Features (e.g., 7-day moving average and std dev)
    # Need a window spec for the past 7 days
    windowSpec_7 = windowSpec.rowsBetween(-7, -1)
    windowSpec_28 = windowSpec.rowsBetween(-28, -1)
    
    df = df.withColumn("rolling_mean_7", avg("sales").over(windowSpec_7)) \
           .withColumn("rolling_std_7", stddev("sales").over(windowSpec_7)) \
           .withColumn("rolling_mean_28", avg("sales").over(windowSpec_28))

    # Drop rows with nulls introduced by lags to keep dataset clean for ML
    df = df.dropna()

    logging.info(f"Writing engineered features to {output_parquet}...")
    # 6. Save engineered features back to storage (Parquet format simulates writing back to Snowflake/ADLS)
    df.write.mode("overwrite").parquet(output_parquet)
    
    logging.info("Feature engineering complete.")
    spark.stop()

if __name__ == "__main__":
    project_dir = Path(__file__).resolve().parents[2]
    db_path = os.path.join(project_dir, "data", "warehouse.duckdb")
    output_parquet = os.path.join(project_dir, "data", "processed", "features.parquet")
    
    os.makedirs(os.path.dirname(output_parquet), exist_ok=True)
    
    spark = get_spark_session()
    build_features(spark, db_path, output_parquet)