File size: 2,663 Bytes
7b7db64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from dotenv import load_dotenv
import torch

# Load environment variables from .env file
load_dotenv()

class Config:
    """Configuration class for device and model settings"""

    def __init__(self):
        # Get USE_CUDA from environment variable, default to False
        self.use_cuda = os.getenv('USE_CUDA', 'false').lower() == 'true'

        # Determine device based on CUDA availability and config
        self.device = self._get_device()

        # Set compute type based on device
        self.compute_type = self._get_compute_type()

        print(f"🔧 Config initialized:")
        print(f"   USE_CUDA environment variable: {os.getenv('USE_CUDA', 'false')}")
        print(f"   CUDA available: {torch.cuda.is_available()}")
        print(f"   Selected device: {self.device}")
        print(f"   Compute type: {self.compute_type}")

    def _get_device(self):
        """Determine the appropriate device"""
        if self.use_cuda and torch.cuda.is_available():
            return "cuda"
        elif self.use_cuda and not torch.cuda.is_available():
            print("⚠️  Warning: CUDA requested but not available, falling back to CPU")
            return "cpu"
        else:
            return "cpu"

    def _get_compute_type(self):
        """Get appropriate compute type for the device"""
        if self.device == "cuda":
            return "float16"  # More efficient for CUDA
        else:
            return "int8"     # More efficient for CPU

    def get_whisper_config(self):
        """Get configuration for Whisper model"""
        return {
            "device": self.device,
            "compute_type": self.compute_type
        }

    def get_transformers_device(self):
        """Get device configuration for transformers (RoBERTa, etc.)"""
        if self.device == "cuda":
            return 0  # Use first CUDA device
        else:
            return -1  # Use CPU

    def get_device_info(self):
        """Get detailed device information"""
        info = {
            "device": self.device,
            "compute_type": self.compute_type,
            "cuda_available": torch.cuda.is_available(),
            "use_cuda_requested": self.use_cuda
        }

        if torch.cuda.is_available():
            info.update({
                "cuda_device_count": torch.cuda.device_count(),
                "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else None,
                "cuda_memory_total": torch.cuda.get_device_properties(0).total_memory if torch.cuda.device_count() > 0 else None
            })

        return info

# Create global config instance
config = Config()