danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
f25859d verified
raw
history blame contribute delete
658 Bytes
"""
Type definitions for TriMul task.
Input: Tuple of (input_tensor, mask, weights, config)
- input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
- mask: Mask tensor of shape [batch_size, seq_len, seq_len]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
"""
import torch
from typing import Tuple, Dict, Any
# Input type: (input_tensor, mask, weights, config)
input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
# Output type: output tensor
output_t = torch.Tensor