VisionTSpp / app.py
Lefei's picture
update app.py
c1f4164 verified
raw
history blame
12.2 kB
# app.py
import os
import gradio as gr
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import einops
from huggingface_hub import snapshot_download
from visionts import VisionTSpp, freq_to_seasonality_list
# ========================
# 配置
# ========================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
REPO_ID = "Lefei/VisionTSpp"
LOCAL_DIR = "./hf_models/VisionTSpp"
CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
ARCH = 'mae_base'
# 下载模型
if not os.path.exists(CKPT_PATH):
os.makedirs(LOCAL_DIR, exist_ok=True)
print("Downloading model from Hugging Face Hub...")
snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
# 加载模型
model = VisionTSpp(
ARCH,
ckpt_path=CKPT_PATH,
quantile=True,
clip_input=True,
complete_no_clip=False,
color=True
).to(DEVICE)
print(f"Model loaded on {DEVICE}")
# Image normalization
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
# ========================
# 预设数据集
# ========================
PRESET_DATASETS = {
"ETTm1 (15-min)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv",
"ETTh1 (1-hour)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv",
"Illness": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/illness.csv",
"Weather": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/weather.csv"
}
# 本地缓存路径
PRESET_DIR = "./preset_data"
os.makedirs(PRESET_DIR, exist_ok=True)
def load_preset_data(name):
url = PRESET_DATASETS[name]
path = os.path.join(PRESET_DIR, f"{name.split(' ')[0]}.csv")
if not os.path.exists(path):
df = pd.read_csv(url)
df.to_csv(path, index=False)
else:
df = pd.read_csv(path)
return df
# ========================
# 可视化函数
# ========================
def show_image_tensor(image, title='', cur_nvars=1, cur_color_list=None):
cur_image = torch.zeros_like(image)
height_per_var = image.shape[0] // cur_nvars
for i in range(cur_nvars):
cur_color = cur_color_list[i]
cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color] = \
(image[i*height_per_var:(i+1)*height_per_var, :, cur_color] * imagenet_std[cur_color] + imagenet_mean[cur_color]) * 255
cur_image = torch.clamp(cur_image, 0, 255).cpu().int()
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(cur_image.numpy())
ax.set_title(title, fontsize=14)
ax.axis('off')
plt.close(fig)
return fig
def visual_ts_with_quantiles(true, pred_median, pred_quantiles, lookback_len_visual=300, pred_len=96, quantile_colors=None):
"""
可视化中叠加多个 quantile 区间
pred_quantiles: list of [pred_len, nvars] tensors
"""
if isinstance(true, torch.Tensor):
true = true.cpu().numpy()
if isinstance(pred_median, torch.Tensor):
pred_median = pred_median.cpu().numpy()
for i, q in enumerate(pred_quantiles):
if isinstance(q, torch.Tensor):
pred_quantiles[i] = q.cpu().numpy()
nvars = true.shape[1]
FIG_WIDTH = 12
FIG_HEIGHT_PER_VAR = 1.8
FONT_S = 10
fig, axes = plt.subplots(
nrows=nvars, ncols=1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True,
gridspec_kw={'height_ratios': [1] * nvars}
)
if nvars == 1:
axes = [axes]
lookback_len = true.shape[0] - pred_len
# Quantile 颜色(从外到内)
if quantile_colors is None:
quantile_colors = ['lightblue', 'skyblue', 'deepskyblue']
for i, ax in enumerate(axes):
ax.plot(true[:, i], label='Ground Truth', color='gray', linewidth=2)
ax.plot(np.arange(lookback_len, len(true)), pred_median[lookback_len:, i],
label='Prediction (Median)', color='blue', linewidth=2)
# 绘制 quantile 区间(从外到内)
base = pred_median[lookback_len:]
quantiles_sorted = sorted(zip(PREDS.quantiles, pred_quantiles), key=lambda x: x[0])
for (q, pred_q), color in zip(quantiles_sorted, quantile_colors):
upper = pred_q[lookback_len:]
lower = 2 * base - upper # 对称假设
ax.fill_between(
np.arange(lookback_len, len(true)),
lower[:, i], upper[:, i],
color=color, alpha=0.5, label=f'Quantile {q:.1f}'
)
y_min, y_max = ax.get_ylim()
ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
ax.set_yticks([])
ax.set_xticks([])
ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold')
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S})
plt.subplots_adjust(hspace=0)
plt.close(fig)
return fig
# ========================
# 预测类封装(便于复用)
# ========================
class PredictionResult:
def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples):
self.ts_fig = ts_fig
self.input_img_fig = input_img_fig
self.recon_img_fig = recon_img_fig
self.csv_path = csv_path
self.total_samples = total_samples
def predict_at_index(df, index, context_len=960, pred_len=394, freq="15Min"):
# === 数据校验 ===
if 'date' not in df.columns:
raise ValueError("❌ 数据集必须包含名为 'date' 的时间列。")
try:
df['date'] = pd.to_datetime(df['date'])
except Exception:
raise ValueError("❌ 'date' 列格式无法解析为时间,请检查日期格式。")
df = df.sort_values('date').set_index('date')
data = df.values
nvars = data.shape[1]
total_samples = len(data) - context_len - pred_len + 1
if total_samples <= 0:
raise ValueError(f"数据太短,至少需要 {context_len + pred_len} 行,当前只有 {len(data)} 行。")
if index >= total_samples:
raise ValueError(f"索引越界,最大允许索引为 {total_samples - 1}")
# 归一化
train_len = int(len(data) * 0.7)
x_mean = data[:train_len].mean(axis=0, keepdims=True)
x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
data_norm = (data - x_mean) / x_std
start_idx = index
x = data_norm[start_idx:start_idx + context_len]
y_true = data_norm[start_idx + context_len:start_idx + context_len + pred_len]
periodicity_list = freq_to_seasonality_list(freq)
periodicity = periodicity_list[0] if periodicity_list else 1
color_list = [i % 3 for i in range(nvars)]
model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE)
with torch.no_grad():
y_pred, input_image, reconstructed_image, nvars_out, color_list_out = model.forward(
x_tensor, export_image=True, color_list=color_list
)
pred_median, pred_quantiles = y_pred # list of quantiles
# 反归一化
y_true_orig = y_true * x_std + x_mean
pred_med_orig = pred_median[0].cpu().numpy() * x_std + x_mean
pred_quants_orig = [q[0].cpu().numpy() * x_std + x_mean for q in pred_quantiles]
# 完整序列
full_true = np.concatenate([x * x_std + x_mean, y_true_orig], axis=0)
full_pred_med = np.concatenate([x * x_std + x_mean, pred_med_orig], axis=0)
# === 可视化 ===
ts_fig = visual_ts_with_quantiles(
true=full_true,
pred_median=full_pred_med,
pred_quantiles=pred_quants_orig,
lookback_len_visual=context_len,
pred_len=pred_len
)
input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
# === 保存 CSV ===
os.makedirs("outputs", exist_ok=True)
csv_path = "outputs/prediction_result.csv"
time_index = df.index[start_idx:start_idx + context_len + pred_len]
combined = np.concatenate([full_true, full_pred_med], axis=1) # [T, 2*nvars]
col_names = [f"True_Var{i+1}" for i in range(nvars)] + [f"Pred_Var{i+1}" for i in range(nvars)]
result_df = pd.DataFrame(combined, index=time_index, columns=col_names)
result_df.to_csv(csv_path)
return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples)
# ========================
# Gradio 接口函数
# ========================
def run_forecast(data_source, upload_file, index, context_len, pred_len, freq):
if data_source == "Upload CSV":
if upload_file is None:
raise ValueError("请上传一个 CSV 文件")
df = pd.read_csv(upload_file.name)
else:
df = load_preset_data(data_source)
try:
result = predict_at_index(df, int(index), context_len=int(context_len), pred_len=int(pred_len), freq=freq)
return (
result.ts_fig,
result.input_img_fig,
result.recon_img_fig,
result.csv_path,
gr.update(maximum=result.total_samples - 1, value=min(index, result.total_samples - 1))
)
except Exception as e:
fig_err = plt.figure(figsize=(6, 4))
plt.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', wrap=True)
plt.axis('off')
plt.close(fig_err)
return fig_err, fig_err, fig_err, None, gr.update()
# ========================
# Gradio UI
# ========================
with gr.Blocks(title="VisionTS++ 高级预测平台") as demo:
gr.Markdown("# 🕰️ VisionTS++ 多变量时间序列预测平台")
gr.Markdown("""
- ✅ 支持预设数据集或本地上传
- ✅ 上传规则:必须是 `.csv`,且包含 `date` 列
- ✅ 显示多分位数预测区间
- ✅ 支持下载预测结果
- ✅ 滑动样本实时预测
""")
with gr.Row():
with gr.Column(scale=2):
data_source = gr.Dropdown(
label="选择数据源",
choices=["ETTm1 (15-min)", "ETTh1 (1-hour)", "Illness", "Weather", "Upload CSV"],
value="ETTm1 (15-min)"
)
upload_file = gr.File(label="上传 CSV 文件", file_types=['.csv'], visible=False)
context_len = gr.Number(label="历史长度", value=960)
pred_len = gr.Number(label="预测长度", value=394)
freq = gr.Textbox(label="频率 (如 15Min, H)", value="15Min")
sample_index = gr.Slider(label="样本索引", minimum=0, maximum=100, step=1, value=0)
with gr.Column(scale=3):
ts_plot = gr.Plot(label="时间序列预测(含分位数区间)")
with gr.Row():
input_img_plot = gr.Plot(label="Input Image")
recon_img_plot = gr.Plot(label="Reconstructed Image")
download_csv = gr.File(label="下载预测结果")
btn = gr.Button("🚀 初始运行")
# 上传切换
def toggle_upload(choice):
return gr.update(visible=choice == "Upload CSV")
data_source.change(fn=toggle_upload, inputs=data_source, outputs=upload_file)
# 初始运行 + 滑动条变化都触发
btn.click(
fn=run_forecast,
inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
)
# 【关键】滑动条变化时重新预测
sample_index.change(
fn=run_forecast,
inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq],
outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index]
)
# 示例
gr.Examples(
examples=[
["ETTm1 (15-min)", None, 960, 394, "15Min"],
["Illness", None, 36, 24, "D"]
],
inputs=[data_source, upload_file, context_len, pred_len, freq],
fn=lambda a,b,c,d,e: run_forecast(a,b,0,c,d,e),
label="点击运行示例"
)
demo.launch()