Testing / config.py
Sidak Singh
question boundary works
7b7db64
import os
from dotenv import load_dotenv
import torch
# Load environment variables from .env file
load_dotenv()
class Config:
"""Configuration class for device and model settings"""
def __init__(self):
# Get USE_CUDA from environment variable, default to False
self.use_cuda = os.getenv('USE_CUDA', 'false').lower() == 'true'
# Determine device based on CUDA availability and config
self.device = self._get_device()
# Set compute type based on device
self.compute_type = self._get_compute_type()
print(f"🔧 Config initialized:")
print(f" USE_CUDA environment variable: {os.getenv('USE_CUDA', 'false')}")
print(f" CUDA available: {torch.cuda.is_available()}")
print(f" Selected device: {self.device}")
print(f" Compute type: {self.compute_type}")
def _get_device(self):
"""Determine the appropriate device"""
if self.use_cuda and torch.cuda.is_available():
return "cuda"
elif self.use_cuda and not torch.cuda.is_available():
print("⚠️ Warning: CUDA requested but not available, falling back to CPU")
return "cpu"
else:
return "cpu"
def _get_compute_type(self):
"""Get appropriate compute type for the device"""
if self.device == "cuda":
return "float16" # More efficient for CUDA
else:
return "int8" # More efficient for CPU
def get_whisper_config(self):
"""Get configuration for Whisper model"""
return {
"device": self.device,
"compute_type": self.compute_type
}
def get_transformers_device(self):
"""Get device configuration for transformers (RoBERTa, etc.)"""
if self.device == "cuda":
return 0 # Use first CUDA device
else:
return -1 # Use CPU
def get_device_info(self):
"""Get detailed device information"""
info = {
"device": self.device,
"compute_type": self.compute_type,
"cuda_available": torch.cuda.is_available(),
"use_cuda_requested": self.use_cuda
}
if torch.cuda.is_available():
info.update({
"cuda_device_count": torch.cuda.device_count(),
"cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else None,
"cuda_memory_total": torch.cuda.get_device_properties(0).total_memory if torch.cuda.device_count() > 0 else None
})
return info
# Create global config instance
config = Config()