ruGPT3XL-8k / triton_utils.py
evilfreelancer's picture
Upload folder using huggingface_hub
bec1270 verified
"""Optional torch.compile Inductor (Triton-backed kernels on CUDA) for RuGPT-3 XL."""
from typing import Optional
import torch
from torch import nn
def triton_runtime_available() -> bool:
"""True if CUDA and the triton package are importable."""
if not torch.cuda.is_available():
return False
try:
import triton # noqa: F401
return True
except ImportError:
return False
def compile_rugpt3xl_for_triton(
model: nn.Module,
mode: str = "max-autotune",
fullgraph: bool = False,
dynamic: Optional[bool] = None,
) -> nn.Module:
"""Apply torch.compile with Inductor backend (generates Triton for many ops).
Does not change the mathematical definition of the model; only the
implementation on GPU. Requires PyTorch 2.x with CUDA.
"""
if not hasattr(torch, "compile"):
raise RuntimeError("torch.compile requires PyTorch 2.0+")
if dynamic is None:
dynamic = True
return torch.compile(
model,
mode=mode,
fullgraph=fullgraph,
dynamic=dynamic,
)