| import torch |
| from torch import nn, FloatTensor, LongTensor |
| import numpy as np |
| from torch.nn.functional import pad |
| from typing import Dict, List, Union |
| from transformers import AutoModelForCausalLM, AutoConfig, LogitsProcessor, LogitsProcessorList |
|
|
| from .spec import ModelSpec, ModelInput |
| from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder |
|
|
| from ..tokenizer.spec import TokenizerSpec, DetokenizeOutput |
| from copy import deepcopy |
|
|
| class VocabSwitchingLogitsProcessor(LogitsProcessor): |
| def __init__(self, tokenizer: TokenizerSpec, start_tokens: LongTensor): |
| self.tokenizer = tokenizer |
| self.start_tokens = start_tokens |
| assert start_tokens.ndim == 1 |
|
|
| def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor: |
| |
| for batch_idx, sequence in enumerate(input_ids): |
| mask = torch.full_like(scores[batch_idx], float('-inf')) |
| sequence = torch.cat([self.start_tokens, sequence]) |
| tokens = self.tokenizer.next_posible_token(ids=sequence.detach().cpu().numpy()) |
| mask[tokens] = 0 |
| scores[batch_idx] = scores[batch_idx] + mask |
| return scores |
|
|
| class UniRigAR(ModelSpec): |
| |
| def process_fn(self, batch: List[ModelInput]) -> List[Dict]: |
| if batch[0].joints is None: |
| return [{} for _ in range(len(batch))] |
| max_length = 0 |
| for b in batch: |
| max_length = max(max_length, b.tokens.shape[0]) |
| res = [{ |
| 'input_ids': np.pad(b.tokens, ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=b.pad), |
| 'attention_mask': np.pad(torch.ones(b.tokens.shape[0]), ((0, max_length - b.tokens.shape[0])), 'constant', constant_values=0.), |
| } for b in batch] |
| return res |
| |
| def __init__(self, llm, mesh_encoder, **kwargs): |
| super().__init__() |
| self.tokenizer: TokenizerSpec = kwargs.get('tokenizer') |
| self.vocab_size = self.tokenizer.vocab_size |
| |
| _d = llm.copy() |
| _d['vocab_size'] = self.tokenizer.vocab_size |
| llm_config = AutoConfig.from_pretrained(**_d) |
| |
| llm_config.torch_dtype = torch.float32 |
| |
| llm_config.pre_norm = True |
| self.transformer = AutoModelForCausalLM.from_config(config=llm_config) |
| |
| self.hidden_size = llm.hidden_size |
| |
| self.mesh_encoder = get_mesh_encoder(**mesh_encoder) |
| |
| if ( |
| isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or |
| isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder) |
| ): |
| self.output_proj = nn.Linear(self.mesh_encoder.width, self.hidden_size) |
| else: |
| raise NotImplementedError() |
| |
| def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor: |
| assert not torch.isnan(vertices).any() |
| assert not torch.isnan(normals).any() |
| if ( |
| isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or |
| isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder) |
| ): |
| if (len(vertices.shape) == 3): |
| shape_embed, latents, token_num, pre_pc = self.mesh_encoder.encode_latents(pc=vertices, feats=normals) |
| else: |
| shape_embed, latents, token_num, pre_pc = self.mesh_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) |
| latents = self.output_proj(latents) |
| return latents |
| else: |
| raise NotImplementedError() |
| |
| def training_step(self, batch: Dict) -> Dict[str, FloatTensor]: |
| cond = self.encode_mesh_cond(vertices=batch['vertices'], normals=batch['normals']).to(dtype=self.transformer.dtype) |
| B = cond.shape[0] |
| input_ids: LongTensor = batch['input_ids'] |
| inputs_embeds = self.transformer.get_input_embeddings()(input_ids).to(dtype=self.transformer.dtype) |
| |
| inputs_embeds = torch.concat([cond, inputs_embeds], dim=1) |
| |
| attention_mask = batch['attention_mask'] |
| |
| attention_mask = pad(attention_mask, (cond.shape[1], 0, 0, 0), value=1.) |
| output = self.transformer( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| ) |
| |
| |
| logit = output.logits[:, cond.shape[1]:].reshape(B, -1, self.vocab_size) |
| |
| device = logit.device |
| logit = logit[:, :-1] |
| num_discrete = self.tokenizer.num_discrete |
| s = torch.nn.functional.softmax(logit, dim=-1) |
| |
| label = input_ids[:, 1:].clone() |
| mask = label < num_discrete |
| dis = torch.arange(num_discrete, device=device).view(1, 1, -1) |
| dis = (dis - label.unsqueeze(2).repeat(1, 1, num_discrete)).type(torch.float32) / num_discrete |
| dis_loss = (s[:, :, :num_discrete] * torch.abs(dis))[mask].sum() / 50 |
| |
| label[attention_mask[:, cond.shape[1] + 1:]==0] = -100 |
| |
| assert not torch.isnan(logit).any(), logit |
| ce_loss = nn.functional.cross_entropy(logit.permute(0, 2, 1), label) |
| return { |
| 'ce_loss': ce_loss, |
| 'dis_loss': dis_loss, |
| } |
| |
| def forward(self, data: Dict): |
| return self.training_step(data=data) |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| vertices: FloatTensor, |
| normals: FloatTensor, |
| cls: Union[str, None]=None, |
| **kwargs, |
| ) -> DetokenizeOutput: |
| ''' |
| Do not support batch! |
| ''' |
| cond = self.encode_mesh_cond(vertices=vertices, normals=normals).to(dtype=self.transformer.dtype) |
| |
| start_tokens = [self.tokenizer.bos] |
| |
| if cls is not None: |
| start_tokens.append(self.tokenizer.cls_name_to_token(cls=cls)) |
| start_tokens = torch.tensor(start_tokens).to(cond.device) |
| start_embed = self.transformer.get_input_embeddings()( |
| start_tokens.unsqueeze(0) |
| ).to(dtype=self.transformer.dtype) |
| cond = torch.cat([cond, start_embed], dim=1) |
| |
| processor = VocabSwitchingLogitsProcessor( |
| tokenizer=self.tokenizer, |
| start_tokens=start_tokens, |
| ) |
| results = self.transformer.generate( |
| inputs_embeds=cond, |
| bos_token_id=self.tokenizer.bos, |
| eos_token_id=self.tokenizer.eos, |
| pad_token_id=self.tokenizer.pad, |
| logits_processor=LogitsProcessorList([processor]), |
| **kwargs, |
| ) |
| output_ids = results[0, :] |
| for token in reversed(start_tokens): |
| output_ids = pad(output_ids, (1, 0), value=token) |
| output_ids = output_ids.detach().cpu().numpy() |
| |
| res = self.tokenizer.detokenize(ids=output_ids) |
| return res |
| |
| def predict_step(self, batch: Dict, no_cls: bool=False): |
| vertices: FloatTensor = batch['vertices'] |
| normals : FloatTensor = batch['normals'] |
| paths : List[str] = batch['path'] |
| cls = batch['cls'] |
| generate_kwargs = deepcopy(batch['generate_kwargs']) |
|
|
| no_cls = generate_kwargs.get('no_cls', False) |
| use_dir_cls = generate_kwargs.get('use_dir_cls', False) |
| assign_cls = generate_kwargs.get('assign_cls', None) |
|
|
| generate_kwargs.pop('no_cls', None) |
| generate_kwargs.pop('use_dir_cls', None) |
| generate_kwargs.pop('assign_cls', None) |
|
|
| if vertices.dim() == 2: |
| vertices = vertices.unsqueeze(0) |
| normals = normals.unsqueeze(0) |
| outputs = [] |
| for i in range(vertices.shape[0]): |
| if no_cls: |
| _cls = None |
| elif assign_cls is not None: |
| _cls = assign_cls |
| print(f"UniRigAR: Using assigned skeleton type: {_cls}") |
| elif use_dir_cls: |
| _cls = paths[i].removeprefix('./').split('/')[0] |
| else: |
| _cls = cls[i] |
| res = self.generate(vertices=vertices[i], normals=normals[i], cls=_cls, **generate_kwargs) |
| outputs.append(res) |
| return outputs |