SkinTokens / src /server /spec.py
jasongzy's picture
๐Ÿž fix: cpu inference
30cb187
Raw
History Blame Contribute Delete
2.65 kB
from dataclasses import dataclass
from torch import Tensor
from typing import Dict, Optional, List, Tuple
import io
import os
import torch
from ..rig_package.info.asset import Asset
from ..model.tokenrig import TokenRig
PORT = 59875
SERVER = f"http://localhost:{PORT}"
TMP_CKPT_DIR = "./tmp_ckpt"
BPY_PORT = 59876
BPY_SERVER = f"http://localhost:{BPY_PORT}"
@dataclass
class TensorPacket:
"""make sure stays on cpu"""
validate: bool=False
know_skeleton: bool=False
learned_mesh_cond: Optional[Tensor]=None
cond_latents: Optional[Tensor]=None
mesh_cond: Optional[Tensor]=None
vertices: Optional[Tensor]=None
assets: Optional[List[Asset]]=None
output_ids: Optional[Tensor]=None
start_embed_list: Optional[List[Tensor]]=None
start_tokens_list: Optional[List[List[int]]]=None
def to_device(self, device):
if self.learned_mesh_cond is not None:
self.learned_mesh_cond = self.learned_mesh_cond.to(device)
if self.cond_latents is not None:
self.cond_latents = self.cond_latents.to(device)
if self.mesh_cond is not None:
self.mesh_cond = self.mesh_cond.to(device)
if self.vertices is not None:
self.vertices = self.vertices.to(device)
if self.output_ids is not None:
self.output_ids = self.output_ids.to(device)
if self.start_embed_list is not None:
self.start_embed_list = [x.to(device) for x in self.start_embed_list]
@property
def B(self):
assert self.learned_mesh_cond is not None
return self.learned_mesh_cond.shape[0]
def to_bytes(self):
return object_to_bytes(self)
@classmethod
def from_bytes(cls, bytes) -> 'TensorPacket':
return bytes_to_object(bytes)
def object_to_bytes(t):
buffer = io.BytesIO()
torch.save(t, buffer)
return buffer.getvalue()
def bytes_to_object(b, map_location=None):
return torch.load(io.BytesIO(b), weights_only=False, map_location=map_location)
def get_model(
ckpt_path: str,
hf_path: Optional[str]=None,
device="cuda" if torch.cuda.is_available() else "cpu",
) -> TokenRig:
model = TokenRig.load_from_system_checkpoint(checkpoint_path=ckpt_path)
if hf_path is not None:
from transformers import AutoModel
a = AutoModel.from_pretrained(
hf_path,
local_files_only=True,
_attn_implementation="flash_attention_2" if device == "cuda" else "sdpa",
torch_dtype=torch.bfloat16,
)
model.transformer.model.load_state_dict(a.state_dict())
model = model.to(device)
return model