Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| IPAD VAD Training Interface on HuggingFace Spaces with ZeroGPU | |
| Updated version with integrated training infrastructure | |
| """ | |
| # IMPORTANT: Clear Python cache first to avoid loading stale modules | |
| import shutil | |
| from pathlib import Path | |
| for pycache in Path('.').rglob('__pycache__'): | |
| shutil.rmtree(pycache, ignore_errors=True) | |
| for pyc in Path('.').rglob('*.pyc'): | |
| pyc.unlink(missing_ok=True) | |
| print("π§Ή Cache cleared - loading fresh modules") | |
| import gradio as gr | |
| import torch | |
| import os | |
| import json | |
| from datetime import datetime | |
| import zipfile | |
| from huggingface_hub import hf_hub_download, HfApi | |
| import subprocess | |
| import sys | |
| from typing import Optional, Dict | |
| # Import training infrastructure | |
| from train_hf import IPADTrainer | |
| from dataset import download_and_extract_dataset, DEVICE_NAMES, SYNTHETIC_DEVICES | |
| import spaces # ZeroGPU decorator | |
| # Global state | |
| DATASET_PATH = None | |
| CHECKPOINT_DIR = Path("./checkpoints") | |
| CHECKPOINT_DIR.mkdir(exist_ok=True) | |
| def setup_dataset(progress=gr.Progress()) -> str: | |
| """Download and extract IPAD dataset from HF Hub""" | |
| global DATASET_PATH | |
| progress(0, desc="Downloading dataset...") | |
| if DATASET_PATH and DATASET_PATH.exists(): | |
| return f"β Dataset already available at {DATASET_PATH}" | |
| try: | |
| DATASET_PATH = download_and_extract_dataset(cache_dir="./cache") | |
| progress(1.0, desc="Complete!") | |
| return f"β Dataset downloaded and extracted to {DATASET_PATH}\nπ Ready for training!" | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # Request GPU for 1 minute | |
| def quick_gpu_test() -> Dict: | |
| """Quick test to verify GPU access and model loading""" | |
| try: | |
| from IPAD.model.video_swin_transformer import VST | |
| # Check GPU | |
| gpu_available = torch.cuda.is_available() | |
| gpu_name = torch.cuda.get_device_name(0) if gpu_available else "None" | |
| if not gpu_available: | |
| return { | |
| "status": "β οΈ Warning", | |
| "message": "No GPU available", | |
| "gpu_available": False, | |
| "gpu_name": "None" | |
| } | |
| # Load model | |
| model = VST(mem_dim=2000, shrink_thres=0.0025) | |
| model = model.cuda() | |
| # Create dummy input | |
| dummy_input = torch.randn(1, 3, 16, 256, 256).cuda() | |
| # Forward pass | |
| with torch.no_grad(): | |
| output = model(dummy_input) | |
| result = { | |
| "status": "β Success", | |
| "message": "GPU test passed!", | |
| "gpu_available": True, | |
| "gpu_name": gpu_name, | |
| "output_shape": str(output['output'].shape), | |
| "attention_shape": str(output['att'].shape), | |
| "period_shape": str(output['recon_index'].shape), | |
| "memory_allocated_gb": f"{torch.cuda.memory_allocated() / 1e9:.2f}", | |
| "memory_reserved_gb": f"{torch.cuda.memory_reserved() / 1e9:.2f}" | |
| } | |
| return result | |
| except Exception as e: | |
| return { | |
| "status": "β Error", | |
| "message": str(e), | |
| "gpu_available": torch.cuda.is_available(), | |
| "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None" | |
| } | |
| # Request GPU for 1 hour | |
| def train_quick_baseline( | |
| device_name: str = "S01", | |
| epochs: int = 10, | |
| batch_size: int = 4, | |
| lr: float = 1e-4, | |
| progress=gr.Progress() | |
| ) -> str: | |
| """Quick baseline training (10 epochs for testing)""" | |
| global DATASET_PATH | |
| # Auto-download dataset if not available | |
| if DATASET_PATH is None or not DATASET_PATH.exists(): | |
| progress(0, desc="Dataset not found, downloading...") | |
| try: | |
| DATASET_PATH = download_and_extract_dataset(cache_dir="./cache") | |
| progress(0.05, desc="Dataset ready, starting training...") | |
| except Exception as e: | |
| return f"β Error downloading dataset: {str(e)}" | |
| progress(0, desc="Initializing trainer...") | |
| try: | |
| # Create trainer | |
| trainer = IPADTrainer( | |
| device_name=device_name, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr, | |
| mem_dim=2000, | |
| checkpoint_dir=str(CHECKPOINT_DIR), | |
| wandb_project=None, # Disable wandb for quick test | |
| hf_repo=None # Disable auto-upload for quick test | |
| ) | |
| progress(0.1, desc="Loading dataset...") | |
| # Train | |
| trainer.train(str(DATASET_PATH)) | |
| progress(1.0, desc="Training complete!") | |
| # Get latest checkpoint | |
| checkpoints = list(CHECKPOINT_DIR.glob(f"{device_name}_*.pth")) | |
| latest_checkpoint = max(checkpoints, key=lambda p: p.stat().st_mtime) if checkpoints else None | |
| result = f""" | |
| β Quick baseline training complete! | |
| π Configuration: | |
| - Device: {device_name} | |
| - Epochs: {epochs} | |
| - Batch Size: {batch_size} | |
| - Learning Rate: {lr} | |
| πΎ Checkpoint: | |
| - {latest_checkpoint.name if latest_checkpoint else 'No checkpoint saved'} | |
| π― Next Steps: | |
| 1. Review training metrics | |
| 2. Run full 200-epoch training | |
| 3. Evaluate on test set | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"β Training failed: {str(e)}\n\nPlease check the logs for details." | |
| # Request GPU for 2 hours | |
| def train_full_baseline( | |
| device_name: str = "S01", | |
| epochs: int = 200, | |
| batch_size: int = 4, | |
| lr: float = 1e-4, | |
| mem_dim: int = 2000, | |
| enable_wandb: bool = False, | |
| enable_hf_upload: bool = True, | |
| progress=gr.Progress() | |
| ) -> str: | |
| """Full baseline training (200 epochs)""" | |
| global DATASET_PATH | |
| if DATASET_PATH is None or not DATASET_PATH.exists(): | |
| return "β Error: Dataset not downloaded. Please download dataset first." | |
| progress(0, desc="Initializing full training...") | |
| try: | |
| # Create trainer | |
| trainer = IPADTrainer( | |
| device_name=device_name, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr, | |
| mem_dim=mem_dim, | |
| checkpoint_dir=str(CHECKPOINT_DIR), | |
| wandb_project="ipad-vad" if enable_wandb else None, | |
| hf_repo="MSherbinii/ipad-vad-checkpoints" if enable_hf_upload else None | |
| ) | |
| progress(0.05, desc="Loading dataset...") | |
| # Train | |
| trainer.train(str(DATASET_PATH)) | |
| progress(1.0, desc="Training complete!") | |
| # Get final checkpoint | |
| checkpoints = list(CHECKPOINT_DIR.glob(f"{device_name}_*.pth")) | |
| latest_checkpoint = max(checkpoints, key=lambda p: p.stat().st_mtime) if checkpoints else None | |
| result = f""" | |
| β Full baseline training complete! | |
| π Configuration: | |
| - Device: {device_name} | |
| - Epochs: {epochs} | |
| - Batch Size: {batch_size} | |
| - Learning Rate: {lr} | |
| - Memory Dimension: {mem_dim} | |
| πΎ Checkpoints: | |
| - Total saved: {len(checkpoints)} | |
| - Latest: {latest_checkpoint.name if latest_checkpoint else 'None'} | |
| βοΈ HuggingFace Hub: | |
| - {'β Uploaded to MSherbinii/ipad-vad-checkpoints' if enable_hf_upload else 'β Upload disabled'} | |
| π WandB Logging: | |
| - {'β Logged to ipad-vad project' if enable_wandb else 'β Logging disabled'} | |
| π― Expected Performance: | |
| - Target AUC for {device_name}: Check baseline results table | |
| - Paper baseline avg: 68.6% | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"β Training failed: {str(e)}\n\nPlease check the logs for details." | |
| def list_checkpoints() -> str: | |
| """List all saved checkpoints""" | |
| checkpoints = sorted(CHECKPOINT_DIR.glob("*.pth")) | |
| if not checkpoints: | |
| return "π No checkpoints found" | |
| result = "πΎ **Available Checkpoints:**\n\n" | |
| for ckpt in checkpoints: | |
| size_mb = ckpt.stat().st_size / (1024 * 1024) | |
| modified = datetime.fromtimestamp(ckpt.stat().st_mtime).strftime("%Y-%m-%d %H:%M") | |
| result += f"- `{ckpt.name}` ({size_mb:.1f} MB, modified {modified})\n" | |
| return result | |
| # Gradio Interface | |
| with gr.Blocks(title="IPAD VAD Training on ZeroGPU", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π IPAD: Industrial Process Anomaly Detection Training") | |
| gr.Markdown("Train video anomaly detection models on ZeroGPU with the IPAD dataset") | |
| with gr.Tab("π₯ Setup"): | |
| gr.Markdown("## 1οΈβ£ Download Dataset from HF Hub") | |
| gr.Markdown("Downloads the 8.3GB IPAD dataset. **This only needs to be done once** - the dataset is cached.") | |
| download_btn = gr.Button("π₯ Download Dataset", variant="primary", size="lg") | |
| download_output = gr.Textbox(label="Download Status", lines=4) | |
| download_btn.click(setup_dataset, outputs=download_output) | |
| gr.Markdown("---") | |
| gr.Markdown("## 2οΈβ£ Test GPU Access") | |
| gr.Markdown("Verify that ZeroGPU is working and the model loads correctly. **No dataset required.**") | |
| test_btn = gr.Button("π§ͺ Run GPU Test", variant="secondary") | |
| test_output = gr.JSON(label="GPU Test Results") | |
| test_btn.click(quick_gpu_test, outputs=test_output) | |
| with gr.Tab("β‘ Quick Test (10 epochs)"): | |
| gr.Markdown("## Quick Baseline Test") | |
| gr.Markdown("Train for 10 epochs to verify everything works. Takes ~10-15 minutes.") | |
| with gr.Row(): | |
| quick_device = gr.Dropdown( | |
| choices=SYNTHETIC_DEVICES, | |
| value="S01", | |
| label="Device" | |
| ) | |
| quick_epochs = gr.Slider(5, 50, value=10, step=5, label="Epochs") | |
| with gr.Row(): | |
| quick_batch = gr.Slider(1, 8, value=4, step=1, label="Batch Size") | |
| quick_lr = gr.Number(value=1e-4, label="Learning Rate", precision=6) | |
| quick_train_btn = gr.Button("π Start Quick Training", variant="primary", size="lg") | |
| quick_output = gr.Textbox(label="Training Results", lines=15) | |
| quick_train_btn.click( | |
| train_quick_baseline, | |
| inputs=[quick_device, quick_epochs, quick_batch, quick_lr], | |
| outputs=quick_output | |
| ) | |
| with gr.Tab("π― Full Training (200 epochs)"): | |
| gr.Markdown("## Full Baseline Training") | |
| gr.Markdown("Complete 200-epoch training to match paper results. Takes ~2-3 hours.") | |
| with gr.Row(): | |
| full_device = gr.Dropdown( | |
| choices=SYNTHETIC_DEVICES, | |
| value="S01", | |
| label="Training Device" | |
| ) | |
| full_epochs = gr.Slider(50, 300, value=200, step=10, label="Epochs") | |
| with gr.Row(): | |
| full_batch = gr.Slider(1, 8, value=4, step=1, label="Batch Size") | |
| full_lr = gr.Number(value=1e-4, label="Learning Rate", precision=6) | |
| with gr.Row(): | |
| full_mem_dim = gr.Slider(500, 2000, value=2000, step=100, label="Memory Dimension") | |
| full_wandb = gr.Checkbox(value=False, label="Enable WandB Logging") | |
| full_hf_upload = gr.Checkbox(value=True, label="Upload to HF Hub") | |
| full_train_btn = gr.Button("π Start Full Training", variant="primary", size="lg") | |
| full_output = gr.Textbox(label="Training Results", lines=20) | |
| full_train_btn.click( | |
| train_full_baseline, | |
| inputs=[full_device, full_epochs, full_batch, full_lr, full_mem_dim, full_wandb, full_hf_upload], | |
| outputs=full_output | |
| ) | |
| with gr.Tab("πΎ Checkpoints"): | |
| gr.Markdown("## Checkpoint Management") | |
| refresh_btn = gr.Button("π Refresh Checkpoint List") | |
| checkpoint_list = gr.Markdown(value=list_checkpoints()) | |
| refresh_btn.click(list_checkpoints, outputs=checkpoint_list) | |
| gr.Markdown("### Checkpoint Info") | |
| gr.Markdown(""" | |
| - Checkpoints are saved every 10 epochs | |
| - Best model (lowest val loss) is automatically selected | |
| - Files are in PyTorch `.pth` format | |
| - Can be loaded with `torch.load(checkpoint_path)` | |
| """) | |
| with gr.Tab("π Documentation"): | |
| gr.Markdown(""" | |
| ## IPAD VAD Training Guide | |
| ### Quick Start | |
| 1. **Download Dataset**: Go to "Setup" tab and download the IPAD dataset (once) | |
| 2. **GPU Test**: Verify GPU access in "Setup" tab | |
| 3. **Quick Test**: Train for 10 epochs in "Quick Test" tab to verify setup | |
| 4. **Full Training**: Launch 200-epoch training in "Full Training" tab | |
| ### Hardware | |
| - **GPU**: NVIDIA H200 (via ZeroGPU) | |
| - **VRAM**: 80GB HBM3 | |
| - **Duration**: 1-2 hours per full training session | |
| ### Model Architecture | |
| - **Encoder**: Video Swin Transformer (768-dim features) | |
| - **Memory**: 2000-dimensional learnable memory bank | |
| - **Period Module**: 200-class temporal position classifier | |
| - **Decoder**: I3D-based 3D decoder | |
| ### Expected Baseline Results (200 epochs) | |
| | Device | AUC (%) | Device | AUC (%) | | |
| |--------|---------|--------|---------| | |
| | S01 | 69.5 | S07 | 60.6 | | |
| | S02 | 63.9 | S08 | 85.6 | | |
| | S03 | 70.6 | S09 | 71.2 | | |
| | S04 | 58.3 | S10 | 62.2 | | |
| | S05 | 86.2 | S11 | 60.9 | | |
| | S06 | 61.2 | S12 | 67.1 | | |
| | **Avg** | **68.6** | | | | |
| ### Training Configuration | |
| - **Batch Size**: 4 (default, can increase with more VRAM) | |
| - **Learning Rate**: 1e-4 (Adam optimizer) | |
| - **Clip Length**: 16 frames | |
| - **Frame Size**: 256Γ256 pixels | |
| - **Mixed Precision**: FP16 (automatic) | |
| ### Loss Function | |
| ``` | |
| Total Loss = Reconstruction Loss | |
| + 0.0002 Γ Entropy Loss | |
| + 0.02 Γ Period Loss | |
| ``` | |
| ### Resources | |
| - [Paper](https://arxiv.org/abs/2404.15033) | |
| - [Dataset](https://huggingface.co/datasets/MSherbinii/ipad-industrial-anomaly) | |
| - [Original Code](https://github.com/LJF1113/IPAD) | |
| - [Checkpoints](https://huggingface.co/MSherbinii/ipad-vad-checkpoints) | |
| ### Next Steps (SOTA Improvements) | |
| After baseline reproduction: | |
| 1. **Modern Transformer**: Replace Video Swin β MViTv2 (+2-4% AUC) | |
| 2. **Diffusion Decoder**: Add diffusion-based reconstruction (+3-5% AUC) | |
| 3. **Enhanced Memory**: GWN regularization (+1-3% AUC) | |
| **Target**: 75-80% average AUC (vs 68.6% baseline) | |
| """) | |
| if __name__ == "__main__": | |
| # Auto-start training if flag file exists | |
| autostart_flag = Path("./AUTOSTART_TRAINING") | |
| if autostart_flag.exists(): | |
| print("π AUTO-START: Training flag detected, starting training...") | |
| try: | |
| # Read configuration from flag file | |
| config = json.loads(autostart_flag.read_text()) | |
| device = config.get("device", "S01") | |
| epochs = config.get("epochs", 10) | |
| print(f"π Configuration: Device={device}, Epochs={epochs}") | |
| # Remove flag to prevent re-running on every restart | |
| autostart_flag.unlink() | |
| # Download dataset first | |
| print("π₯ Downloading dataset...") | |
| DATASET_PATH = download_and_extract_dataset(cache_dir="./cache") | |
| print(f"β Dataset ready at {DATASET_PATH}") | |
| # Start training in background thread | |
| import threading | |
| def run_training(): | |
| print(f"ποΈ Starting training on {device} for {epochs} epochs...") | |
| result = train_quick_baseline(device, epochs, 4, 1e-4) | |
| print(f"π Training result:\n{result}") | |
| training_thread = threading.Thread(target=run_training, daemon=True) | |
| training_thread.start() | |
| print("β Training started in background!") | |
| except Exception as e: | |
| print(f"β Auto-start failed: {e}") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |