Spaces:
Running
Running
| # app.py | |
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import einops | |
| from huggingface_hub import snapshot_download | |
| from visionts import VisionTSpp, freq_to_seasonality_list | |
| # ======================== | |
| # 配置 | |
| # ======================== | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| REPO_ID = "Lefei/VisionTSpp" | |
| LOCAL_DIR = "./hf_models/VisionTSpp" | |
| CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt") | |
| ARCH = 'mae_base' | |
| # 下载模型 | |
| if not os.path.exists(CKPT_PATH): | |
| os.makedirs(LOCAL_DIR, exist_ok=True) | |
| print("Downloading model from Hugging Face Hub...") | |
| snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False) | |
| # 加载模型 | |
| model = VisionTSpp( | |
| ARCH, | |
| ckpt_path=CKPT_PATH, | |
| quantile=True, | |
| clip_input=True, | |
| complete_no_clip=False, | |
| color=True | |
| ).to(DEVICE) | |
| print(f"Model loaded on {DEVICE}") | |
| # Image normalization | |
| imagenet_mean = np.array([0.485, 0.456, 0.406]) | |
| imagenet_std = np.array([0.229, 0.224, 0.225]) | |
| # ======================== | |
| # 预设数据集 | |
| # ======================== | |
| PRESET_DATASETS = { | |
| "ETTm1 (15-min)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTm1.csv", | |
| "ETTh1 (1-hour)": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv", | |
| "Illness": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/illness.csv", | |
| "Weather": "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/weather.csv" | |
| } | |
| # 本地缓存路径 | |
| PRESET_DIR = "./preset_data" | |
| os.makedirs(PRESET_DIR, exist_ok=True) | |
| def load_preset_data(name): | |
| url = PRESET_DATASETS[name] | |
| path = os.path.join(PRESET_DIR, f"{name.split(' ')[0]}.csv") | |
| if not os.path.exists(path): | |
| df = pd.read_csv(url) | |
| df.to_csv(path, index=False) | |
| else: | |
| df = pd.read_csv(path) | |
| return df | |
| # ======================== | |
| # 可视化函数 | |
| # ======================== | |
| def show_image_tensor(image, title='', cur_nvars=1, cur_color_list=None): | |
| cur_image = torch.zeros_like(image) | |
| height_per_var = image.shape[0] // cur_nvars | |
| for i in range(cur_nvars): | |
| cur_color = cur_color_list[i] | |
| cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color] = \ | |
| (image[i*height_per_var:(i+1)*height_per_var, :, cur_color] * imagenet_std[cur_color] + imagenet_mean[cur_color]) * 255 | |
| cur_image = torch.clamp(cur_image, 0, 255).cpu().int() | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| ax.imshow(cur_image.numpy()) | |
| ax.set_title(title, fontsize=14) | |
| ax.axis('off') | |
| plt.close(fig) | |
| return fig | |
| def visual_ts_with_quantiles(true, pred_median, pred_quantiles, lookback_len_visual=300, pred_len=96, quantile_colors=None): | |
| """ | |
| 可视化中叠加多个 quantile 区间 | |
| pred_quantiles: list of [pred_len, nvars] tensors | |
| """ | |
| if isinstance(true, torch.Tensor): | |
| true = true.cpu().numpy() | |
| if isinstance(pred_median, torch.Tensor): | |
| pred_median = pred_median.cpu().numpy() | |
| for i, q in enumerate(pred_quantiles): | |
| if isinstance(q, torch.Tensor): | |
| pred_quantiles[i] = q.cpu().numpy() | |
| nvars = true.shape[1] | |
| FIG_WIDTH = 12 | |
| FIG_HEIGHT_PER_VAR = 1.8 | |
| FONT_S = 10 | |
| fig, axes = plt.subplots( | |
| nrows=nvars, ncols=1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True, | |
| gridspec_kw={'height_ratios': [1] * nvars} | |
| ) | |
| if nvars == 1: | |
| axes = [axes] | |
| lookback_len = true.shape[0] - pred_len | |
| # Quantile 颜色(从外到内) | |
| if quantile_colors is None: | |
| quantile_colors = ['lightblue', 'skyblue', 'deepskyblue'] | |
| for i, ax in enumerate(axes): | |
| ax.plot(true[:, i], label='Ground Truth', color='gray', linewidth=2) | |
| ax.plot(np.arange(lookback_len, len(true)), pred_median[lookback_len:, i], | |
| label='Prediction (Median)', color='blue', linewidth=2) | |
| # 绘制 quantile 区间(从外到内) | |
| base = pred_median[lookback_len:] | |
| quantiles_sorted = sorted(zip(PREDS.quantiles, pred_quantiles), key=lambda x: x[0]) | |
| for (q, pred_q), color in zip(quantiles_sorted, quantile_colors): | |
| upper = pred_q[lookback_len:] | |
| lower = 2 * base - upper # 对称假设 | |
| ax.fill_between( | |
| np.arange(lookback_len, len(true)), | |
| lower[:, i], upper[:, i], | |
| color=color, alpha=0.5, label=f'Quantile {q:.1f}' | |
| ) | |
| y_min, y_max = ax.get_ylim() | |
| ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7) | |
| ax.set_yticks([]) | |
| ax.set_xticks([]) | |
| ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold') | |
| handles, labels = axes[0].get_legend_handles_labels() | |
| fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S}) | |
| plt.subplots_adjust(hspace=0) | |
| plt.close(fig) | |
| return fig | |
| # ======================== | |
| # 预测类封装(便于复用) | |
| # ======================== | |
| class PredictionResult: | |
| def __init__(self, ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples): | |
| self.ts_fig = ts_fig | |
| self.input_img_fig = input_img_fig | |
| self.recon_img_fig = recon_img_fig | |
| self.csv_path = csv_path | |
| self.total_samples = total_samples | |
| def predict_at_index(df, index, context_len=960, pred_len=394, freq="15Min"): | |
| # === 数据校验 === | |
| if 'date' not in df.columns: | |
| raise ValueError("❌ 数据集必须包含名为 'date' 的时间列。") | |
| try: | |
| df['date'] = pd.to_datetime(df['date']) | |
| except Exception: | |
| raise ValueError("❌ 'date' 列格式无法解析为时间,请检查日期格式。") | |
| df = df.sort_values('date').set_index('date') | |
| data = df.values | |
| nvars = data.shape[1] | |
| total_samples = len(data) - context_len - pred_len + 1 | |
| if total_samples <= 0: | |
| raise ValueError(f"数据太短,至少需要 {context_len + pred_len} 行,当前只有 {len(data)} 行。") | |
| if index >= total_samples: | |
| raise ValueError(f"索引越界,最大允许索引为 {total_samples - 1}") | |
| # 归一化 | |
| train_len = int(len(data) * 0.7) | |
| x_mean = data[:train_len].mean(axis=0, keepdims=True) | |
| x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8 | |
| data_norm = (data - x_mean) / x_std | |
| start_idx = index | |
| x = data_norm[start_idx:start_idx + context_len] | |
| y_true = data_norm[start_idx + context_len:start_idx + context_len + pred_len] | |
| periodicity_list = freq_to_seasonality_list(freq) | |
| periodicity = periodicity_list[0] if periodicity_list else 1 | |
| color_list = [i % 3 for i in range(nvars)] | |
| model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity) | |
| x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| y_pred, input_image, reconstructed_image, nvars_out, color_list_out = model.forward( | |
| x_tensor, export_image=True, color_list=color_list | |
| ) | |
| pred_median, pred_quantiles = y_pred # list of quantiles | |
| # 反归一化 | |
| y_true_orig = y_true * x_std + x_mean | |
| pred_med_orig = pred_median[0].cpu().numpy() * x_std + x_mean | |
| pred_quants_orig = [q[0].cpu().numpy() * x_std + x_mean for q in pred_quantiles] | |
| # 完整序列 | |
| full_true = np.concatenate([x * x_std + x_mean, y_true_orig], axis=0) | |
| full_pred_med = np.concatenate([x * x_std + x_mean, pred_med_orig], axis=0) | |
| # === 可视化 === | |
| ts_fig = visual_ts_with_quantiles( | |
| true=full_true, | |
| pred_median=full_pred_med, | |
| pred_quantiles=pred_quants_orig, | |
| lookback_len_visual=context_len, | |
| pred_len=pred_len | |
| ) | |
| input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list) | |
| recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list) | |
| # === 保存 CSV === | |
| os.makedirs("outputs", exist_ok=True) | |
| csv_path = "outputs/prediction_result.csv" | |
| time_index = df.index[start_idx:start_idx + context_len + pred_len] | |
| combined = np.concatenate([full_true, full_pred_med], axis=1) # [T, 2*nvars] | |
| col_names = [f"True_Var{i+1}" for i in range(nvars)] + [f"Pred_Var{i+1}" for i in range(nvars)] | |
| result_df = pd.DataFrame(combined, index=time_index, columns=col_names) | |
| result_df.to_csv(csv_path) | |
| return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples) | |
| # ======================== | |
| # Gradio 接口函数 | |
| # ======================== | |
| def run_forecast(data_source, upload_file, index, context_len, pred_len, freq): | |
| if data_source == "Upload CSV": | |
| if upload_file is None: | |
| raise ValueError("请上传一个 CSV 文件") | |
| df = pd.read_csv(upload_file.name) | |
| else: | |
| df = load_preset_data(data_source) | |
| try: | |
| result = predict_at_index(df, int(index), context_len=int(context_len), pred_len=int(pred_len), freq=freq) | |
| return ( | |
| result.ts_fig, | |
| result.input_img_fig, | |
| result.recon_img_fig, | |
| result.csv_path, | |
| gr.update(maximum=result.total_samples - 1, value=min(index, result.total_samples - 1)) | |
| ) | |
| except Exception as e: | |
| fig_err = plt.figure(figsize=(6, 4)) | |
| plt.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', wrap=True) | |
| plt.axis('off') | |
| plt.close(fig_err) | |
| return fig_err, fig_err, fig_err, None, gr.update() | |
| # ======================== | |
| # Gradio UI | |
| # ======================== | |
| with gr.Blocks(title="VisionTS++ 高级预测平台") as demo: | |
| gr.Markdown("# 🕰️ VisionTS++ 多变量时间序列预测平台") | |
| gr.Markdown(""" | |
| - ✅ 支持预设数据集或本地上传 | |
| - ✅ 上传规则:必须是 `.csv`,且包含 `date` 列 | |
| - ✅ 显示多分位数预测区间 | |
| - ✅ 支持下载预测结果 | |
| - ✅ 滑动样本实时预测 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| data_source = gr.Dropdown( | |
| label="选择数据源", | |
| choices=["ETTm1 (15-min)", "ETTh1 (1-hour)", "Illness", "Weather", "Upload CSV"], | |
| value="ETTm1 (15-min)" | |
| ) | |
| upload_file = gr.File(label="上传 CSV 文件", file_types=['.csv'], visible=False) | |
| context_len = gr.Number(label="历史长度", value=960) | |
| pred_len = gr.Number(label="预测长度", value=394) | |
| freq = gr.Textbox(label="频率 (如 15Min, H)", value="15Min") | |
| sample_index = gr.Slider(label="样本索引", minimum=0, maximum=100, step=1, value=0) | |
| with gr.Column(scale=3): | |
| ts_plot = gr.Plot(label="时间序列预测(含分位数区间)") | |
| with gr.Row(): | |
| input_img_plot = gr.Plot(label="Input Image") | |
| recon_img_plot = gr.Plot(label="Reconstructed Image") | |
| download_csv = gr.File(label="下载预测结果") | |
| btn = gr.Button("🚀 初始运行") | |
| # 上传切换 | |
| def toggle_upload(choice): | |
| return gr.update(visible=choice == "Upload CSV") | |
| data_source.change(fn=toggle_upload, inputs=data_source, outputs=upload_file) | |
| # 初始运行 + 滑动条变化都触发 | |
| btn.click( | |
| fn=run_forecast, | |
| inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq], | |
| outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index] | |
| ) | |
| # 【关键】滑动条变化时重新预测 | |
| sample_index.change( | |
| fn=run_forecast, | |
| inputs=[data_source, upload_file, sample_index, context_len, pred_len, freq], | |
| outputs=[ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index] | |
| ) | |
| # 示例 | |
| gr.Examples( | |
| examples=[ | |
| ["ETTm1 (15-min)", None, 960, 394, "15Min"], | |
| ["Illness", None, 36, 24, "D"] | |
| ], | |
| inputs=[data_source, upload_file, context_len, pred_len, freq], | |
| fn=lambda a,b,c,d,e: run_forecast(a,b,0,c,d,e), | |
| label="点击运行示例" | |
| ) | |
| demo.launch() | |