File size: 658 Bytes
f25859d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
"""
Type definitions for TriMul task.
Input: Tuple of (input_tensor, mask, weights, config)
- input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
- mask: Mask tensor of shape [batch_size, seq_len, seq_len]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
"""
import torch
from typing import Tuple, Dict, Any
# Input type: (input_tensor, mask, weights, config)
input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
# Output type: output tensor
output_t = torch.Tensor |