|
|
"""
|
|
|
Model inference using Helium virtual GPU with PyTorch-style loading and execution.
|
|
|
"""
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
import json
|
|
|
import numpy as np
|
|
|
from typing import Dict, List, Optional, Union, Any, Tuple
|
|
|
|
|
|
from helium import HeliumMultiModal
|
|
|
from helium.modality import ModalityType
|
|
|
from helium.tensor_ops import TensorOps
|
|
|
from helium.embedding import Embedding
|
|
|
from helium.positional_encoding import sinusoidal_positional_encoding
|
|
|
from helium.multihead_attention import AttentionConfig, AttentionType
|
|
|
from helium.normalization import NormConfig, NormType
|
|
|
from helium.gelu import gelu
|
|
|
from helium.softmax import softmax
|
|
|
from helium.decoder import DecoderConfig
|
|
|
from safetensors.numpy import save_file, load_file
|
|
|
|
|
|
class HeliumModel:
|
|
|
"""Base model class for Helium framework"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self._modules: Dict[str, Any] = {}
|
|
|
self._parameters: Dict[str, np.ndarray] = {}
|
|
|
self._buffers: Dict[str, np.ndarray] = {}
|
|
|
self.training = False
|
|
|
self.device_id = None
|
|
|
|
|
|
def load_state_from_db(self, model_key: str, device_id: str) -> None:
|
|
|
"""Load model state from device DB"""
|
|
|
import duckdb
|
|
|
from config import get_db_url
|
|
|
|
|
|
conn = duckdb.connect(get_db_url())
|
|
|
|
|
|
|
|
|
config = conn.execute(
|
|
|
"SELECT config FROM model_configs WHERE model_key = ?",
|
|
|
[model_key]
|
|
|
).fetchone()[0]
|
|
|
self.config = json.loads(config)
|
|
|
|
|
|
|
|
|
state_blob = conn.execute(
|
|
|
"SELECT weights FROM model_weights WHERE model_key = ?",
|
|
|
[model_key]
|
|
|
).fetchone()[0]
|
|
|
|
|
|
state_dict = json.loads(state_blob)
|
|
|
self.load_state_dict(state_dict)
|
|
|
|
|
|
def to_device(self, device_id: str) -> None:
|
|
|
"""Move model to specified virtual GPU device"""
|
|
|
self.device_id = device_id
|
|
|
for module in self._modules.values():
|
|
|
if hasattr(module, 'to_device'):
|
|
|
module.to_device(device_id)
|
|
|
|
|
|
def register_module(self, name: str, module: Any) -> None:
|
|
|
self._modules[name] = module
|
|
|
|
|
|
def register_parameter(self, name: str, param: np.ndarray) -> None:
|
|
|
self._parameters[name] = param
|
|
|
|
|
|
def register_buffer(self, name: str, buffer: np.ndarray) -> None:
|
|
|
self._buffers[name] = buffer
|
|
|
|
|
|
def state_dict(self) -> Dict[str, Any]:
|
|
|
"""Returns model state as a dictionary"""
|
|
|
state = {}
|
|
|
state.update(self._parameters)
|
|
|
state.update(self._buffers)
|
|
|
for name, module in self._modules.items():
|
|
|
if hasattr(module, "state_dict"):
|
|
|
state.update({
|
|
|
f"{name}.{k}": v
|
|
|
for k, v in module.state_dict().items()
|
|
|
})
|
|
|
return state
|
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
|
|
"""Loads model state from dictionary"""
|
|
|
for name, param in state_dict.items():
|
|
|
if "." in name:
|
|
|
module_name, param_name = name.split(".", 1)
|
|
|
if module_name in self._modules:
|
|
|
if hasattr(self._modules[module_name], "load_state_dict"):
|
|
|
self._modules[module_name].load_state_dict({param_name: param})
|
|
|
else:
|
|
|
if name in self._parameters:
|
|
|
self._parameters[name] = param
|
|
|
elif name in self._buffers:
|
|
|
self._buffers[name] = param
|
|
|
|
|
|
def train(self, mode: bool = True) -> "HeliumModel":
|
|
|
"""Sets training mode"""
|
|
|
self.training = mode
|
|
|
for module in self._modules.values():
|
|
|
if hasattr(module, "train"):
|
|
|
module.train(mode)
|
|
|
return self
|
|
|
|
|
|
def eval(self) -> "HeliumModel":
|
|
|
"""Sets evaluation mode"""
|
|
|
return self.train(False)
|
|
|
|
|
|
def to_device(self, device_id: str) -> "HeliumModel":
|
|
|
"""Moves model to specified device"""
|
|
|
for module in self._modules.values():
|
|
|
if hasattr(module, "to_device"):
|
|
|
module.to_device(device_id)
|
|
|
return self
|
|
|
|
|
|
class MultiModalModel(HeliumModel):
|
|
|
"""Multi-modal model using Helium virtual GPU"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
hidden_size: int = 1024,
|
|
|
num_heads: int = 16,
|
|
|
num_layers: int = 12,
|
|
|
vocab_size: int = 50257,
|
|
|
max_seq_len: int = 2048,
|
|
|
device_id: str = "vgpu0"
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.config = {
|
|
|
"hidden_size": hidden_size,
|
|
|
"num_heads": num_heads,
|
|
|
"num_layers": num_layers,
|
|
|
"vocab_size": vocab_size,
|
|
|
"max_seq_len": max_seq_len
|
|
|
}
|
|
|
|
|
|
|
|
|
self.system = HeliumMultiModal(
|
|
|
num_tensor_cores=1,
|
|
|
memory_size=None
|
|
|
)
|
|
|
self.device_id = device_id
|
|
|
|
|
|
|
|
|
|
|
|
self.system = HeliumMultiModal(num_tensor_cores=1)
|
|
|
|
|
|
|
|
|
driver = self.system.gpu.tensor_cores[0]
|
|
|
|
|
|
self.register_module("text_embedding", Embedding(
|
|
|
vocab_size=vocab_size,
|
|
|
embedding_dim=hidden_size,
|
|
|
driver=driver,
|
|
|
prefix="text_embed"
|
|
|
))
|
|
|
|
|
|
|
|
|
pos_enc = sinusoidal_positional_encoding(
|
|
|
seq_len=max_seq_len,
|
|
|
hidden_dim=hidden_size,
|
|
|
driver=driver,
|
|
|
prefix="pos_enc"
|
|
|
)
|
|
|
self.register_buffer("positional_encoding", pos_enc)
|
|
|
|
|
|
|
|
|
decoder_config = DecoderConfig(
|
|
|
output_modalities=[ModalityType.TEXT],
|
|
|
hidden_dim=hidden_size,
|
|
|
num_layers=num_layers,
|
|
|
num_heads=num_heads,
|
|
|
intermediate_size=hidden_size * 4,
|
|
|
max_seq_len={ModalityType.TEXT: max_seq_len},
|
|
|
vocab_size=vocab_size,
|
|
|
use_cache=True
|
|
|
)
|
|
|
|
|
|
|
|
|
attn_config = AttentionConfig(
|
|
|
attention_type=AttentionType.SELF,
|
|
|
hidden_size=hidden_size,
|
|
|
num_heads=num_heads,
|
|
|
head_dim=hidden_size // num_heads,
|
|
|
dropout=0.1
|
|
|
)
|
|
|
self.register_buffer("attention_config", attn_config)
|
|
|
|
|
|
|
|
|
norm_config = NormConfig(
|
|
|
norm_type=NormType.LAYER,
|
|
|
hidden_size=hidden_size,
|
|
|
eps=1e-5
|
|
|
)
|
|
|
self.register_buffer("norm_config", norm_config)
|
|
|
|
|
|
|
|
|
self.register_parameter(
|
|
|
"qkv_weights",
|
|
|
np.random.randn(3, hidden_size, hidden_size).astype(np.float32) * 0.02
|
|
|
)
|
|
|
|
|
|
self.register_parameter(
|
|
|
"norm_weight",
|
|
|
np.ones(hidden_size).astype(np.float32)
|
|
|
)
|
|
|
|
|
|
self.register_parameter(
|
|
|
"norm_bias",
|
|
|
np.zeros(hidden_size).astype(np.float32)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.register_parameter(
|
|
|
"fusion_weight",
|
|
|
np.random.randn(hidden_size, hidden_size).astype(np.float32)
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_dict: Dict[str, np.ndarray],
|
|
|
return_dict: bool = True
|
|
|
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
|
"""Forward pass"""
|
|
|
outputs = {}
|
|
|
|
|
|
|
|
|
for modality, inputs in input_dict.items():
|
|
|
if modality == "text":
|
|
|
|
|
|
embeds = self._modules["text_embedding"](inputs)
|
|
|
pos_embeds = embeds + self._buffers["positional_encoding"][:inputs.shape[1]]
|
|
|
|
|
|
|
|
|
mean = pos_embeds.mean(axis=-1, keepdims=True)
|
|
|
var = ((pos_embeds - mean) ** 2).mean(axis=-1, keepdims=True)
|
|
|
hidden = (pos_embeds - mean) / np.sqrt(var + self._buffers["norm_config"].eps)
|
|
|
hidden = hidden * self._parameters["norm_weight"] + self._parameters["norm_bias"]
|
|
|
|
|
|
|
|
|
qkv = np.einsum('...d,hdi->...hi', hidden, self._parameters["qkv_weights"])
|
|
|
q, k, v = np.split(qkv, 3, axis=-2)
|
|
|
|
|
|
|
|
|
attn_weights = np.matmul(q, k.transpose(-2, -1)) / np.sqrt(hidden.shape[-1])
|
|
|
attn_weights = softmax(attn_weights, axis=-1)
|
|
|
attn_output = np.matmul(attn_weights, v)
|
|
|
|
|
|
|
|
|
hidden = gelu(attn_output)
|
|
|
outputs["text_features"] = hidden
|
|
|
|
|
|
elif modality == "image":
|
|
|
|
|
|
outputs["image_features"] = self.system.process_batch({
|
|
|
ModalityType.IMAGE: inputs
|
|
|
})
|
|
|
|
|
|
elif modality == "audio":
|
|
|
|
|
|
outputs["audio_features"] = self.system.process_batch({
|
|
|
ModalityType.AUDIO: inputs
|
|
|
})
|
|
|
|
|
|
|
|
|
if len(outputs) > 1:
|
|
|
fusion = sum(outputs.values())
|
|
|
fusion = fusion @ self._parameters["fusion_weight"]
|
|
|
outputs["fused_features"] = fusion
|
|
|
|
|
|
return outputs if return_dict else fusion
|
|
|
|
|
|
def generate(
|
|
|
self,
|
|
|
inputs: Union[np.ndarray, Dict[str, np.ndarray]],
|
|
|
max_length: int = 100,
|
|
|
**kwargs
|
|
|
) -> np.ndarray:
|
|
|
"""Generate sequence"""
|
|
|
if isinstance(inputs, dict):
|
|
|
|
|
|
hidden = self.forward(inputs, return_dict=False)
|
|
|
else:
|
|
|
|
|
|
embeds = self._modules["text_embedding"](inputs)
|
|
|
pos_embeds = self._modules["pos_encoding"](embeds)
|
|
|
hidden = self._modules["decoder"](pos_embeds)
|
|
|
|
|
|
|
|
|
generated = []
|
|
|
for _ in range(max_length):
|
|
|
next_token = self._modules["decoder"].predict_next(hidden)
|
|
|
generated.append(next_token)
|
|
|
|
|
|
|
|
|
next_embeds = self._modules["text_embedding"](next_token)
|
|
|
next_pos = self._modules["pos_encoding"](next_embeds)
|
|
|
hidden = self._modules["decoder"](next_pos, hidden)
|
|
|
|
|
|
return np.array(generated)
|
|
|
|
|
|
def save_pretrained(self, path: str) -> None:
|
|
|
"""Save model weights and config"""
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
|
|
|
with open(os.path.join(path, "config.json"), "w") as f:
|
|
|
json.dump(self.config, f, indent=2)
|
|
|
|
|
|
|
|
|
save_file(self.state_dict(), os.path.join(path, "model.safetensors"))
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(
|
|
|
cls,
|
|
|
model_id: str = "openai-oss-20b",
|
|
|
device_id: str = "vgpu0",
|
|
|
cache_dir: Optional[str] = None,
|
|
|
**kwargs
|
|
|
) -> "MultiModalModel":
|
|
|
"""Load pretrained model from HuggingFace Hub"""
|
|
|
from .model_loader import download_model, store_in_device_db
|
|
|
from config import get_db_url
|
|
|
|
|
|
|
|
|
local_path = download_model(model_id, cache_dir)
|
|
|
|
|
|
|
|
|
db_url = get_db_url()
|
|
|
model_key = store_in_device_db(local_path, db_url)
|
|
|
|
|
|
|
|
|
model = cls()
|
|
|
model.load_state_from_db(model_key, device_id)
|
|
|
model.to_device(device_id)
|
|
|
return model
|
|
|
device_db_url = get_db_url()
|
|
|
store_in_device_db(local_path, device_db_url, model_id)
|
|
|
|
|
|
|
|
|
conn = duckdb.connect(device_db_url)
|
|
|
|
|
|
|
|
|
config = conn.execute(
|
|
|
"SELECT config FROM model_configs WHERE model_id = ?",
|
|
|
[model_id]
|
|
|
).fetchone()[0]
|
|
|
config = json.loads(config)
|
|
|
|
|
|
|
|
|
model = cls(**config, device_id=device_id, **kwargs)
|
|
|
state_dict = load_file(os.path.join(path, "model.safetensors"))
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
|
return model
|
|
|
|
|
|
def main():
|
|
|
"""Example usage"""
|
|
|
|
|
|
model = MultiModalModel(
|
|
|
hidden_size=1024,
|
|
|
num_heads=16,
|
|
|
num_layers=12,
|
|
|
device_id="vgpu0"
|
|
|
)
|
|
|
|
|
|
|
|
|
inputs = {
|
|
|
"text": np.random.randint(0, 50257, (1, 64)),
|
|
|
"image": np.random.randn(1, 3, 224, 224),
|
|
|
"audio": np.random.randn(1, 1, 16000)
|
|
|
}
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
outputs = model(inputs)
|
|
|
|
|
|
print("Output features shapes:")
|
|
|
for k, v in outputs.items():
|
|
|
print(f" {k}: {v.shape}")
|
|
|
|
|
|
|
|
|
generated = model.generate(inputs, max_length=20)
|
|
|
print("\nGenerated sequence shape:", generated.shape)
|
|
|
|
|
|
|
|
|
model.save_pretrained("model_checkpoint")
|
|
|
|
|
|
|
|
|
loaded_model = MultiModalModel.from_pretrained("model_checkpoint")
|
|
|
print("\nSuccessfully loaded model from checkpoint")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |