""" Generation and inference utilities with constrained decoding """ import torch from transformers import LogitsProcessor, LogitsProcessorList from typing import Dict from config import ( SYSTEM_MSG, GEN_MAX_NEW_TOKENS, GEN_TEMPERATURE, GEN_TOP_P, GEN_TOP_K, GEN_NO_REPEAT_NGRAM_SIZE, GEN_REPETITION_PENALTY, GEN_END_LOGIT_SLOPE ) class LengthAwareMotionLogitsProcessor(LogitsProcessor): """ Constrained decoding processor that: 1. Enforces motion token vocabulary 2. Controls sequence length (min/soft_target/max) 3. Biases toward ending at soft_target length """ def __init__(self, prompt_len, mot_begin_id, mot_end_id, motion_ids, hard_min, soft_target, hard_max, end_logit_slope=0.25): super().__init__() self.prompt_len = int(prompt_len) self.mot_begin_id = int(mot_begin_id) self.mot_end_id = int(mot_end_id) self.motion_ids = torch.tensor(sorted(set(int(x) for x in motion_ids))) self.motion_plus_end = torch.tensor( sorted(set(list(self.motion_ids.tolist()) + [self.mot_end_id])) ) self.hard_min = int(hard_min) self.soft_target = int(soft_target) self.hard_max = int(hard_max) self.end_logit_slope = float(end_logit_slope) def __call__(self, input_ids, scores): device = scores.device bs = scores.size(0) mask = torch.full_like(scores, float("-inf")) for b in range(bs): gen = input_ids[b, self.prompt_len:] # No tokens generated yet - must start with MOT_BEGIN if gen.numel() == 0: allowed = torch.tensor([self.mot_begin_id], device=device) mask[b].index_fill_(0, allowed, 0.0) continue # Find MOT_BEGIN position begin_pos = (gen == self.mot_begin_id).nonzero(as_tuple=True)[0] if begin_pos.numel() == 0: allowed = torch.tensor([self.mot_begin_id], device=device) mask[b].index_fill_(0, allowed, 0.0) continue # Already generated MOT_END - force EOS if (gen == self.mot_end_id).any(): allowed = torch.tensor([self.mot_end_id], device=device) mask[b].index_fill_(0, allowed, 0.0) continue # Count motion tokens after MOT_BEGIN after_begin = gen[begin_pos[0].item() + 1:] cur_len = after_begin.numel() # Before minimum length - only allow motion tokens if cur_len < self.hard_min: allowed = self.motion_ids.to(device) mask[b].index_fill_(0, allowed, 0.0) # After maximum length - force end elif cur_len >= self.hard_max: allowed = torch.tensor([self.mot_end_id], device=device) mask[b].index_fill_(0, allowed, 0.0) # Between min and max - allow motion tokens or end else: allowed = self.motion_plus_end.to(device) mask[b].index_fill_(0, allowed, 0.0) # Bias toward ending at soft_target distance = max(0, cur_len - self.soft_target) bias = self.end_logit_slope * float(distance) scores[b, self.mot_end_id] = scores[b, self.mot_end_id] + bias return scores + mask def get_len_controls(prompt_text: str, length_stats_by_text: Dict, global_median_len: int): """ Get length controls (min/soft_target/max) for a given prompt """ s = length_stats_by_text.get(prompt_text) if s is None: med = global_median_len else: med = s["median"] hard_min = max(1, int(0.6 * med)) soft_tgt = med hard_max = max(hard_min + 4, int(1.4 * med)) return hard_min, soft_tgt, hard_max def generate_t2m( model, tokenizer, prompt_text: str, mot_begin_id: int, mot_end_id: int, motion_token_ids: list, length_stats_by_text: Dict, global_median_len: int, prompt_vocab: Dict = None, pid: str = None, has_pid: bool = False, max_new_tokens: int = None, per_prompt_vocab: bool = True ): """ Generate motion sequence from text prompt with constrained decoding """ model.eval() device = next(model.parameters()).device if max_new_tokens is None: max_new_tokens = GEN_MAX_NEW_TOKENS # Build prompt pid_tok = "" if has_pid and pid is not None: pid_tok = f"" user_text = f"{pid_tok}\n\n" + prompt_text prompt = ( "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n" + "<|im_start|>user\n" + user_text + "\n<|im_end|>\n" + "<|im_start|>assistant\n" ) # Tokenize inputs = tokenizer(prompt, return_tensors="pt").to(device) prompt_len = inputs["input_ids"].size(1) # Get length controls hard_min, soft_tgt, hard_max = get_len_controls( prompt_text, length_stats_by_text, global_median_len ) # Get allowed motion tokens if per_prompt_vocab and prompt_vocab: allowed_motion_ids = prompt_vocab.get(prompt_text, motion_token_ids) else: allowed_motion_ids = motion_token_ids # Setup constrained decoding processors = LogitsProcessorList([ LengthAwareMotionLogitsProcessor( prompt_len=prompt_len, mot_begin_id=mot_begin_id, mot_end_id=mot_end_id, motion_ids=allowed_motion_ids, hard_min=hard_min, soft_target=soft_tgt, hard_max=hard_max, end_logit_slope=GEN_END_LOGIT_SLOPE, ) ]) # Generate with torch.no_grad(): out = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), max_new_tokens=min(max_new_tokens, hard_max + 4), do_sample=True, temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, top_k=GEN_TOP_K, no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM_SIZE, repetition_penalty=GEN_REPETITION_PENALTY, logits_processor=processors, eos_token_id=mot_end_id, pad_token_id=tokenizer.eos_token_id, ) # Decode decoded = tokenizer.decode(out[0], skip_special_tokens=False) reply = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0] return reply