af6e330
1
2
3
def identity_with_cast(q, k, v, offset: int = 0): return q.to(v.dtype), k.to(v.dtype), v