File size: 658 Bytes
f25859d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
Type definitions for TriMul task.

Input: Tuple of (input_tensor, mask, weights, config)
  - input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
  - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
  - weights: Dictionary containing model weights
  - config: Dictionary containing model configuration parameters

Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
"""

import torch
from typing import Tuple, Dict, Any

# Input type: (input_tensor, mask, weights, config)
input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]

# Output type: output tensor
output_t = torch.Tensor