File size: 3,385 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
from typing import Optional

class PositionalEncodingState:
    def __init__(self, driver, prefix: str):
        self.driver = driver
        self.prefix = prefix
        self.counter = 0
        
    def get_temp_tensor(self, data, name_suffix: str = "") -> str:
        """Store temporary computation results in driver memory"""
        name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
        self.counter += 1
        self.driver.create_tensor(name, data)
        return name
        
    def free_temp_tensor(self, name: str):
        """Clean up temporary tensors"""
        if self.driver.tensor_exists(name):
            self.driver.delete_tensor(name)
            
def sinusoidal_positional_encoding(

    seq_len: int,

    hidden_dim: int,

    driver = None,

    prefix: str = "pos_enc"

) -> str:
    """

    All computations done in driver memory if driver is provided

    Returns: name of positional encoding tensor in driver, or numpy array if no driver

    """
    if driver is None:
        # Fallback to numpy
        position = np.arange(seq_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim))
        pe = np.zeros((seq_len, hidden_dim))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        return pe
        
    state = PositionalEncodingState(driver, prefix)
    
    # Create position sequence in driver memory
    position_name = state.get_temp_tensor(
        np.arange(seq_len)[:, np.newaxis],
        "position"
    )
    
    # Create division terms in driver memory
    log_term = -np.log(10000.0) / hidden_dim
    div_indices = np.arange(0, hidden_dim, 2)
    div_term_name = state.get_temp_tensor(
        np.exp(div_indices * log_term),
        "div_term"
    )
    
    # Initialize output tensor in driver memory
    pe_name = state.get_temp_tensor(
        np.zeros((seq_len, hidden_dim)),
        "pe"
    )
    
    # Compute position * div_term in driver memory
    mul_name = state.get_temp_tensor(
        driver.matmul(
            driver.get_tensor(position_name),
            driver.get_tensor(div_term_name).reshape(1, -1)
        ),
        "multiplied"
    )
    
    # Compute sin and cos in driver memory
    sin_name = state.get_temp_tensor(
        driver.sin(mul_name),
        "sin"
    )
    cos_name = state.get_temp_tensor(
        driver.cos(mul_name),
        "cos"
    )
    
    # Place sin and cos values in output tensor
    for i in range(0, hidden_dim, 2):
        # Even indices get sin values
        driver.scatter(pe_name, 
                      np.array([(j, i) for j in range(seq_len)]),
                      driver.get_tensor(sin_name)[:, i//2])
        # Odd indices get cos values
        if i + 1 < hidden_dim:
            driver.scatter(pe_name,
                         np.array([(j, i+1) for j in range(seq_len)]),
                         driver.get_tensor(cos_name)[:, i//2])
    
    # Cleanup intermediate tensors
    state.free_temp_tensor(position_name)
    state.free_temp_tensor(div_term_name)
    state.free_temp_tensor(mul_name)
    state.free_temp_tensor(sin_name)
    state.free_temp_tensor(cos_name)
    
    # Return final tensor name
    return pe_name