energy_forecasting / utils /preprocessing.py
kawaiipeace's picture
update model
570d1fd
# utils/preprocessing.py
import pandas as pd
def load_and_process_data(file_path, is_multivariate, keep_datetime_column_for_darts=False):
df = pd.read_csv(file_path)
# Auto-detect time column
time_col = None
for col in df.columns:
if pd.api.types.is_datetime64_any_dtype(df[col]) or "date" in col.lower() or "time" in col.lower():
time_col = col
break
if time_col:
df[time_col] = pd.to_datetime(df[time_col], errors="coerce") # force datetime conversion
df = df.dropna(subset=[time_col]) # drop rows where datetime is NaT
df.set_index(time_col, inplace=True)
if not is_multivariate:
numeric_cols = df.select_dtypes(include=["float64", "int64"]).columns
if len(numeric_cols) == 0:
raise ValueError("No numeric column found for univariate forecast.")
if keep_datetime_column_for_darts:
df = df[[numeric_cols[0]]] # Keep datetime for Darts
else:
df = df[[numeric_cols[0]]] # Just the numeric column, no datetime needed
return df, df.shape[1]