| import os |
| import logging |
| from pathlib import Path |
| from pyspark.sql import SparkSession |
| from pyspark.sql.functions import col, month, dayofweek, dayofyear, when, lag, avg, stddev, last, lit |
| from pyspark.sql.window import Window |
| import duckdb |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| def get_spark_session(): |
| """Initialize a local Spark session simulating Databricks.""" |
| return SparkSession.builder \ |
| .appName("FMCG_Feature_Engineering") \ |
| .config("spark.driver.memory", "4g") \ |
| .config("spark.sql.execution.arrow.pyspark.enabled", "true") \ |
| .getOrCreate() |
|
|
| def build_features(spark, db_path: str, output_parquet: str): |
| """ |
| Reads raw data from DuckDB (simulating Snowflake), |
| uses PySpark to build scalable time-series features, |
| and saves the engineered features to Parquet. |
| """ |
| logging.info("Starting PySpark feature engineering...") |
| |
| |
| con = duckdb.connect(db_path) |
| |
| sales_pdf = con.execute("SELECT * FROM raw_sales ORDER BY store_nbr, family, date").df() |
| stores_pdf = con.execute("SELECT * FROM stores").df() |
| oil_pdf = con.execute("SELECT * FROM oil").df() |
| holidays_pdf = con.execute("SELECT * FROM holidays_events").df() |
| transactions_pdf = con.execute("SELECT * FROM transactions").df() |
| con.close() |
| |
| df = spark.createDataFrame(sales_pdf) |
| stores_df = spark.createDataFrame(stores_pdf).withColumnRenamed("type", "store_type") |
| oil_df = spark.createDataFrame(oil_pdf) |
| holidays_df = spark.createDataFrame(holidays_pdf) |
| transactions_df = spark.createDataFrame(transactions_pdf) |
| |
| |
| |
| oil_df = oil_df.withColumn("_dummy", lit(1)) |
| window_oil = Window.partitionBy("_dummy").orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow) |
| oil_df = oil_df.withColumn("dcoilwtico", last("dcoilwtico", ignorenulls=True).over(window_oil)).drop("_dummy") |
| |
| |
| df = df.join(stores_df, on="store_nbr", how="left") |
| df = df.join(oil_df, on="date", how="left").fillna({"dcoilwtico": 50.0}) |
| df = df.join(transactions_df, on=["date", "store_nbr"], how="left").fillna({"transactions": 0}) |
| |
| national_holidays = holidays_df.filter( |
| (col("locale") == "National") & (col("transferred") == False) |
| ).select("date").distinct().withColumn("is_holiday", lit(1)) |
| |
| df = df.join(national_holidays, on="date", how="left").fillna({"is_holiday": 0}) |
| |
| |
| df = df.withColumn("month", month(col("date"))) \ |
| .withColumn("day_of_week", dayofweek(col("date"))) \ |
| .withColumn("day_of_year", dayofyear(col("date"))) \ |
| .withColumn("is_weekend", when(col("day_of_week").isin([1, 7]), 1).otherwise(0)) |
| |
| |
| |
| windowSpec = Window.partitionBy("store_nbr", "family").orderBy("date") |
| |
| |
| df = df.withColumn("sales_lag_1", lag("sales", 1).over(windowSpec)) \ |
| .withColumn("sales_lag_7", lag("sales", 7).over(windowSpec)) \ |
| .withColumn("sales_lag_28", lag("sales", 28).over(windowSpec)) \ |
| .withColumn("transactions_lag_1", lag("transactions", 1).over(windowSpec)) \ |
| .withColumn("transactions_lag_7", lag("transactions", 7).over(windowSpec)) |
| |
| |
| |
| windowSpec_7 = windowSpec.rowsBetween(-7, -1) |
| windowSpec_28 = windowSpec.rowsBetween(-28, -1) |
| |
| df = df.withColumn("rolling_mean_7", avg("sales").over(windowSpec_7)) \ |
| .withColumn("rolling_std_7", stddev("sales").over(windowSpec_7)) \ |
| .withColumn("rolling_mean_28", avg("sales").over(windowSpec_28)) |
|
|
| |
| df = df.dropna() |
|
|
| logging.info(f"Writing engineered features to {output_parquet}...") |
| |
| df.write.mode("overwrite").parquet(output_parquet) |
| |
| logging.info("Feature engineering complete.") |
| spark.stop() |
|
|
| if __name__ == "__main__": |
| project_dir = Path(__file__).resolve().parents[2] |
| db_path = os.path.join(project_dir, "data", "warehouse.duckdb") |
| output_parquet = os.path.join(project_dir, "data", "processed", "features.parquet") |
| |
| os.makedirs(os.path.dirname(output_parquet), exist_ok=True) |
| |
| spark = get_spark_session() |
| build_features(spark, db_path, output_parquet) |
|
|