File size: 3,642 Bytes
d14d520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""

Trigger GPU training through Gradio interface

Uses gradio_client to call the training endpoint

"""
import time
from datetime import datetime

print("="*70)
print("πŸš€ IPAD VAD Training Trigger")
print("="*70)
print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print()

# Method 1: Direct function call (since we're in the same process)
print("[Method 1] Direct function call (fastest)")
print("-" * 70)

try:
    # Import the training function directly
    from train_hf import IPADTrainer

    print("βœ… Imported IPADTrainer successfully")
    print()

    # Create trainer with quick test parameters
    # Using 1 epoch for smoke test on CPU, will do full training on GPU
    print("πŸ“‹ Configuration:")
    print("   Device: S01 (Conveyor Belt)")
    print("   Epochs: 1 (smoke test on CPU)")
    print("   Batch Size: 2 (reduced for CPU)")
    print("   Learning Rate: 1e-4")
    print("   Memory Dimension: 2000")
    print("   ⚠️  Note: This is a CPU smoke test. Full GPU training needs Gradio interface.")
    print()

    trainer = IPADTrainer(
        device_name="S01",
        epochs=1,  # Just 1 epoch to verify training works
        batch_size=2,  # Reduced for CPU
        lr=1e-4,
        mem_dim=2000,
        checkpoint_dir="./checkpoints",
        wandb_project=None,  # Disable wandb for quick test
        hf_repo=None  # Disable HF upload for quick test
    )

    print("βœ… Trainer initialized")
    print()

    # Check CUDA availability
    import torch
    print(f"πŸ” Checking GPU availability...")
    print(f"   CUDA Available: {torch.cuda.is_available()}")
    print(f"   Device Count: {torch.cuda.device_count()}")
    if torch.cuda.is_available():
        print(f"   Device Name: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("   ⚠️  No GPU detected - this will run on CPU (very slow)")
        print("   ⚠️  ZeroGPU allocation only works through Gradio @spaces.GPU decorator")
    print()

    # Start training
    dataset_path = "/app/cache/IPAD_dataset"
    print(f"πŸ‹οΈ  Starting training...")
    print(f"   Dataset: {dataset_path}")
    print(f"   This will take ~10-15 minutes on GPU, several hours on CPU")
    print()
    print("="*70)
    print()

    # Train
    start_time = time.time()
    trainer.train(dataset_path)
    end_time = time.time()

    print()
    print("="*70)
    print(f"βœ… Training completed in {(end_time - start_time) / 60:.1f} minutes!")
    print("="*70)

    # Check checkpoints
    from pathlib import Path
    checkpoint_dir = Path("./checkpoints")
    checkpoints = list(checkpoint_dir.glob("S01_*.pth"))

    if checkpoints:
        print()
        print("πŸ’Ύ Checkpoints saved:")
        for ckpt in sorted(checkpoints):
            size_mb = ckpt.stat().st_size / (1024 * 1024)
            print(f"   - {ckpt.name} ({size_mb:.1f} MB)")
    else:
        print()
        print("⚠️  No checkpoints found - check logs for errors")

except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()
    print()
    print("="*70)
    print("πŸ’‘ Troubleshooting:")
    print("   1. Check GPU availability (might need @spaces.GPU decorator)")
    print("   2. Check dataset path exists")
    print("   3. Check logs for detailed error messages")
    print("="*70)

print()
print("="*70)
print("🏁 Training trigger script finished")
print("="*70)