drbh
commited on
Commit
·
4080f9c
1
Parent(s):
b0d3c12
fix: adjust types
Browse files- flash_attn/flash_api.cpp +102 -32
flash_attn/flash_api.cpp
CHANGED
|
@@ -1507,45 +1507,61 @@ mha_fwd(const at::Tensor &q, // batch_size x seqle
|
|
| 1507 |
float softcap_float = static_cast<float>(softcap);
|
| 1508 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1509 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1510 |
-
|
| 1511 |
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
| 1512 |
}
|
| 1513 |
|
| 1514 |
std::vector<at::Tensor>
|
| 1515 |
-
mha_varlen_fwd(
|
| 1516 |
-
const at::Tensor &k,
|
| 1517 |
-
const at::Tensor &v,
|
| 1518 |
-
const
|
| 1519 |
-
const at::Tensor &cu_seqlens_q,
|
| 1520 |
-
const at::Tensor &cu_seqlens_k,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1521 |
const int64_t max_seqlen_q,
|
| 1522 |
const int64_t max_seqlen_k,
|
| 1523 |
const double p_dropout,
|
| 1524 |
const double softmax_scale,
|
| 1525 |
-
bool
|
|
|
|
| 1526 |
const int64_t window_size_left,
|
| 1527 |
const int64_t window_size_right,
|
| 1528 |
const double softcap,
|
| 1529 |
const bool return_softmax,
|
| 1530 |
-
const
|
| 1531 |
-
|
| 1532 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
| 1533 |
-
|
| 1534 |
// Prepare the optional arguments as non-const references.
|
| 1535 |
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
| 1536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1537 |
if (!out.has_value()){
|
| 1538 |
out = torch::empty_like(q);
|
| 1539 |
}
|
| 1540 |
-
|
| 1541 |
// Convert double to float and int64_t to int.
|
| 1542 |
float p_dropout_float = static_cast<float>(p_dropout);
|
| 1543 |
float softmax_scale_float = static_cast<float>(softmax_scale);
|
| 1544 |
float softcap_float = static_cast<float>(softcap);
|
|
|
|
|
|
|
| 1545 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1546 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1547 |
-
|
| 1548 |
-
return FLASH_NAMESPACE::mha_varlen_fwd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1549 |
}
|
| 1550 |
|
| 1551 |
std::vector<at::Tensor>
|
|
@@ -1570,7 +1586,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
| 1570 |
std::optional<at::Tensor> &rng_state) {
|
| 1571 |
|
| 1572 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
| 1573 |
-
|
| 1574 |
// Prepare the optional arguments as non-const references.
|
| 1575 |
std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
|
| 1576 |
std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
|
|
@@ -1584,7 +1600,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
| 1584 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1585 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1586 |
|
| 1587 |
-
return FLASH_NAMESPACE::mha_bwd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1588 |
}
|
| 1589 |
|
| 1590 |
|
|
@@ -1595,12 +1619,17 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
| 1595 |
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
| 1596 |
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
| 1597 |
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
|
|
|
|
|
|
|
|
|
| 1598 |
const at::Tensor &cu_seqlens_q, // batch_size + 1
|
| 1599 |
const at::Tensor &cu_seqlens_k, // batch_size + 1
|
|
|
|
| 1600 |
const int64_t max_seqlen_q,
|
| 1601 |
const int64_t max_seqlen_k,
|
| 1602 |
const double p_dropout,
|
| 1603 |
const double softmax_scale,
|
|
|
|
| 1604 |
const bool is_causal,
|
| 1605 |
const int64_t window_size_left,
|
| 1606 |
const int64_t window_size_right,
|
|
@@ -1608,17 +1637,36 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
| 1608 |
const bool deterministic,
|
| 1609 |
std::optional<at::Generator> gen_,
|
| 1610 |
std::optional<at::Tensor> &rng_state) {
|
| 1611 |
-
|
| 1612 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
| 1613 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1614 |
// Convert double to float and int64_t to int.
|
| 1615 |
float p_dropout_float = static_cast<float>(p_dropout);
|
| 1616 |
float softmax_scale_float = static_cast<float>(softmax_scale);
|
| 1617 |
float softcap_float = static_cast<float>(softcap);
|
|
|
|
|
|
|
| 1618 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1619 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1620 |
|
| 1621 |
-
return FLASH_NAMESPACE::mha_varlen_bwd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1622 |
}
|
| 1623 |
|
| 1624 |
std::vector<at::Tensor>
|
|
@@ -1643,25 +1691,47 @@ mha_fwd_kvcache(const at::Tensor &q, // batch
|
|
| 1643 |
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
| 1644 |
const int64_t num_splits
|
| 1645 |
) {
|
| 1646 |
-
|
| 1647 |
-
// Prepare the optional arguments as
|
| 1648 |
-
std::optional<at::Tensor> k = k_.has_value() ? std::optional<at::Tensor>(
|
| 1649 |
-
std::optional<at::Tensor> v = v_.has_value() ? std::optional<at::Tensor>(
|
| 1650 |
-
std::optional<at::Tensor> seqlens_k = seqlens_k_.has_value() ? std::optional<at::Tensor>(
|
| 1651 |
-
std::optional<at::Tensor> rotary_cos = rotary_cos_.has_value() ? std::optional<at::Tensor>(
|
| 1652 |
-
std::optional<at::Tensor> rotary_sin = rotary_sin_.has_value() ? std::optional<at::Tensor>(
|
| 1653 |
-
std::optional<at::Tensor> cache_batch_idx = cache_batch_idx_.has_value() ? std::optional<at::Tensor>(
|
| 1654 |
-
std::optional<at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<at::Tensor>(
|
|
|
|
|
|
|
| 1655 |
std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
|
| 1656 |
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
| 1657 |
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
| 1658 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1659 |
// Convert double to float and int64_t to int.
|
| 1660 |
float softmax_scale_float = static_cast<float>(softmax_scale);
|
| 1661 |
float softcap_float = static_cast<float>(softcap);
|
| 1662 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1663 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1664 |
int num_splits_int = static_cast<int>(num_splits);
|
| 1665 |
-
|
| 1666 |
-
return FLASH_NAMESPACE::mha_fwd_kvcache(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1667 |
}
|
|
|
|
| 1507 |
float softcap_float = static_cast<float>(softcap);
|
| 1508 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1509 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1510 |
+
|
| 1511 |
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
| 1512 |
}
|
| 1513 |
|
| 1514 |
std::vector<at::Tensor>
|
| 1515 |
+
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
| 1516 |
+
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
|
| 1517 |
+
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
|
| 1518 |
+
const std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
| 1519 |
+
const at::Tensor &cu_seqlens_q, // b+1
|
| 1520 |
+
const at::Tensor &cu_seqlens_k, // b+1
|
| 1521 |
+
const std::optional<at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
| 1522 |
+
const std::optional<const at::Tensor> &leftpad_k_, // batch_size
|
| 1523 |
+
const std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
| 1524 |
+
const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
| 1525 |
const int64_t max_seqlen_q,
|
| 1526 |
const int64_t max_seqlen_k,
|
| 1527 |
const double p_dropout,
|
| 1528 |
const double softmax_scale,
|
| 1529 |
+
const bool zero_tensors,
|
| 1530 |
+
const bool is_causal,
|
| 1531 |
const int64_t window_size_left,
|
| 1532 |
const int64_t window_size_right,
|
| 1533 |
const double softcap,
|
| 1534 |
const bool return_softmax,
|
| 1535 |
+
const std::optional<at::Generator> gen_) {
|
|
|
|
| 1536 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
|
|
|
| 1537 |
// Prepare the optional arguments as non-const references.
|
| 1538 |
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
| 1539 |
+
std::optional<at::Tensor> seqused_k = seqused_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(seqused_k_.value())) : std::nullopt;
|
| 1540 |
+
std::optional<const at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<const at::Tensor>(leftpad_k_.value()) : std::nullopt;
|
| 1541 |
+
std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
|
| 1542 |
+
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
| 1543 |
+
|
| 1544 |
if (!out.has_value()){
|
| 1545 |
out = torch::empty_like(q);
|
| 1546 |
}
|
|
|
|
| 1547 |
// Convert double to float and int64_t to int.
|
| 1548 |
float p_dropout_float = static_cast<float>(p_dropout);
|
| 1549 |
float softmax_scale_float = static_cast<float>(softmax_scale);
|
| 1550 |
float softcap_float = static_cast<float>(softcap);
|
| 1551 |
+
int max_seqlen_q_int = static_cast<int>(max_seqlen_q);
|
| 1552 |
+
int max_seqlen_k_int = static_cast<int>(max_seqlen_k);
|
| 1553 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1554 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1555 |
+
|
| 1556 |
+
return FLASH_NAMESPACE::mha_varlen_fwd(
|
| 1557 |
+
const_cast<at::Tensor &>(q), k, v, out,
|
| 1558 |
+
cu_seqlens_q, cu_seqlens_k,
|
| 1559 |
+
seqused_k, leftpad_k, block_table, alibi_slopes,
|
| 1560 |
+
max_seqlen_q_int, max_seqlen_k_int,
|
| 1561 |
+
p_dropout_float, softmax_scale_float,
|
| 1562 |
+
zero_tensors, is_causal,
|
| 1563 |
+
window_size_left_int, window_size_right_int,
|
| 1564 |
+
softcap_float, return_softmax, gen);
|
| 1565 |
}
|
| 1566 |
|
| 1567 |
std::vector<at::Tensor>
|
|
|
|
| 1586 |
std::optional<at::Tensor> &rng_state) {
|
| 1587 |
|
| 1588 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
| 1589 |
+
|
| 1590 |
// Prepare the optional arguments as non-const references.
|
| 1591 |
std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
|
| 1592 |
std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
|
|
|
|
| 1600 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1601 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1602 |
|
| 1603 |
+
return FLASH_NAMESPACE::mha_bwd(
|
| 1604 |
+
const_cast<at::Tensor &>(dout),
|
| 1605 |
+
q, k, v, out, softmax_lse,
|
| 1606 |
+
dq, dk, dv, alibi_slopes,
|
| 1607 |
+
p_dropout_float, softmax_scale_float,
|
| 1608 |
+
is_causal,
|
| 1609 |
+
window_size_left_int, window_size_right_int,
|
| 1610 |
+
softcap_float, deterministic,
|
| 1611 |
+
gen, rng_state);
|
| 1612 |
}
|
| 1613 |
|
| 1614 |
|
|
|
|
| 1619 |
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
| 1620 |
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
| 1621 |
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
| 1622 |
+
const std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
| 1623 |
+
const std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
| 1624 |
+
const std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
| 1625 |
const at::Tensor &cu_seqlens_q, // batch_size + 1
|
| 1626 |
const at::Tensor &cu_seqlens_k, // batch_size + 1
|
| 1627 |
+
const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
| 1628 |
const int64_t max_seqlen_q,
|
| 1629 |
const int64_t max_seqlen_k,
|
| 1630 |
const double p_dropout,
|
| 1631 |
const double softmax_scale,
|
| 1632 |
+
const bool zero_tensors,
|
| 1633 |
const bool is_causal,
|
| 1634 |
const int64_t window_size_left,
|
| 1635 |
const int64_t window_size_right,
|
|
|
|
| 1637 |
const bool deterministic,
|
| 1638 |
std::optional<at::Generator> gen_,
|
| 1639 |
std::optional<at::Tensor> &rng_state) {
|
| 1640 |
+
|
| 1641 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
| 1642 |
+
|
| 1643 |
+
// Prepare the optional arguments as non-const references.
|
| 1644 |
+
std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
|
| 1645 |
+
std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
|
| 1646 |
+
std::optional<at::Tensor> dv = dv_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dv_.value())) : std::nullopt;
|
| 1647 |
+
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
| 1648 |
+
|
| 1649 |
// Convert double to float and int64_t to int.
|
| 1650 |
float p_dropout_float = static_cast<float>(p_dropout);
|
| 1651 |
float softmax_scale_float = static_cast<float>(softmax_scale);
|
| 1652 |
float softcap_float = static_cast<float>(softcap);
|
| 1653 |
+
int max_seqlen_q_int = static_cast<int>(max_seqlen_q);
|
| 1654 |
+
int max_seqlen_k_int = static_cast<int>(max_seqlen_k);
|
| 1655 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1656 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1657 |
|
| 1658 |
+
return FLASH_NAMESPACE::mha_varlen_bwd(
|
| 1659 |
+
const_cast<at::Tensor &>(dout),
|
| 1660 |
+
q, k, v, out, softmax_lse,
|
| 1661 |
+
dq, dk, dv,
|
| 1662 |
+
cu_seqlens_q, cu_seqlens_k,
|
| 1663 |
+
alibi_slopes,
|
| 1664 |
+
max_seqlen_q_int, max_seqlen_k_int,
|
| 1665 |
+
p_dropout_float, softmax_scale_float,
|
| 1666 |
+
zero_tensors, is_causal,
|
| 1667 |
+
window_size_left_int, window_size_right_int,
|
| 1668 |
+
softcap_float, deterministic,
|
| 1669 |
+
gen, rng_state);
|
| 1670 |
}
|
| 1671 |
|
| 1672 |
std::vector<at::Tensor>
|
|
|
|
| 1691 |
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
| 1692 |
const int64_t num_splits
|
| 1693 |
) {
|
| 1694 |
+
|
| 1695 |
+
// Prepare the optional arguments as const references where needed
|
| 1696 |
+
std::optional<const at::Tensor> k = k_.has_value() ? std::optional<const at::Tensor>(k_.value()) : std::nullopt;
|
| 1697 |
+
std::optional<const at::Tensor> v = v_.has_value() ? std::optional<const at::Tensor>(v_.value()) : std::nullopt;
|
| 1698 |
+
std::optional<const at::Tensor> seqlens_k = seqlens_k_.has_value() ? std::optional<const at::Tensor>(seqlens_k_.value()) : std::nullopt;
|
| 1699 |
+
std::optional<const at::Tensor> rotary_cos = rotary_cos_.has_value() ? std::optional<const at::Tensor>(rotary_cos_.value()) : std::nullopt;
|
| 1700 |
+
std::optional<const at::Tensor> rotary_sin = rotary_sin_.has_value() ? std::optional<const at::Tensor>(rotary_sin_.value()) : std::nullopt;
|
| 1701 |
+
std::optional<const at::Tensor> cache_batch_idx = cache_batch_idx_.has_value() ? std::optional<const at::Tensor>(cache_batch_idx_.value()) : std::nullopt;
|
| 1702 |
+
std::optional<const at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<const at::Tensor>(leftpad_k_.value()) : std::nullopt;
|
| 1703 |
+
|
| 1704 |
+
// For non-const tensors
|
| 1705 |
std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
|
| 1706 |
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
| 1707 |
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
| 1708 |
+
|
| 1709 |
+
if (!out.has_value()){
|
| 1710 |
+
out = torch::empty_like(q);
|
| 1711 |
+
}
|
| 1712 |
+
|
| 1713 |
// Convert double to float and int64_t to int.
|
| 1714 |
float softmax_scale_float = static_cast<float>(softmax_scale);
|
| 1715 |
float softcap_float = static_cast<float>(softcap);
|
| 1716 |
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1717 |
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1718 |
int num_splits_int = static_cast<int>(num_splits);
|
| 1719 |
+
|
| 1720 |
+
return FLASH_NAMESPACE::mha_fwd_kvcache(
|
| 1721 |
+
const_cast<at::Tensor &>(q),
|
| 1722 |
+
kcache, vcache,
|
| 1723 |
+
k, v,
|
| 1724 |
+
seqlens_k,
|
| 1725 |
+
rotary_cos, rotary_sin,
|
| 1726 |
+
cache_batch_idx,
|
| 1727 |
+
leftpad_k,
|
| 1728 |
+
block_table, alibi_slopes,
|
| 1729 |
+
out,
|
| 1730 |
+
softmax_scale_float,
|
| 1731 |
+
is_causal,
|
| 1732 |
+
window_size_left_int, window_size_right_int,
|
| 1733 |
+
softcap_float,
|
| 1734 |
+
is_rotary_interleaved,
|
| 1735 |
+
num_splits_int
|
| 1736 |
+
);
|
| 1737 |
}
|