File size: 4,263 Bytes
5dbca28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Shared Utilities
================
Common helpers used across the pipeline.
"""

import logging
import os

logger = logging.getLogger(__name__)


def get_device(preferred: str = "auto") -> str:
    """Determine the best available device for PyTorch.
    
    Auto-detects CUDA GPU and falls back to CPU.
    
    Args:
        preferred: 'auto' (detect GPU), 'cuda', or 'cpu'.
                   'auto' will use CUDA if available, else CPU.
    
    Returns:
        Device string: 'cuda' or 'cpu'.
    """
    if preferred == "cpu":
        return "cpu"
    
    try:
        import torch
        if torch.cuda.is_available():
            gpu_name = torch.cuda.get_device_name(0)
            vram_mb = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024
            logger.info(f"GPU detected: {gpu_name} ({vram_mb:.0f} MB VRAM)")
            return "cuda"
        else:
            if preferred == "cuda":
                logger.warning("CUDA requested but not available, falling back to CPU")
            return "cpu"
    except ImportError:
        return "cpu"


def get_available_ram_gb() -> float:
    """Get available system RAM in GB.
    
    Uses psutil if available, falls back to OS-level checks.
    Returns a conservative estimate if detection fails.
    """
    # Try psutil first (most accurate)
    try:
        import psutil
        return psutil.virtual_memory().available / (1024 ** 3)
    except ImportError:
        pass
    
    # Windows fallback: use ctypes
    try:
        import ctypes
        class MEMORYSTATUSEX(ctypes.Structure):
            _fields_ = [
                ("dwLength", ctypes.c_ulong),
                ("dwMemoryLoad", ctypes.c_ulong),
                ("ullTotalPhys", ctypes.c_ulonglong),
                ("ullAvailPhys", ctypes.c_ulonglong),
                ("ullTotalPageFile", ctypes.c_ulonglong),
                ("ullAvailPageFile", ctypes.c_ulonglong),
                ("ullTotalVirtual", ctypes.c_ulonglong),
                ("ullAvailVirtual", ctypes.c_ulonglong),
                ("ullAvailExtendedVirtual", ctypes.c_ulonglong),
            ]
        stat = MEMORYSTATUSEX()
        stat.dwLength = ctypes.sizeof(stat)
        ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat))
        return stat.ullAvailPhys / (1024 ** 3)
    except Exception:
        pass
    
    # Last resort: assume 8GB total, ~4GB available (conservative)
    logger.warning("Could not detect available RAM, assuming 4 GB available")
    return 4.0


def get_safe_train_samples(total_examples: int) -> int:
    """Determine a safe number of training samples based on available RAM.
    
    CUAD training is RAM-intensive because:
    - Each example contains a ~54K char contract context
    - 22K examples Γ— 54K chars = ~1.2 GB just for raw strings  
    - Column conversion + tokenization roughly triples peak memory
    - Sliding window tokenization expands 22K examples β†’ ~1.1M features
    
    Memory estimates (approximate):
        - 1000 samples β†’ ~1 GB peak RAM
        - 5000 samples β†’ ~3 GB peak RAM  
        - 10000 samples β†’ ~5 GB peak RAM
        - 22000 samples β†’ ~10 GB peak RAM
    
    Args:
        total_examples: Total number of available training examples.
    
    Returns:
        Safe number of samples to use.
    """
    available_gb = get_available_ram_gb()
    
    # Reserve ~3 GB for OS + Python + PyTorch model + overhead
    usable_gb = max(available_gb - 3.0, 1.0)
    
    # ~500 MB per 1000 samples during peak tokenization
    safe_samples = int(usable_gb * 2000)
    
    # Clamp to actual dataset size
    safe_samples = min(safe_samples, total_examples)
    
    # Minimum floor of 500 (below this, training is meaningless)
    safe_samples = max(safe_samples, min(500, total_examples))
    
    if safe_samples < total_examples:
        logger.warning(
            f"Available RAM: {available_gb:.1f} GB β†’ limiting to {safe_samples} samples "
            f"(out of {total_examples}) to prevent crashes. "
            f"Use --max_train_samples {total_examples} to force all."
        )
    else:
        logger.info(
            f"Available RAM: {available_gb:.1f} GB β†’ using all {total_examples} samples"
        )
    
    return safe_samples