File size: 6,632 Bytes
4bd136e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
Generation and inference utilities with constrained decoding
"""
import torch
from transformers import LogitsProcessor, LogitsProcessorList
from typing import Dict
from config import (
    SYSTEM_MSG, GEN_MAX_NEW_TOKENS, GEN_TEMPERATURE,
    GEN_TOP_P, GEN_TOP_K, GEN_NO_REPEAT_NGRAM_SIZE,
    GEN_REPETITION_PENALTY, GEN_END_LOGIT_SLOPE
)


class LengthAwareMotionLogitsProcessor(LogitsProcessor):
    """
    Constrained decoding processor that:
    1. Enforces motion token vocabulary
    2. Controls sequence length (min/soft_target/max)
    3. Biases toward ending at soft_target length
    """
    
    def __init__(self, prompt_len, mot_begin_id, mot_end_id, motion_ids,
                 hard_min, soft_target, hard_max, end_logit_slope=0.25):
        super().__init__()
        self.prompt_len = int(prompt_len)
        self.mot_begin_id = int(mot_begin_id)
        self.mot_end_id = int(mot_end_id)
        self.motion_ids = torch.tensor(sorted(set(int(x) for x in motion_ids)))
        self.motion_plus_end = torch.tensor(
            sorted(set(list(self.motion_ids.tolist()) + [self.mot_end_id]))
        )
        self.hard_min = int(hard_min)
        self.soft_target = int(soft_target)
        self.hard_max = int(hard_max)
        self.end_logit_slope = float(end_logit_slope)
    
    def __call__(self, input_ids, scores):
        device = scores.device
        bs = scores.size(0)
        mask = torch.full_like(scores, float("-inf"))
        
        for b in range(bs):
            gen = input_ids[b, self.prompt_len:]
            
            # No tokens generated yet - must start with MOT_BEGIN
            if gen.numel() == 0:
                allowed = torch.tensor([self.mot_begin_id], device=device)
                mask[b].index_fill_(0, allowed, 0.0)
                continue
            
            # Find MOT_BEGIN position
            begin_pos = (gen == self.mot_begin_id).nonzero(as_tuple=True)[0]
            if begin_pos.numel() == 0:
                allowed = torch.tensor([self.mot_begin_id], device=device)
                mask[b].index_fill_(0, allowed, 0.0)
                continue
            
            # Already generated MOT_END - force EOS
            if (gen == self.mot_end_id).any():
                allowed = torch.tensor([self.mot_end_id], device=device)
                mask[b].index_fill_(0, allowed, 0.0)
                continue
            
            # Count motion tokens after MOT_BEGIN
            after_begin = gen[begin_pos[0].item() + 1:]
            cur_len = after_begin.numel()
            
            # Before minimum length - only allow motion tokens
            if cur_len < self.hard_min:
                allowed = self.motion_ids.to(device)
                mask[b].index_fill_(0, allowed, 0.0)
            
            # After maximum length - force end
            elif cur_len >= self.hard_max:
                allowed = torch.tensor([self.mot_end_id], device=device)
                mask[b].index_fill_(0, allowed, 0.0)
            
            # Between min and max - allow motion tokens or end
            else:
                allowed = self.motion_plus_end.to(device)
                mask[b].index_fill_(0, allowed, 0.0)
                
                # Bias toward ending at soft_target
                distance = max(0, cur_len - self.soft_target)
                bias = self.end_logit_slope * float(distance)
                scores[b, self.mot_end_id] = scores[b, self.mot_end_id] + bias
        
        return scores + mask


def get_len_controls(prompt_text: str, length_stats_by_text: Dict, global_median_len: int):
    """
    Get length controls (min/soft_target/max) for a given prompt
    """
    s = length_stats_by_text.get(prompt_text)
    if s is None:
        med = global_median_len
    else:
        med = s["median"]
    
    hard_min = max(1, int(0.6 * med))
    soft_tgt = med
    hard_max = max(hard_min + 4, int(1.4 * med))
    
    return hard_min, soft_tgt, hard_max


def generate_t2m(
    model,
    tokenizer,
    prompt_text: str,
    mot_begin_id: int,
    mot_end_id: int,
    motion_token_ids: list,
    length_stats_by_text: Dict,
    global_median_len: int,
    prompt_vocab: Dict = None,
    pid: str = None,
    has_pid: bool = False,
    max_new_tokens: int = None,
    per_prompt_vocab: bool = True
):
    """
    Generate motion sequence from text prompt with constrained decoding
    """
    model.eval()
    device = next(model.parameters()).device
    
    if max_new_tokens is None:
        max_new_tokens = GEN_MAX_NEW_TOKENS
    
    # Build prompt
    pid_tok = ""
    if has_pid and pid is not None:
        pid_tok = f"<PID_{pid}>"
    
    user_text = f"<T2M>{pid_tok}\n\n" + prompt_text
    prompt = (
        "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
        + "<|im_start|>user\n" + user_text + "\n<|im_end|>\n"
        + "<|im_start|>assistant\n"
    )
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    prompt_len = inputs["input_ids"].size(1)
    
    # Get length controls
    hard_min, soft_tgt, hard_max = get_len_controls(
        prompt_text, length_stats_by_text, global_median_len
    )
    
    # Get allowed motion tokens
    if per_prompt_vocab and prompt_vocab:
        allowed_motion_ids = prompt_vocab.get(prompt_text, motion_token_ids)
    else:
        allowed_motion_ids = motion_token_ids
    
    # Setup constrained decoding
    processors = LogitsProcessorList([
        LengthAwareMotionLogitsProcessor(
            prompt_len=prompt_len,
            mot_begin_id=mot_begin_id,
            mot_end_id=mot_end_id,
            motion_ids=allowed_motion_ids,
            hard_min=hard_min,
            soft_target=soft_tgt,
            hard_max=hard_max,
            end_logit_slope=GEN_END_LOGIT_SLOPE,
        )
    ])
    
    # Generate
    with torch.no_grad():
        out = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask"),
            max_new_tokens=min(max_new_tokens, hard_max + 4),
            do_sample=True,
            temperature=GEN_TEMPERATURE,
            top_p=GEN_TOP_P,
            top_k=GEN_TOP_K,
            no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM_SIZE,
            repetition_penalty=GEN_REPETITION_PENALTY,
            logits_processor=processors,
            eos_token_id=mot_end_id,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    # Decode
    decoded = tokenizer.decode(out[0], skip_special_tokens=False)
    reply = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
    
    return reply