| import torch | |
| def set_kv_buffer_kernel( | |
| k_cache: torch.Tensor, | |
| v_cache: torch.Tensor, | |
| loc: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| fallback: bool = False, | |
| ): | |
| try: | |
| if fallback: | |
| raise RuntimeError("Fallback to torch implementation") | |
| torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v) | |
| except RuntimeError: # ok, fallback to torch implementation | |
| k_cache[loc] = k | |
| v_cache[loc] = v | |
| def weak_ref_tensor(tensor): | |
| return ( | |
| torch.ops.sgl_kernel.weak_ref_tensor(tensor) | |
| if isinstance(tensor, torch.Tensor) | |
| else tensor | |
| ) | |