Upload main.py
Browse files
main.py
CHANGED
|
@@ -38,6 +38,7 @@ if not os.environ.get("TLM_DATA_DIR"):
|
|
| 38 |
os.environ["TLM_DATA_DIR"] = "/tmp/tlm_data"
|
| 39 |
|
| 40 |
# Select GPU if available
|
|
|
|
| 41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
logger.info(f"Using device: {device}")
|
| 43 |
|
|
@@ -47,15 +48,6 @@ try:
|
|
| 47 |
if torch.cuda.is_available():
|
| 48 |
torch.cuda.empty_cache()
|
| 49 |
logger.info(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
|
| 50 |
-
|
| 51 |
-
# Apply PyTorch device fixes
|
| 52 |
-
if hasattr(torch, 'set_default_device'):
|
| 53 |
-
torch.set_default_device("cpu")
|
| 54 |
-
_original_device = torch.device
|
| 55 |
-
def patched_device(device_str=None):
|
| 56 |
-
return _original_device("cpu") if device_str is None else _original_device("cpu")
|
| 57 |
-
torch.device = patched_device
|
| 58 |
-
logger.info("✅ Applied PyTorch device attribute fix")
|
| 59 |
except Exception as e:
|
| 60 |
logger.warning(f"Error with PyTorch setup: {e}")
|
| 61 |
|
|
|
|
| 38 |
os.environ["TLM_DATA_DIR"] = "/tmp/tlm_data"
|
| 39 |
|
| 40 |
# Select GPU if available
|
| 41 |
+
import torch
|
| 42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
logger.info(f"Using device: {device}")
|
| 44 |
|
|
|
|
| 48 |
if torch.cuda.is_available():
|
| 49 |
torch.cuda.empty_cache()
|
| 50 |
logger.info(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
except Exception as e:
|
| 52 |
logger.warning(f"Error with PyTorch setup: {e}")
|
| 53 |
|