File size: 8,010 Bytes
7a0c684 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
def embedding_lookup(
input_ids: np.ndarray,
embedding_weights: np.ndarray,
driver=None
) -> np.ndarray:
"""
Look up embeddings for input tokens.
Args:
input_ids: Input token indices of shape (batch_size, sequence_length)
embedding_weights: Embedding weight matrix of shape (vocab_size, hidden_dim)
driver: Optional hardware driver for optimized lookup
Returns:
Embedded tokens of shape (batch_size, sequence_length, hidden_dim)
"""
if driver and hasattr(driver, 'embedding_lookup'):
return driver.embedding_lookup(input_ids, embedding_weights)
# Fallback to numpy implementation
batch_size, seq_length = input_ids.shape
hidden_dim = embedding_weights.shape[1]
# Reshape input_ids for broadcasting
input_ids_reshaped = input_ids.reshape(-1)
# Lookup embeddings
embeddings = embedding_weights[input_ids_reshaped]
# Reshape back to (batch_size, sequence_length, hidden_dim)
return embeddings.reshape(batch_size, seq_length, hidden_dim)
def add_positional_encoding(
embeddings: np.ndarray,
max_position: int,
hidden_dim: int,
dtype: np.dtype = np.float32,
driver=None
) -> np.ndarray:
"""
Add positional encodings to input embeddings.
Args:
embeddings: Input embeddings of shape (batch_size, sequence_length, hidden_dim)
max_position: Maximum sequence length
hidden_dim: Hidden dimension size
dtype: Data type for positional encodings
driver: Optional hardware driver for optimized computation
Returns:
Embeddings with positional encoding added
"""
if driver and hasattr(driver, 'add_positional_encoding'):
return driver.add_positional_encoding(
embeddings,
max_position,
hidden_dim,
dtype
)
# Fallback to numpy implementation
batch_size, seq_length, _ = embeddings.shape
# Create position indices
position = np.arange(seq_length)[:, np.newaxis]
div_term = np.exp(
np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
)
# Calculate positional encodings
pos_encoding = np.zeros((seq_length, hidden_dim), dtype=dtype)
pos_encoding[:, 0::2] = np.sin(position * div_term)
pos_encoding[:, 1::2] = np.cos(position * div_term)
# Add batch dimension and add to embeddings
pos_encoding = pos_encoding[np.newaxis, :, :]
return embeddings + pos_encoding[:, :seq_length, :]
class EmbeddingState:
def __init__(self, driver, prefix: str):
self.driver = driver
self.prefix = prefix
self.counter = 0
def get_temp_tensor(self, data, name_suffix: str = "") -> str:
"""Store temporary computation results in driver memory"""
name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
self.counter += 1
self.driver.create_tensor(name, data)
return name
def free_temp_tensor(self, name: str):
"""Clean up temporary tensors"""
if self.driver.tensor_exists(name):
self.driver.delete_tensor(name)
class Embedding:
"""
GPU/DB-backed Embedding layer for NLP/graph models.
All weights/tensors are stored and accessed via the driver (e.g., SQLiteMemoryManager), not Python RAM.
"""
def __init__(self, vocab_size: int, embedding_dim: int, driver, prefix: str = "embed", init_std: float = 0.02):
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.driver = driver
self.prefix = prefix
# Create unique names for persistent tensors
self.weight_name = f"{prefix}_weight"
self.grad_name = f"{prefix}_grad"
# Initialize embedding matrix in driver memory if not present
if not driver.tensor_exists(self.weight_name):
weights = driver.random_normal(
(vocab_size, embedding_dim),
mean=0.0,
std=init_std
)
driver.create_tensor(self.weight_name, weights)
# Initialize gradient tensor
driver.create_tensor(
self.grad_name,
np.zeros((vocab_size, embedding_dim))
)
def forward(
self,
indices_name: str,
training: bool = True
) -> str:
"""
All operations in driver memory
indices_name: name of tensor containing indices in driver
Returns: name of output tensor in driver
"""
state = EmbeddingState(self.driver, f"{self.prefix}_fwd")
# Get shape info from driver
indices = self.driver.get_tensor(indices_name)
original_shape = indices.shape
# Flatten indices in driver memory
flat_name = state.get_temp_tensor(
indices.reshape(-1),
"flat"
)
# Gather embeddings in driver memory
gathered_name = state.get_temp_tensor(
self.driver.gather(self.weight_name, flat_name),
"gathered"
)
state.free_temp_tensor(flat_name)
# Reshape to original dimensions + embedding_dim
output_shape = original_shape + (self.embedding_dim,)
output_name = state.get_temp_tensor(
self.driver.reshape(gathered_name, output_shape),
"output"
)
state.free_temp_tensor(gathered_name)
if training:
# Store intermediate results needed for backward
self.save_for_backward(indices_name, original_shape)
return output_name
def save_for_backward(self, indices_name: str, shape: Tuple[int, ...]):
"""Save tensors needed for backward pass in driver memory"""
self.driver.create_tensor(
f"{self.prefix}_cache_indices",
self.driver.get_tensor(indices_name)
)
self.driver.create_tensor(
f"{self.prefix}_cache_shape",
np.array(shape)
)
def backward(self, grad_output_name: str) -> None:
"""
Compute gradients in driver memory
grad_output_name: name of gradient tensor in driver
"""
state = EmbeddingState(self.driver, f"{self.prefix}_bwd")
# Get cached values from driver
indices = self.driver.get_tensor(f"{self.prefix}_cache_indices")
orig_shape = tuple(self.driver.get_tensor(f"{self.prefix}_cache_shape"))
# Reshape gradient to match gathered shape
reshaped_grad_name = state.get_temp_tensor(
self.driver.reshape(grad_output_name, (-1, self.embedding_dim)),
"reshaped_grad"
)
# Use scatter_add to accumulate gradients for each index
self.driver.scatter_add(
self.grad_name, # Accumulate into gradient tensor
indices.reshape(-1), # Flattened indices
reshaped_grad_name # Reshaped gradients
)
state.free_temp_tensor(reshaped_grad_name)
# Cleanup cached tensors
self.driver.delete_tensor(f"{self.prefix}_cache_indices")
self.driver.delete_tensor(f"{self.prefix}_cache_shape")
def parameters(self) -> Dict[str, str]:
"""Return names of parameter tensors in driver"""
return {
"weight": self.weight_name,
"grad": self.grad_name
}
def zero_grad(self) -> None:
"""Reset gradients to zero in driver memory"""
self.driver.fill(self.grad_name, 0.0)
|