0xZohar commited on
Commit
cfbc796
·
verified ·
1 Parent(s): 1c69fb1

Add missing code/cube3d/inference/logits_postprocesses.py

Browse files
code/cube3d/inference/logits_postprocesses.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def top_p_filtering(logits, top_p: float = 1.0):
6
+ """
7
+ Filter a distribution of logits using top-p filtering.
8
+ The input logits tensor is modified in-place.
9
+
10
+ Args:
11
+ logits (torch.Tensor): A tensor of logits to be filtered. Expected shape is [..., vocab_size].
12
+ top_p (float, optional): The cumulative probability threshold for top-p sampling.
13
+ If < 1.0, only keep the smallest set of tokens whose
14
+ cumulative probability does not exceed this threshold.
15
+
16
+ Returns:
17
+ torch.Tensor: logits where values outside the top-p threshold are set to -∞.
18
+ """
19
+ if top_p < 1.0:
20
+ sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
21
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
22
+ sorted_idx_to_remove[..., 0] = False
23
+
24
+ idx_to_remove = sorted_idx_to_remove.scatter(
25
+ -1, sorted_idx, sorted_idx_to_remove
26
+ )
27
+ logits.masked_fill_(idx_to_remove, -torch.inf)
28
+
29
+ return logits
30
+
31
+
32
+ def process_logits(
33
+ logits,
34
+ top_p: float = None,
35
+ ):
36
+ """
37
+ Process logits by optionally applying nucleus (top-p) filtering and token selection.
38
+
39
+ If `top_p` is None, the token with the highest probability (argmax) is selected.
40
+ If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain
41
+ probabilities. A token is sampled from this filtered distribution using `torch.multinomial`.
42
+
43
+ Args:
44
+ logits (torch.Tensor): A tensor of logits to process.
45
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
46
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
47
+
48
+ Returns:
49
+ torch.Tensor: selected token index.
50
+ """
51
+ if top_p is None:
52
+ #import ipdb; ipdb.set_trace()
53
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
54
+ else:
55
+ logits = top_p_filtering(logits, top_p=0.9)
56
+ probs = F.softmax(logits, dim=-1)
57
+ next_id = torch.multinomial(probs, num_samples=1, replacement=True)
58
+ return next_id
59
+
60
+ def process_logits_assembly(
61
+ logits,
62
+ #tokens_num: int = 19,
63
+ top_p: float = None,
64
+ pos_id: int = 0,
65
+ stride: int = 0
66
+ ):
67
+ """
68
+ Process logits by optionally applying nucleus (top-p) filtering and token selection.
69
+
70
+ If `top_p` is None, the token with the highest probability (argmax) is selected.
71
+ If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain
72
+ probabilities. A token is sampled from this filtered distribution using `torch.multinomial`.
73
+
74
+ Args:
75
+ logits (torch.Tensor): A tensor of logits to process.
76
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
77
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
78
+
79
+ Returns:
80
+ torch.Tensor: selected token index.
81
+ """
82
+ dat_num = 604
83
+ x_num = 213
84
+ y_num = 217
85
+ z_num = 529
86
+ rot_num = 24
87
+
88
+ # x = x_num
89
+ # xy = x_num + y_num + rot_num
90
+ # xyz = x_num + y_num + z_num + rot_num
91
+
92
+ if top_p is None:
93
+ if pos_id % stride==0 and stride>3:
94
+ next_id = logits[:, :dat_num+1].argmax(dim=-1) # [B]
95
+ elif pos_id % stride==1 and stride>4:
96
+ next_id = logits[:, :rot_num+1].argmax(dim=-1)
97
+ elif pos_id % stride==(stride-3):
98
+ next_id = logits[:, :y_num+1].argmax(dim=-1)
99
+ elif pos_id % stride==(stride-2):
100
+ next_id = logits[:, :x_num+1].argmax(dim=-1)
101
+ elif pos_id % stride==(stride-1):
102
+ next_id = logits[:, :z_num+1].argmax(dim=-1)
103
+ else:
104
+ if pos_id % stride == 0 and stride > 3:
105
+ logits = logits[:, :dat_num+1]
106
+ elif pos_id % stride == 1 and stride > 4:
107
+ logits = logits[:, :rot_num+1]
108
+ elif pos_id % stride == (stride-3):
109
+ logits = logits[:, :y_num+1]
110
+ elif pos_id % stride == (stride-2):
111
+ logits = logits[:, :x_num+1]
112
+ elif pos_id % stride == (stride-1):
113
+ logits = logits[:, :z_num+1]
114
+
115
+ logits = top_p_filtering(logits, top_p=top_p)
116
+ probs = F.softmax(logits, dim=-1)
117
+ next_id = torch.multinomial(probs, num_samples=1, replacement=True)
118
+ return next_id