energy_forecasting / models /hf_models.py
kawaiipeace's picture
update model
570d1fd
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import root_mean_squared_error, r2_score
from transformers import AutoModelForTimeSeriesForecasting
import pandas as pd
def run_hf_forecast(df, model_name, is_multivariate, horizon):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id_map = {
"Informer": "kashif/informer_model",
"Autoformer": "kashif/autoformer_model",
"TimesNet": "kashif/timesnet_model"
}
model_id = model_id_map.get(model_name)
if model_id is None:
raise ValueError("Unsupported model")
model = AutoModelForTimeSeriesForecasting.from_pretrained(model_id).to(device)
values = df.values.astype("float32")
context_len = getattr(model.config, "context_length", 48)
input_seq = values[-context_len:].reshape(1, context_len, -1)
input_tensor = torch.tensor(input_seq).to(device)
with torch.no_grad():
output = model.generate(input_tensor, prediction_length=horizon)
forecast = output.squeeze().cpu().numpy().flatten()
actual = values[-horizon:].flatten()
fig = plt.figure()
plt.plot(actual, label="Actual")
plt.plot(forecast, label="Forecast", linestyle="dashed")
plt.title(f"{model_name} Forecast")
plt.legend()
plt.xlabel("Time Step")
plt.ylabel("Value")
rmse = root_mean_squared_error(actual, forecast)
r2 = r2_score(actual, forecast)
metrics = f"RMSE: {rmse:.3f}, R2: {r2:.3f}"
return forecast, fig, metrics