| import io |
| import zipfile |
| from pathlib import Path |
| from typing import List, Tuple, Literal, Optional |
| from evaluation.metrics import get_metrics |
| import gradio as gr |
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| from huggingface_hub import hf_hub_download |
| from huggingface_hub.errors import HfHubHTTPError |
|
|
| from model_wrapper import run_Time_RCD |
|
|
| REPO_ID = "thu-sail-lab/Time-RCD" |
|
|
| CHECKPOINT_FILES = [ |
| "checkpoints/full_mask_anomaly_head_pretrain_checkpoint_best.pth", |
| "checkpoints/dataset_10_20.pth", |
| "checkpoints/full_mask_10_20.pth", |
| "checkpoints/dataset_15_56.pth", |
| "checkpoints/full_mask_15_56.pth", |
| ] |
|
|
|
|
| def ensure_checkpoints() -> None: |
| """Ensure that the required checkpoint files are present locally.""" |
| missing = [path for path in CHECKPOINT_FILES if not Path(path).exists()] |
| if not missing: |
| return |
|
|
| try: |
| zip_path = hf_hub_download( |
| repo_id=REPO_ID, |
| filename="checkpoints.zip", |
| repo_type="model", |
| cache_dir=".cache/hf", |
| ) |
| except HfHubHTTPError: |
| zip_path = hf_hub_download( |
| repo_id=REPO_ID, |
| filename="checkpoints.zip", |
| repo_type="dataset", |
| cache_dir=".cache/hf", |
| ) |
|
|
| with zipfile.ZipFile(zip_path, "r") as zf: |
| zf.extractall(".") |
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent |
| SAMPLE_DATASET_DIR = BASE_DIR / "sample_datasets" |
|
|
| LabelSource = Literal["same_file", "separate_file", "none"] |
| LABEL_COLUMN_CANDIDATES = ("label", "labels") |
| LABEL_SOURCE_CHOICES = { |
| "Value + label in same file": "same_file", |
| "Labels in separate file": "separate_file", |
| "No labels provided": "none", |
| } |
| SAMPLE_FILES: dict[str, dict[str, object]] = { |
| "Sample: Univariate SED Medical": { |
| "path": SAMPLE_DATASET_DIR / "235_SED_id_2_Medical_tr_2499_1st_3840.csv", |
| "is_multivariate": False, |
| }, |
| "Sample: Univariate UCR Medical": { |
| "path": SAMPLE_DATASET_DIR / "353_UCR_id_51_Medical_tr_1875_1st_3198.csv", |
| "is_multivariate": False, |
| }, |
| "Sample: Univariate Yahoo WebService": { |
| "path": SAMPLE_DATASET_DIR / "686_YAHOO_id_136_WebService_tr_500_1st_755.csv", |
| "is_multivariate": False, |
| }, |
| |
| |
| |
| |
| } |
|
|
|
|
| def _resolve_path(file_obj) -> Path: |
| """Extract a pathlib.Path from the gradio file object.""" |
| if file_obj is None: |
| raise ValueError("File object is None.") |
|
|
| if isinstance(file_obj, Path): |
| return file_obj |
|
|
| if isinstance(file_obj, str): |
| path = Path(file_obj) |
| if not path.is_absolute(): |
| path = (BASE_DIR / path).resolve() |
| return path |
|
|
| |
| if isinstance(file_obj, dict) and "name" in file_obj: |
| return _resolve_path(file_obj["name"]) |
|
|
| name = getattr(file_obj, "name", None) |
| if not name: |
| raise ValueError("Unable to resolve uploaded file path.") |
| return _resolve_path(name) |
|
|
|
|
| def _load_dataframe(path: Path) -> pd.DataFrame: |
| """Load a dataframe from supported file types.""" |
| if not path.exists(): |
| raise ValueError(f"File not found: {path}. If this is a bundled sample, ensure it exists under {SAMPLE_DATASET_DIR}.") |
|
|
| suffix = path.suffix.lower() |
| if suffix == ".npy": |
| data = np.load(path, allow_pickle=False) |
| if data.ndim == 1: |
| data = data.reshape(-1, 1) |
| if not isinstance(data, np.ndarray): |
| raise ValueError("Loaded .npy data is not a numpy array.") |
| return pd.DataFrame(data) |
|
|
| if suffix not in {".csv", ".txt"}: |
| raise ValueError("Unsupported file type. Please upload a .csv, .txt, or .npy file.") |
|
|
| return pd.read_csv(path) |
|
|
|
|
| def _extract_label_column(df: pd.DataFrame) -> Tuple[pd.DataFrame, Optional[pd.Series]]: |
| """Split a label column from dataframe if one of the candidate names exists.""" |
| lower_to_original = {col.lower(): col for col in df.columns} |
| label_col = None |
| for candidate in LABEL_COLUMN_CANDIDATES: |
| if candidate in lower_to_original: |
| label_col = lower_to_original[candidate] |
| break |
|
|
| if label_col is None: |
| return df, None |
|
|
| label_series = pd.to_numeric(df[label_col], errors="raise") |
| feature_df = df.drop(columns=[label_col]) |
| return feature_df, label_series |
|
|
|
|
| def _load_label_series(file_obj) -> pd.Series: |
| """Load labels from a dedicated upload.""" |
| path = _resolve_path(file_obj) |
| df = _load_dataframe(path) |
| numeric_df = df.select_dtypes(include=np.number) |
| if numeric_df.empty: |
| raise ValueError("Uploaded label file does not contain numeric columns.") |
|
|
| lower_to_original = {col.lower(): col for col in numeric_df.columns} |
| for candidate in LABEL_COLUMN_CANDIDATES: |
| if candidate in lower_to_original: |
| column = lower_to_original[candidate] |
| return pd.to_numeric(numeric_df[column], errors="raise").rename("label") |
|
|
| if numeric_df.shape[1] > 1: |
| raise ValueError( |
| "Label file must contain exactly one numeric column or include a column named 'label'." |
| ) |
|
|
| series = pd.to_numeric(numeric_df.iloc[:, 0], errors="raise").rename("label") |
| return series |
|
|
|
|
| def load_timeseries( |
| value_file, |
| feature_columns: List[str] | None, |
| label_source: LabelSource, |
| label_file=None, |
| ) -> Tuple[pd.DataFrame, np.ndarray, Optional[pd.Series]]: |
| """Load the uploaded value file, optional label file, and return features/labels.""" |
| value_path = _resolve_path(value_file) |
| raw_df = _load_dataframe(value_path) |
| feature_df = raw_df.select_dtypes(include=np.number) |
|
|
| if feature_df.empty: |
| raise ValueError("No numeric columns detected. Ensure your value file contains numeric values.") |
|
|
| label_series: Optional[pd.Series] = None |
| feature_df, embedded_label = _extract_label_column(feature_df) |
|
|
| if label_source == "same_file": |
| if embedded_label is None: |
| raise ValueError( |
| "Label column not found in the uploaded file. Expected a column named 'label'." |
| ) |
| label_series = embedded_label |
| elif label_source == "separate_file": |
| if label_file is None: |
| raise ValueError("Please upload a label file or switch the label source option.") |
| label_series = _load_label_series(label_file) |
| elif label_source == "none": |
| label_series = None |
| else: |
| raise ValueError(f"Unsupported label source option: {label_source}") |
|
|
| if feature_columns: |
| missing = [col for col in feature_columns if col not in feature_df.columns] |
| if missing: |
| raise ValueError(f"Selected columns not found in the value file: {', '.join(missing)}") |
| feature_df = feature_df[feature_columns] |
|
|
| feature_df = feature_df.reset_index(drop=True) |
|
|
| if label_series is not None: |
| label_series = label_series.reset_index(drop=True) |
| if len(label_series) != len(feature_df): |
| min_length = min(len(label_series), len(feature_df)) |
| label_series = label_series.iloc[:min_length].reset_index(drop=True) |
| feature_df = feature_df.iloc[:min_length, :].reset_index(drop=True) |
|
|
| array = feature_df.to_numpy(dtype=np.float32) |
| if array.ndim == 1: |
| array = array.reshape(-1, 1) |
|
|
| return feature_df, array, label_series |
|
|
|
|
| def _metrics_to_dataframe(metrics: dict[str, float]) -> pd.DataFrame: |
| if not metrics: |
| return pd.DataFrame({"Metric": [], "Value": []}) |
| return pd.DataFrame( |
| { |
| "Metric": list(metrics.keys()), |
| "Value": [round(float(value), 4) for value in metrics.values()], |
| } |
| ) |
|
|
|
|
| def infer( |
| file_obj, |
| is_multivariate: bool, |
| window_size: int, |
| batch_size: int, |
| multi_size: str, |
| feature_columns: List[str], |
| label_source: LabelSource, |
| label_file, |
| ) -> Tuple[str, pd.DataFrame, plt.Figure, pd.DataFrame]: |
| """Run Time-RCD inference and produce outputs for the Gradio UI.""" |
| ensure_checkpoints() |
| numeric_df, array, labels = load_timeseries( |
| file_obj, feature_columns or None, label_source=label_source, label_file=label_file |
| ) |
|
|
| num_features = array.shape[1] if array.ndim > 1 else 1 |
| if is_multivariate and num_features == 1: |
| raise ValueError( |
| "Dataset check: only one feature column found, so please switch the Data type to 'Univariate' or upload a multivariate file with multiple feature columns." |
| ) |
| if not is_multivariate and num_features > 1: |
| raise ValueError( |
| "Dataset check: multiple feature columns detected, so please switch the Data type to 'Multivariate' or provide a univariate file with a single feature column." |
| ) |
|
|
| kwargs = { |
| "Multi": is_multivariate, |
| "win_size": window_size, |
| "batch_size": batch_size, |
| "random_mask": "random_mask", |
| "size": multi_size, |
| "device": "cpu", |
| } |
|
|
| scores, logits = run_Time_RCD(array, **kwargs) |
| score_vector = np.asarray(scores).reshape(-1) |
| logit_vector = np.asarray(logits).reshape(-1) |
|
|
| valid_length = min(len(score_vector), len(numeric_df)) |
| if labels is not None: |
| valid_length = min(valid_length, len(labels)) |
|
|
| result_df = numeric_df.iloc[:valid_length, :].copy() |
| score_series = pd.Series(score_vector[:valid_length], index=result_df.index, name="anomaly_score") |
| logit_series = pd.Series(logit_vector[:valid_length], index=result_df.index, name="anomaly_logit") |
| result_df["anomaly_score"] = score_series |
| result_df["anomaly_logit"] = logit_series |
|
|
| metrics_df: pd.DataFrame |
| if labels is not None: |
| label_series = labels.iloc[:valid_length] |
| result_df["label"] = label_series.to_numpy() |
| metrics = get_metrics(score_series.to_numpy(), label_series.to_numpy()) |
| metrics_df = _metrics_to_dataframe(metrics) |
| else: |
| metrics_df = pd.DataFrame({"Metric": ["Info"], "Value": ["Labels not provided; metrics skipped."]}) |
|
|
| top_indices = score_series.nlargest(5).index.tolist() |
| highlight_message = ( |
| "Top anomaly indices (by score): " + ", ".join(str(idx) for idx in top_indices) |
| if len(top_indices) > 0 |
| else "No anomalies detected." |
| ) |
| if labels is None: |
| highlight_message += " Metrics skipped due to missing labels." |
|
|
| figure = build_plot(result_df) |
|
|
| return highlight_message, result_df, figure, metrics_df |
|
|
|
|
| def build_plot(result_df: pd.DataFrame) -> plt.Figure: |
| """Create a matplotlib plot of the first feature vs. anomaly score.""" |
| fig, ax_primary = plt.subplots( |
| figsize=(12, 4), |
| dpi=200, |
| constrained_layout=True |
| ) |
| index = result_df.index |
| feature_cols = [ |
| col for col in result_df.columns if col not in {"anomaly_score", "anomaly_logit", "label"} |
| ] |
|
|
| primary_col = feature_cols[0] |
| ax_primary.plot( |
| index, |
| result_df[primary_col], |
| label=f"{primary_col}", |
| color="#1f77b4", |
| linewidth=1.0, |
| ) |
|
|
| if "label" in result_df.columns: |
| anomalies = result_df[result_df["label"] > 0] |
| if not anomalies.empty: |
| ax_primary.scatter( |
| anomalies.index, |
| anomalies[primary_col], |
| label="Label = 1", |
| color="#ff7f0e", |
| marker="o", |
| s=30, |
| alpha=0.85, |
| ) |
|
|
| ax_primary.set_xlabel("Index") |
| ax_primary.set_ylabel("Value") |
| ax_primary.grid(alpha=0.2) |
|
|
| ax_secondary = ax_primary.twinx() |
| ax_secondary.plot( |
| index, |
| result_df["anomaly_score"], |
| label="Anomaly Score", |
| color="#d62728", |
| linewidth=1.0, |
| ) |
| ax_secondary.set_ylabel("Anomaly Score") |
|
|
| handles_primary, labels_primary = ax_primary.get_legend_handles_labels() |
| handles_secondary, labels_secondary = ax_secondary.get_legend_handles_labels() |
| ax_primary.legend( |
| handles_primary + handles_secondary, |
| labels_primary + labels_secondary, |
| loc="upper right", |
| ) |
|
|
| fig.tight_layout() |
| return fig |
|
|
|
|
| def build_interface() -> gr.Blocks: |
| """Define the Gradio UI.""" |
| with gr.Blocks(title="Time-RCD Zero-Shot Anomaly Detection") as demo: |
| gr.Markdown( |
| "# Time-RCD Zero-Shot Anomaly Detection\n" |
| "Start with one of the bundled datasets or upload your own time series to run zero-shot anomaly detection." |
| ) |
|
|
| bundled_choices = list(SAMPLE_FILES.keys()) |
| default_choice = bundled_choices[0] if bundled_choices else "Upload my own" |
|
|
| data_selector = gr.Radio( |
| choices=bundled_choices + ["Upload my own"], |
| value=default_choice, |
| label="Choose dataset", |
| ) |
|
|
| with gr.Row(): |
| file_input = gr.File( |
| label="Upload time series file (.csv, .txt, .npy)", |
| file_types=[".csv", ".txt", ".npy"], |
| visible=default_choice == "Upload my own", |
| ) |
| label_source = gr.Radio( |
| choices=list(LABEL_SOURCE_CHOICES.keys()), |
| value="Value + label in same file", |
| label="Label source", |
| ) |
|
|
| with gr.Row(): |
| label_file_input = gr.File( |
| label="Upload label file (.csv, .txt, .npy)", |
| file_types=[".csv", ".txt", ".npy"], |
| visible=False, |
| ) |
| column_selector = gr.Textbox( |
| label="Columns to use (comma-separated, optional)", |
| placeholder="e.g. value,feature_1,feature_2", |
| ) |
|
|
| gr.Markdown( |
| "Bundled datasets live in the Downloads folder and include labels unless noted. " |
| "Select \"Upload my own\" to provide a custom file." |
| ) |
|
|
| with gr.Row(): |
| multivariate = gr.Radio( |
| choices=["Univariate", "Multivariate"], |
| value=( |
| "Multivariate" |
| if bundled_choices and SAMPLE_FILES[default_choice]["is_multivariate"] |
| else "Univariate" |
| ), |
| label="Data type", |
| ) |
| window_size_in = gr.Slider( |
| minimum=128, |
| maximum=20000, |
| value=15000, |
| step=128, |
| label="Window size", |
| ) |
| batch_size_in = gr.Slider( |
| minimum=1, |
| maximum=128, |
| value=16, |
| step=1, |
| label="Batch size", |
| ) |
|
|
| with gr.Row(): |
| multi_size_in = gr.Radio( |
| choices=["full", "small"], |
| value="full", |
| label="Multivariate model size", |
| ) |
|
|
| run_button = gr.Button("Run Inference", variant="primary") |
|
|
| result_message = gr.Textbox(label="Summary", interactive=False) |
| result_dataframe = gr.DataFrame(label="Anomaly Scores", interactive=False) |
| plot_output = gr.Plot(label="Series vs. Anomaly Score") |
| metrics_output = gr.DataFrame(label="Metrics", interactive=False) |
| def _submit( |
| data_choice, |
| file_obj, |
| label_source_choice, |
| label_file_obj, |
| multivariate_choice, |
| win, |
| batch, |
| size, |
| columns_text, |
| ): |
| use_sample = data_choice != "Upload my own" |
| if use_sample: |
| sample_entry = SAMPLE_FILES[data_choice] |
| value_obj = sample_entry["path"] |
| else: |
| value_obj = file_obj |
|
|
| if value_obj is None: |
| raise gr.Error("Please upload a time series file or choose a sample.") |
|
|
| feature_columns = [col.strip() for col in columns_text.split(",") if col.strip()] if columns_text else [] |
| is_multi = multivariate_choice == "Multivariate" |
| resolved_label_source = LABEL_SOURCE_CHOICES[label_source_choice] |
| if resolved_label_source == "separate_file" and label_file_obj is None: |
| raise gr.Error("Please upload a label file or change the label source option.") |
| summary, df, fig, metrics = infer( |
| file_obj=value_obj, |
| is_multivariate=is_multi, |
| window_size=int(win), |
| batch_size=int(batch), |
| multi_size=size, |
| feature_columns=feature_columns, |
| label_source=resolved_label_source, |
| label_file=label_file_obj, |
| ) |
| return summary, df, fig, metrics |
|
|
| def _toggle_label_file(option): |
| return gr.update(visible=option == "Labels in separate file") |
|
|
| def _handle_dataset_choice(choice): |
| show_upload = choice == "Upload my own" |
| if choice == "Upload my own": |
| multi_update = gr.update() |
| else: |
| expected_multi = SAMPLE_FILES[choice]["is_multivariate"] |
| multi_update = gr.update(value="Multivariate" if expected_multi else "Univariate") |
| return gr.update(visible=show_upload), multi_update |
|
|
| label_source.change(fn=_toggle_label_file, inputs=label_source, outputs=label_file_input) |
| data_selector.change(fn=_handle_dataset_choice, inputs=data_selector, outputs=[file_input, multivariate]) |
|
|
| run_button.click( |
| fn=_submit, |
| inputs=[ |
| data_selector, |
| file_input, |
| label_source, |
| label_file_input, |
| multivariate, |
| window_size_in, |
| batch_size_in, |
| multi_size_in, |
| column_selector, |
| ], |
| outputs=[result_message, result_dataframe, plot_output, metrics_output], |
| ) |
|
|
| return demo |
|
|
|
|
| demo = build_interface() |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|
|
|