File size: 8,010 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

from typing import Dict, List, Optional, Tuple, Union
import numpy as np


def embedding_lookup(

    input_ids: np.ndarray,

    embedding_weights: np.ndarray,

    driver=None

) -> np.ndarray:
    """

    Look up embeddings for input tokens.

    

    Args:

        input_ids: Input token indices of shape (batch_size, sequence_length)

        embedding_weights: Embedding weight matrix of shape (vocab_size, hidden_dim)

        driver: Optional hardware driver for optimized lookup

        

    Returns:

        Embedded tokens of shape (batch_size, sequence_length, hidden_dim)

    """
    if driver and hasattr(driver, 'embedding_lookup'):
        return driver.embedding_lookup(input_ids, embedding_weights)
        
    # Fallback to numpy implementation
    batch_size, seq_length = input_ids.shape
    hidden_dim = embedding_weights.shape[1]
    
    # Reshape input_ids for broadcasting
    input_ids_reshaped = input_ids.reshape(-1)
    
    # Lookup embeddings
    embeddings = embedding_weights[input_ids_reshaped]
    
    # Reshape back to (batch_size, sequence_length, hidden_dim)
    return embeddings.reshape(batch_size, seq_length, hidden_dim)


def add_positional_encoding(

    embeddings: np.ndarray,

    max_position: int,

    hidden_dim: int,

    dtype: np.dtype = np.float32,

    driver=None

) -> np.ndarray:
    """

    Add positional encodings to input embeddings.

    

    Args:

        embeddings: Input embeddings of shape (batch_size, sequence_length, hidden_dim)

        max_position: Maximum sequence length

        hidden_dim: Hidden dimension size

        dtype: Data type for positional encodings

        driver: Optional hardware driver for optimized computation

        

    Returns:

        Embeddings with positional encoding added

    """
    if driver and hasattr(driver, 'add_positional_encoding'):
        return driver.add_positional_encoding(
            embeddings,
            max_position,
            hidden_dim,
            dtype
        )
        
    # Fallback to numpy implementation
    batch_size, seq_length, _ = embeddings.shape
    
    # Create position indices
    position = np.arange(seq_length)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
    )
    
    # Calculate positional encodings
    pos_encoding = np.zeros((seq_length, hidden_dim), dtype=dtype)
    pos_encoding[:, 0::2] = np.sin(position * div_term)
    pos_encoding[:, 1::2] = np.cos(position * div_term)
    
    # Add batch dimension and add to embeddings
    pos_encoding = pos_encoding[np.newaxis, :, :]
    return embeddings + pos_encoding[:, :seq_length, :]


class EmbeddingState:
    def __init__(self, driver, prefix: str):
        self.driver = driver
        self.prefix = prefix
        self.counter = 0
        
    def get_temp_tensor(self, data, name_suffix: str = "") -> str:
        """Store temporary computation results in driver memory"""
        name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
        self.counter += 1
        self.driver.create_tensor(name, data)
        return name
        
    def free_temp_tensor(self, name: str):
        """Clean up temporary tensors"""
        if self.driver.tensor_exists(name):
            self.driver.delete_tensor(name)

class Embedding:
    """

    GPU/DB-backed Embedding layer for NLP/graph models.

    All weights/tensors are stored and accessed via the driver (e.g., SQLiteMemoryManager), not Python RAM.

    """
    def __init__(self, vocab_size: int, embedding_dim: int, driver, prefix: str = "embed", init_std: float = 0.02):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.driver = driver
        self.prefix = prefix
        
        # Create unique names for persistent tensors
        self.weight_name = f"{prefix}_weight"
        self.grad_name = f"{prefix}_grad"
        
        # Initialize embedding matrix in driver memory if not present
        if not driver.tensor_exists(self.weight_name):
            weights = driver.random_normal(
                (vocab_size, embedding_dim),
                mean=0.0,
                std=init_std
            )
            driver.create_tensor(self.weight_name, weights)
            # Initialize gradient tensor
            driver.create_tensor(
                self.grad_name,
                np.zeros((vocab_size, embedding_dim))
            )
            
    def forward(

        self,

        indices_name: str,

        training: bool = True

    ) -> str:
        """

        All operations in driver memory

        indices_name: name of tensor containing indices in driver

        Returns: name of output tensor in driver

        """
        state = EmbeddingState(self.driver, f"{self.prefix}_fwd")
        
        # Get shape info from driver
        indices = self.driver.get_tensor(indices_name)
        original_shape = indices.shape
        
        # Flatten indices in driver memory
        flat_name = state.get_temp_tensor(
            indices.reshape(-1),
            "flat"
        )
        
        # Gather embeddings in driver memory
        gathered_name = state.get_temp_tensor(
            self.driver.gather(self.weight_name, flat_name),
            "gathered"
        )
        state.free_temp_tensor(flat_name)
        
        # Reshape to original dimensions + embedding_dim
        output_shape = original_shape + (self.embedding_dim,)
        output_name = state.get_temp_tensor(
            self.driver.reshape(gathered_name, output_shape),
            "output"
        )
        state.free_temp_tensor(gathered_name)
        
        if training:
            # Store intermediate results needed for backward
            self.save_for_backward(indices_name, original_shape)
            
        return output_name
        
    def save_for_backward(self, indices_name: str, shape: Tuple[int, ...]):
        """Save tensors needed for backward pass in driver memory"""
        self.driver.create_tensor(
            f"{self.prefix}_cache_indices",
            self.driver.get_tensor(indices_name)
        )
        self.driver.create_tensor(
            f"{self.prefix}_cache_shape",
            np.array(shape)
        )
        
    def backward(self, grad_output_name: str) -> None:
        """

        Compute gradients in driver memory

        grad_output_name: name of gradient tensor in driver

        """
        state = EmbeddingState(self.driver, f"{self.prefix}_bwd")
        
        # Get cached values from driver
        indices = self.driver.get_tensor(f"{self.prefix}_cache_indices")
        orig_shape = tuple(self.driver.get_tensor(f"{self.prefix}_cache_shape"))
        
        # Reshape gradient to match gathered shape
        reshaped_grad_name = state.get_temp_tensor(
            self.driver.reshape(grad_output_name, (-1, self.embedding_dim)),
            "reshaped_grad"
        )
        
        # Use scatter_add to accumulate gradients for each index
        self.driver.scatter_add(
            self.grad_name,  # Accumulate into gradient tensor
            indices.reshape(-1),  # Flattened indices
            reshaped_grad_name  # Reshaped gradients
        )
        
        state.free_temp_tensor(reshaped_grad_name)
        
        # Cleanup cached tensors
        self.driver.delete_tensor(f"{self.prefix}_cache_indices")
        self.driver.delete_tensor(f"{self.prefix}_cache_shape")
        
    def parameters(self) -> Dict[str, str]:
        """Return names of parameter tensors in driver"""
        return {
            "weight": self.weight_name,
            "grad": self.grad_name
        }
        
    def zero_grad(self) -> None:
        """Reset gradients to zero in driver memory"""
        self.driver.fill(self.grad_name, 0.0)