Kernels
aiter-kernels / build /torch-rocm /utils /common_utils.py
kernels-bot's picture
Uploaded using `kernel-builder`.
2976eec verified
Raw
History Blame Contribute Delete
1.04 kB
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
from typing import List
import torch
import triton
import json
def prev_power_of_2(x: int) -> int:
out = triton.next_power_of_2(x)
return out // 2 if out > x else out
STATIC_MAX_SEQ_LENS: List[int] = []
USE_RUNTIME_MAX_SEQ_LEN: bool = False
def autotune_max_seq_len(runtime_max_seq_len: int) -> int:
global USE_RUNTIME_MAX_SEQ_LEN
if USE_RUNTIME_MAX_SEQ_LEN:
return prev_power_of_2(runtime_max_seq_len)
else:
if STATIC_MAX_SEQ_LENS == []:
return 1
for max_len in STATIC_MAX_SEQ_LENS:
if max_len >= runtime_max_seq_len:
return max_len
return STATIC_MAX_SEQ_LENS[-1]
def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor:
if x.stride(-1) == 1:
return x
return x.contiguous()
def serialize_dict(d: dict) -> str:
return json.dumps(d)
def deserialize_str(s: str) -> dict:
return json.loads(s)