File size: 2,338 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""

Memory Optimization Module Test

Quick test to verify all components work correctly.

"""
import torch
import sys
import os

# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from memory_optimization import (
    UnifiedMemoryManager,
    MemoryOptimizationConfig,
    get_shared_model,
    get_tensor,
    return_tensor,
    get_memory_stats,
    clear_memory
)

def test_memory_optimization():
    """Test memory optimization module"""
    print("=" * 70)
    print("Testing Memory Optimization Module")
    print("=" * 70)
    
    # Initialize config
    config = MemoryOptimizationConfig(
        use_shared_model=True,
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    
    print(f"\n[CONFIG] Device: {config.device}")
    print(f"[CONFIG] Shared Model: {config.use_shared_model}")
    
    # Initialize manager
    manager = UnifiedMemoryManager(config)
    print("\n[OK] UnifiedMemoryManager initialized")
    
    # Test tensor pooling
    print("\n[TEST] Tensor Pooling...")
    tensor1 = manager.get_tensor((10, 1024), dtype=torch.float32)
    print(f"  [OK] Created tensor: {tensor1.shape}, device: {tensor1.device}")
    
    manager.return_tensor(tensor1)
    print("  [OK] Returned tensor to pool")
    
    tensor2 = manager.get_tensor((10, 1024), dtype=torch.float32)
    print(f"  [OK] Retrieved tensor from pool: {tensor2.shape}")
    
    # Test shared model (if available)
    print("\n[TEST] Shared Model...")
    try:
        # This will use shared Qwen model if available
        model = manager.get_shared_model("Qwen/Qwen3-0.6B", "transformer")
        print(f"  [OK] Got shared model: {type(model).__name__}")
    except Exception as e:
        print(f"  [WARN] Could not get shared model: {e}")
    
    # Test memory stats
    print("\n[TEST] Memory Stats...")
    stats = manager.get_memory_stats()
    print(f"  [OK] Got memory stats: {len(stats)} categories")
    
    # Test cleanup
    print("\n[TEST] Memory Cleanup...")
    manager.clear_all_memory()
    print("  [OK] Memory cleared")
    
    print("\n" + "=" * 70)
    print("[SUCCESS] All tests passed!")
    print("=" * 70)

if __name__ == "__main__":
    test_memory_optimization()