MSherbinii's picture
Force cache clear on app startup to ensure fresh module loading
6b45962
#!/usr/bin/env python3
"""
IPAD VAD Training Interface on HuggingFace Spaces with ZeroGPU
Updated version with integrated training infrastructure
"""
# IMPORTANT: Clear Python cache first to avoid loading stale modules
import shutil
from pathlib import Path
for pycache in Path('.').rglob('__pycache__'):
shutil.rmtree(pycache, ignore_errors=True)
for pyc in Path('.').rglob('*.pyc'):
pyc.unlink(missing_ok=True)
print("🧹 Cache cleared - loading fresh modules")
import gradio as gr
import torch
import os
import json
from datetime import datetime
import zipfile
from huggingface_hub import hf_hub_download, HfApi
import subprocess
import sys
from typing import Optional, Dict
# Import training infrastructure
from train_hf import IPADTrainer
from dataset import download_and_extract_dataset, DEVICE_NAMES, SYNTHETIC_DEVICES
import spaces # ZeroGPU decorator
# Global state
DATASET_PATH = None
CHECKPOINT_DIR = Path("./checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)
def setup_dataset(progress=gr.Progress()) -> str:
"""Download and extract IPAD dataset from HF Hub"""
global DATASET_PATH
progress(0, desc="Downloading dataset...")
if DATASET_PATH and DATASET_PATH.exists():
return f"βœ… Dataset already available at {DATASET_PATH}"
try:
DATASET_PATH = download_and_extract_dataset(cache_dir="./cache")
progress(1.0, desc="Complete!")
return f"βœ… Dataset downloaded and extracted to {DATASET_PATH}\nπŸ“Š Ready for training!"
except Exception as e:
return f"❌ Error: {str(e)}"
@spaces.GPU(duration=60) # Request GPU for 1 minute
def quick_gpu_test() -> Dict:
"""Quick test to verify GPU access and model loading"""
try:
from IPAD.model.video_swin_transformer import VST
# Check GPU
gpu_available = torch.cuda.is_available()
gpu_name = torch.cuda.get_device_name(0) if gpu_available else "None"
if not gpu_available:
return {
"status": "⚠️ Warning",
"message": "No GPU available",
"gpu_available": False,
"gpu_name": "None"
}
# Load model
model = VST(mem_dim=2000, shrink_thres=0.0025)
model = model.cuda()
# Create dummy input
dummy_input = torch.randn(1, 3, 16, 256, 256).cuda()
# Forward pass
with torch.no_grad():
output = model(dummy_input)
result = {
"status": "βœ… Success",
"message": "GPU test passed!",
"gpu_available": True,
"gpu_name": gpu_name,
"output_shape": str(output['output'].shape),
"attention_shape": str(output['att'].shape),
"period_shape": str(output['recon_index'].shape),
"memory_allocated_gb": f"{torch.cuda.memory_allocated() / 1e9:.2f}",
"memory_reserved_gb": f"{torch.cuda.memory_reserved() / 1e9:.2f}"
}
return result
except Exception as e:
return {
"status": "❌ Error",
"message": str(e),
"gpu_available": torch.cuda.is_available(),
"gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None"
}
@spaces.GPU(duration=3600) # Request GPU for 1 hour
def train_quick_baseline(
device_name: str = "S01",
epochs: int = 10,
batch_size: int = 4,
lr: float = 1e-4,
progress=gr.Progress()
) -> str:
"""Quick baseline training (10 epochs for testing)"""
global DATASET_PATH
# Auto-download dataset if not available
if DATASET_PATH is None or not DATASET_PATH.exists():
progress(0, desc="Dataset not found, downloading...")
try:
DATASET_PATH = download_and_extract_dataset(cache_dir="./cache")
progress(0.05, desc="Dataset ready, starting training...")
except Exception as e:
return f"❌ Error downloading dataset: {str(e)}"
progress(0, desc="Initializing trainer...")
try:
# Create trainer
trainer = IPADTrainer(
device_name=device_name,
epochs=epochs,
batch_size=batch_size,
lr=lr,
mem_dim=2000,
checkpoint_dir=str(CHECKPOINT_DIR),
wandb_project=None, # Disable wandb for quick test
hf_repo=None # Disable auto-upload for quick test
)
progress(0.1, desc="Loading dataset...")
# Train
trainer.train(str(DATASET_PATH))
progress(1.0, desc="Training complete!")
# Get latest checkpoint
checkpoints = list(CHECKPOINT_DIR.glob(f"{device_name}_*.pth"))
latest_checkpoint = max(checkpoints, key=lambda p: p.stat().st_mtime) if checkpoints else None
result = f"""
βœ… Quick baseline training complete!
πŸ“Š Configuration:
- Device: {device_name}
- Epochs: {epochs}
- Batch Size: {batch_size}
- Learning Rate: {lr}
πŸ’Ύ Checkpoint:
- {latest_checkpoint.name if latest_checkpoint else 'No checkpoint saved'}
🎯 Next Steps:
1. Review training metrics
2. Run full 200-epoch training
3. Evaluate on test set
"""
return result
except Exception as e:
return f"❌ Training failed: {str(e)}\n\nPlease check the logs for details."
@spaces.GPU(duration=7200) # Request GPU for 2 hours
def train_full_baseline(
device_name: str = "S01",
epochs: int = 200,
batch_size: int = 4,
lr: float = 1e-4,
mem_dim: int = 2000,
enable_wandb: bool = False,
enable_hf_upload: bool = True,
progress=gr.Progress()
) -> str:
"""Full baseline training (200 epochs)"""
global DATASET_PATH
if DATASET_PATH is None or not DATASET_PATH.exists():
return "❌ Error: Dataset not downloaded. Please download dataset first."
progress(0, desc="Initializing full training...")
try:
# Create trainer
trainer = IPADTrainer(
device_name=device_name,
epochs=epochs,
batch_size=batch_size,
lr=lr,
mem_dim=mem_dim,
checkpoint_dir=str(CHECKPOINT_DIR),
wandb_project="ipad-vad" if enable_wandb else None,
hf_repo="MSherbinii/ipad-vad-checkpoints" if enable_hf_upload else None
)
progress(0.05, desc="Loading dataset...")
# Train
trainer.train(str(DATASET_PATH))
progress(1.0, desc="Training complete!")
# Get final checkpoint
checkpoints = list(CHECKPOINT_DIR.glob(f"{device_name}_*.pth"))
latest_checkpoint = max(checkpoints, key=lambda p: p.stat().st_mtime) if checkpoints else None
result = f"""
βœ… Full baseline training complete!
πŸ“Š Configuration:
- Device: {device_name}
- Epochs: {epochs}
- Batch Size: {batch_size}
- Learning Rate: {lr}
- Memory Dimension: {mem_dim}
πŸ’Ύ Checkpoints:
- Total saved: {len(checkpoints)}
- Latest: {latest_checkpoint.name if latest_checkpoint else 'None'}
☁️ HuggingFace Hub:
- {'βœ… Uploaded to MSherbinii/ipad-vad-checkpoints' if enable_hf_upload else '❌ Upload disabled'}
πŸ“ˆ WandB Logging:
- {'βœ… Logged to ipad-vad project' if enable_wandb else '❌ Logging disabled'}
🎯 Expected Performance:
- Target AUC for {device_name}: Check baseline results table
- Paper baseline avg: 68.6%
"""
return result
except Exception as e:
return f"❌ Training failed: {str(e)}\n\nPlease check the logs for details."
def list_checkpoints() -> str:
"""List all saved checkpoints"""
checkpoints = sorted(CHECKPOINT_DIR.glob("*.pth"))
if not checkpoints:
return "πŸ“ No checkpoints found"
result = "πŸ’Ύ **Available Checkpoints:**\n\n"
for ckpt in checkpoints:
size_mb = ckpt.stat().st_size / (1024 * 1024)
modified = datetime.fromtimestamp(ckpt.stat().st_mtime).strftime("%Y-%m-%d %H:%M")
result += f"- `{ckpt.name}` ({size_mb:.1f} MB, modified {modified})\n"
return result
# Gradio Interface
with gr.Blocks(title="IPAD VAD Training on ZeroGPU", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🏭 IPAD: Industrial Process Anomaly Detection Training")
gr.Markdown("Train video anomaly detection models on ZeroGPU with the IPAD dataset")
with gr.Tab("πŸ“₯ Setup"):
gr.Markdown("## 1️⃣ Download Dataset from HF Hub")
gr.Markdown("Downloads the 8.3GB IPAD dataset. **This only needs to be done once** - the dataset is cached.")
download_btn = gr.Button("πŸ“₯ Download Dataset", variant="primary", size="lg")
download_output = gr.Textbox(label="Download Status", lines=4)
download_btn.click(setup_dataset, outputs=download_output)
gr.Markdown("---")
gr.Markdown("## 2️⃣ Test GPU Access")
gr.Markdown("Verify that ZeroGPU is working and the model loads correctly. **No dataset required.**")
test_btn = gr.Button("πŸ§ͺ Run GPU Test", variant="secondary")
test_output = gr.JSON(label="GPU Test Results")
test_btn.click(quick_gpu_test, outputs=test_output)
with gr.Tab("⚑ Quick Test (10 epochs)"):
gr.Markdown("## Quick Baseline Test")
gr.Markdown("Train for 10 epochs to verify everything works. Takes ~10-15 minutes.")
with gr.Row():
quick_device = gr.Dropdown(
choices=SYNTHETIC_DEVICES,
value="S01",
label="Device"
)
quick_epochs = gr.Slider(5, 50, value=10, step=5, label="Epochs")
with gr.Row():
quick_batch = gr.Slider(1, 8, value=4, step=1, label="Batch Size")
quick_lr = gr.Number(value=1e-4, label="Learning Rate", precision=6)
quick_train_btn = gr.Button("πŸš€ Start Quick Training", variant="primary", size="lg")
quick_output = gr.Textbox(label="Training Results", lines=15)
quick_train_btn.click(
train_quick_baseline,
inputs=[quick_device, quick_epochs, quick_batch, quick_lr],
outputs=quick_output
)
with gr.Tab("🎯 Full Training (200 epochs)"):
gr.Markdown("## Full Baseline Training")
gr.Markdown("Complete 200-epoch training to match paper results. Takes ~2-3 hours.")
with gr.Row():
full_device = gr.Dropdown(
choices=SYNTHETIC_DEVICES,
value="S01",
label="Training Device"
)
full_epochs = gr.Slider(50, 300, value=200, step=10, label="Epochs")
with gr.Row():
full_batch = gr.Slider(1, 8, value=4, step=1, label="Batch Size")
full_lr = gr.Number(value=1e-4, label="Learning Rate", precision=6)
with gr.Row():
full_mem_dim = gr.Slider(500, 2000, value=2000, step=100, label="Memory Dimension")
full_wandb = gr.Checkbox(value=False, label="Enable WandB Logging")
full_hf_upload = gr.Checkbox(value=True, label="Upload to HF Hub")
full_train_btn = gr.Button("πŸš€ Start Full Training", variant="primary", size="lg")
full_output = gr.Textbox(label="Training Results", lines=20)
full_train_btn.click(
train_full_baseline,
inputs=[full_device, full_epochs, full_batch, full_lr, full_mem_dim, full_wandb, full_hf_upload],
outputs=full_output
)
with gr.Tab("πŸ’Ύ Checkpoints"):
gr.Markdown("## Checkpoint Management")
refresh_btn = gr.Button("πŸ”„ Refresh Checkpoint List")
checkpoint_list = gr.Markdown(value=list_checkpoints())
refresh_btn.click(list_checkpoints, outputs=checkpoint_list)
gr.Markdown("### Checkpoint Info")
gr.Markdown("""
- Checkpoints are saved every 10 epochs
- Best model (lowest val loss) is automatically selected
- Files are in PyTorch `.pth` format
- Can be loaded with `torch.load(checkpoint_path)`
""")
with gr.Tab("πŸ“Š Documentation"):
gr.Markdown("""
## IPAD VAD Training Guide
### Quick Start
1. **Download Dataset**: Go to "Setup" tab and download the IPAD dataset (once)
2. **GPU Test**: Verify GPU access in "Setup" tab
3. **Quick Test**: Train for 10 epochs in "Quick Test" tab to verify setup
4. **Full Training**: Launch 200-epoch training in "Full Training" tab
### Hardware
- **GPU**: NVIDIA H200 (via ZeroGPU)
- **VRAM**: 80GB HBM3
- **Duration**: 1-2 hours per full training session
### Model Architecture
- **Encoder**: Video Swin Transformer (768-dim features)
- **Memory**: 2000-dimensional learnable memory bank
- **Period Module**: 200-class temporal position classifier
- **Decoder**: I3D-based 3D decoder
### Expected Baseline Results (200 epochs)
| Device | AUC (%) | Device | AUC (%) |
|--------|---------|--------|---------|
| S01 | 69.5 | S07 | 60.6 |
| S02 | 63.9 | S08 | 85.6 |
| S03 | 70.6 | S09 | 71.2 |
| S04 | 58.3 | S10 | 62.2 |
| S05 | 86.2 | S11 | 60.9 |
| S06 | 61.2 | S12 | 67.1 |
| **Avg** | **68.6** | | |
### Training Configuration
- **Batch Size**: 4 (default, can increase with more VRAM)
- **Learning Rate**: 1e-4 (Adam optimizer)
- **Clip Length**: 16 frames
- **Frame Size**: 256Γ—256 pixels
- **Mixed Precision**: FP16 (automatic)
### Loss Function
```
Total Loss = Reconstruction Loss
+ 0.0002 Γ— Entropy Loss
+ 0.02 Γ— Period Loss
```
### Resources
- [Paper](https://arxiv.org/abs/2404.15033)
- [Dataset](https://huggingface.co/datasets/MSherbinii/ipad-industrial-anomaly)
- [Original Code](https://github.com/LJF1113/IPAD)
- [Checkpoints](https://huggingface.co/MSherbinii/ipad-vad-checkpoints)
### Next Steps (SOTA Improvements)
After baseline reproduction:
1. **Modern Transformer**: Replace Video Swin β†’ MViTv2 (+2-4% AUC)
2. **Diffusion Decoder**: Add diffusion-based reconstruction (+3-5% AUC)
3. **Enhanced Memory**: GWN regularization (+1-3% AUC)
**Target**: 75-80% average AUC (vs 68.6% baseline)
""")
if __name__ == "__main__":
# Auto-start training if flag file exists
autostart_flag = Path("./AUTOSTART_TRAINING")
if autostart_flag.exists():
print("πŸš€ AUTO-START: Training flag detected, starting training...")
try:
# Read configuration from flag file
config = json.loads(autostart_flag.read_text())
device = config.get("device", "S01")
epochs = config.get("epochs", 10)
print(f"πŸ“Š Configuration: Device={device}, Epochs={epochs}")
# Remove flag to prevent re-running on every restart
autostart_flag.unlink()
# Download dataset first
print("πŸ“₯ Downloading dataset...")
DATASET_PATH = download_and_extract_dataset(cache_dir="./cache")
print(f"βœ… Dataset ready at {DATASET_PATH}")
# Start training in background thread
import threading
def run_training():
print(f"πŸ‹οΈ Starting training on {device} for {epochs} epochs...")
result = train_quick_baseline(device, epochs, 4, 1e-4)
print(f"πŸ“Š Training result:\n{result}")
training_thread = threading.Thread(target=run_training, daemon=True)
training_thread.start()
print("βœ… Training started in background!")
except Exception as e:
print(f"❌ Auto-start failed: {e}")
demo.launch(server_name="0.0.0.0", server_port=7860)