0xZohar commited on
Commit
2d80856
·
verified ·
1 Parent(s): 51f3cca

Upload code/cube3d/inference/engine.py

Browse files
Files changed (1) hide show
  1. code/cube3d/inference/engine.py +924 -0
code/cube3d/inference/engine.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from peft import LoraConfig, get_peft_model, TaskType
5
+ from tqdm import tqdm
6
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
7
+
8
+ from cube3d.config import get_mapping_paths
9
+ from cube3d.inference.logits_postprocesses import process_logits, process_logits_assembly
10
+ from cube3d.inference.utils import load_config, load_model_weights, parse_structured, load_model_weights_adaption
11
+ from cube3d.training.process_single_ldr import logits2ldr, load_mappings, logits2flatldrp, logits2flatldrpr, ids2flatldrpr
12
+ from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
13
+ from cube3d.model.gpt.dual_stream_roformer import DualStreamRoformer
14
+ from cube3d.model.transformers.cache import Cache
15
+
16
+
17
+ class Engine:
18
+ def __init__(
19
+ self,
20
+ config_path: str,
21
+ gpt_ckpt_path: str,
22
+ shape_ckpt_path: str,
23
+ save_gpt_ckpt_path: str,
24
+ device: torch.device,
25
+ ):
26
+ """
27
+ Initializes the inference engine with the given configuration and checkpoint paths.
28
+ Args:
29
+ config_path (str): Path to the configuration file.
30
+ gpt_ckpt_path (str): Path to the GPT model checkpoint file.
31
+ shape_ckpt_path (str): Path to the shape model checkpoint file.
32
+ device (torch.device): The device to run the models on (e.g., 'cpu' or 'cuda').
33
+ Attributes:
34
+ cfg (dict): Loaded configuration from the config file.
35
+ device (torch.device): The device to run the models on.
36
+ gpt_model (DualStreamRoformer): The GPT model initialized and loaded with weights.
37
+ shape_model (OneDAutoEncoder): The shape model initialized and loaded with weights.
38
+ text_model (CLIPTextModelWithProjection): The text model initialized from a pretrained model.
39
+ text_tokenizer (CLIPTokenizerFast): The tokenizer for the text model.
40
+ max_new_tokens (int): Maximum number of new tokens for the shape model.
41
+ min_id (int): Minimum ID for the shape model codes.
42
+ max_id (int): Maximum ID for the shape model codes.
43
+ """
44
+
45
+ self.cfg = load_config(config_path)
46
+ self.device = device
47
+
48
+ self.gpt_model = DualStreamRoformer(
49
+ parse_structured(DualStreamRoformer.Config, self.cfg.gpt_model)
50
+ )
51
+ # load_model_weights(
52
+ # self.gpt_model,
53
+ # gpt_ckpt_path,
54
+ # )
55
+ self.gpt_model = load_model_weights_adaption(
56
+ self.gpt_model,
57
+ gpt_ckpt_path,
58
+ save_gpt_ckpt_path
59
+ )
60
+
61
+ self.gpt_model = self.gpt_model.eval().to(self.device)
62
+
63
+ self.shape_model = OneDAutoEncoder(
64
+ parse_structured(OneDAutoEncoder.Config, self.cfg.shape_model)
65
+ )
66
+ load_model_weights(
67
+ self.shape_model,
68
+ shape_ckpt_path,
69
+ )
70
+ self.shape_model = self.shape_model.eval().to(self.device)
71
+
72
+ # copy vq codebook to gpt
73
+ with torch.no_grad():
74
+ codebook = self.shape_model.bottleneck.block.get_codebook()
75
+ codebook = self.gpt_model.shape_proj(codebook).detach()
76
+ self.gpt_model.transformer.wte.weight.data[: codebook.shape[0]] = codebook
77
+ #import ipdb; ipdb.set_trace()
78
+ self.text_model = CLIPTextModelWithProjection.from_pretrained(
79
+ self.cfg.text_model_pretrained_model_name_or_path,
80
+ force_download=False,
81
+ device_map=self.device,
82
+ ).eval()
83
+ self.text_tokenizer = CLIPTokenizerFast.from_pretrained(
84
+ self.cfg.text_model_pretrained_model_name_or_path,
85
+ #force_download=False,
86
+ )
87
+
88
+ self.stride = 5
89
+ self.given = 0
90
+ self.max_new_tokens = 311*self.stride #self.shape_model.cfg.num_encoder_latents
91
+ self.min_id = 0
92
+ self.max_id = self.shape_model.cfg.num_codes
93
+
94
+ @torch.inference_mode()
95
+ def prepare_conditions_with_bbox(
96
+ self,
97
+ cond: torch.Tensor,
98
+ bounding_box_tensor: Optional[torch.Tensor] = None,
99
+ ):
100
+ """
101
+ Prepares condition embeddings by incorporating bounding box information.
102
+
103
+ Concatenates bounding box embeddings to the existing condition tensor if the model
104
+ supports bounding box projection. If no bounding box is provided, uses zero padding.
105
+
106
+ Args:
107
+ cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim).
108
+ bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box
109
+ as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for
110
+ bounding box embeddings.
111
+
112
+ Returns:
113
+ torch.Tensor: The condition tensor with bounding box embeddings concatenated along
114
+ the sequence dimension if bounding box projection is supported, otherwise
115
+ returns the original condition tensor unchanged.
116
+ """
117
+ if not hasattr(self.gpt_model, "bbox_proj"):
118
+ return cond
119
+
120
+ if bounding_box_tensor is None:
121
+ B = cond.shape[0]
122
+ bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device)
123
+
124
+ bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1)
125
+ cond = torch.cat([cond, bbox_emb], dim=1)
126
+ return cond
127
+
128
+ @torch.inference_mode()
129
+ def prepare_conditions_with_bboxs(
130
+ self,
131
+ cond: torch.Tensor,
132
+ bounding_box_tensor: Optional[torch.Tensor] = None,
133
+ ):
134
+ """
135
+ Prepares condition embeddings by incorporating bounding box information.
136
+
137
+ Concatenates bounding box embeddings to the existing condition tensor if the model
138
+ supports bounding box projection. If no bounding box is provided, uses zero padding.
139
+
140
+ Args:
141
+ cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim).
142
+ bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box
143
+ as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for
144
+ bounding box embeddings.
145
+
146
+ Returns:
147
+ torch.Tensor: The condition tensor with bounding box embeddings concatenated along
148
+ the sequence dimension if bounding box projection is supported, otherwise
149
+ returns the original condition tensor unchanged.
150
+ """
151
+ if not hasattr(self.gpt_model, "bbox_proj"):
152
+ return cond
153
+
154
+ if bounding_box_tensor is None:
155
+ B = cond.shape[0]
156
+ bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device)
157
+
158
+ bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1).expand(cond.shape[0], -1, -1)
159
+
160
+ cond = torch.cat([cond, bbox_emb], dim=1)
161
+ return cond
162
+
163
+ @torch.inference_mode()
164
+ def prepare_inputs(
165
+ self,
166
+ prompts: list[str],
167
+ guidance_scale: float,
168
+ bounding_box_xyz: Optional[Tuple[float]] = None,
169
+ ):
170
+ """
171
+ Prepares the input embeddings for the model based on the provided prompts and guidance scale.
172
+ Args:
173
+ prompts (list[str]): A list of prompt strings to be encoded.
174
+ guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
175
+ bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
176
+ as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
177
+ uses default bounding box sizing.
178
+ Returns:
179
+ tuple: A tuple containing:
180
+ - embed (torch.Tensor): The encoded input embeddings.
181
+ - cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0.
182
+ """
183
+
184
+ prompt_embeds = self.run_clip(prompts)
185
+
186
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
187
+ embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id)
188
+
189
+ if bounding_box_xyz is not None:
190
+ cond_bbox = torch.atleast_2d(torch.tensor(bounding_box_xyz)).to(self.device)
191
+ uncond_bbox = torch.zeros_like(cond_bbox).to(self.device)
192
+ else:
193
+ cond_bbox = None
194
+ uncond_bbox = None
195
+
196
+ cond = self.prepare_conditions_with_bbox(prompt_embeds, cond_bbox)
197
+ if guidance_scale > 0.0:
198
+ embed = torch.cat([embed, embed], dim=0)
199
+ uncond_embeds = self.run_clip([""] * len(prompts))
200
+ uncond = self.prepare_conditions_with_bbox(uncond_embeds, uncond_bbox)
201
+ cond = torch.cat([cond, uncond], dim=0)
202
+
203
+ return embed, cond
204
+
205
+ @torch.inference_mode()
206
+ def run_clip(self, text_inputs):
207
+ """
208
+ Processes the given text inputs using a text tokenizer and a text model, and returns the encoded text embeddings.
209
+ Args:
210
+ text_inputs (str or List[str]): The input text or list of texts to be processed.
211
+ Returns:
212
+ torch.Tensor: The encoded text embeddings.
213
+ """
214
+ text_inputs = self.text_tokenizer(
215
+ text_inputs,
216
+ max_length=self.text_tokenizer.model_max_length,
217
+ padding="max_length",
218
+ truncation=True,
219
+ return_tensors="pt",
220
+ )
221
+ with torch.no_grad():
222
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
223
+ # use full precision for text encoder
224
+ with torch.autocast(device_type=self.device.type, enabled=False):
225
+ encoded = self.text_model(**text_inputs)
226
+ if self.gpt_model.cfg.use_pooled_text_embed:
227
+ embed = encoded.text_embeds.unsqueeze(1) # [bs, 1, 512]
228
+ else:
229
+ embed = encoded.last_hidden_state # [bs, 77, 512]
230
+ embed = self.gpt_model.encode_text(embed)
231
+
232
+ return embed
233
+
234
+ @torch.inference_mode()
235
+ def encode_input(self, inputs: torch.Tensor, bos: int):
236
+ """
237
+ Encodes the beginning of sequence (BOS) token for the given input tensor.
238
+ Args:
239
+ inputs (torch.Tensor): The input tensor containing sequences.
240
+ bos (int): The beginning of sequence token ID.
241
+ Returns:
242
+ torch.Tensor: The encoded BOS token embeddings.
243
+ """
244
+
245
+ b = inputs.shape[0]
246
+ bos_embed = self.gpt_model.encode_token(
247
+ torch.full(
248
+ (b, 1),
249
+ fill_value=bos,
250
+ dtype=torch.long,
251
+ device=self.device,
252
+ )
253
+ )
254
+ return bos_embed
255
+
256
+ @torch.inference_mode()
257
+ def run_gpt(
258
+ self,
259
+ prompts: list[str],
260
+ use_kv_cache: bool,
261
+ guidance_scale: float = 3.0,
262
+ top_p: float = None,
263
+ bounding_box_xyz: Optional[Tuple[float]] = None,
264
+ ):
265
+ """
266
+ Generates text using a GPT model based on the provided prompts.
267
+ Args:
268
+ prompts (list[str]): A list of input prompts to generate text from.
269
+ use_kv_cache (bool): Whether to use key-value caching for faster generation.
270
+ guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
271
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
272
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
273
+ bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
274
+ as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
275
+ uses default bounding box sizing.
276
+ Returns:
277
+ torch.Tensor: A tensor containing the generated token IDs.
278
+ """
279
+ embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
280
+
281
+ output_ids = []
282
+ # import ipdb; ipdb.set_trace()
283
+ batch_size, input_seq_len, dim = embed.shape
284
+ max_seq_len = input_seq_len + self.max_new_tokens
285
+ embed_buffer = torch.zeros(
286
+ (batch_size, max_seq_len, dim), dtype=embed.dtype, device=embed.device
287
+ )
288
+ embed_buffer[:, :input_seq_len, :].copy_(embed)
289
+ cond_len = cond.shape[1]
290
+ kv_cache = None
291
+ if use_kv_cache:
292
+ kv_cache = self.gpt_model.init_kv_cache(
293
+ batch_size,
294
+ cond_len,
295
+ self.max_new_tokens + 1, # +1 for the BOS token
296
+ torch.bfloat16,
297
+ embed.device,
298
+ )
299
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
300
+ for i in tqdm(range(self.max_new_tokens), desc=f"generating"):
301
+ curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device)
302
+ logits = self.gpt_model(
303
+ embed_buffer,
304
+ cond,
305
+ kv_cache=kv_cache,
306
+ curr_pos_id=curr_pos_id if use_kv_cache else None,
307
+ decode=(i > 0) if use_kv_cache else False,
308
+ )
309
+ if use_kv_cache:
310
+ logits = logits[:, 0, ...]
311
+ else:
312
+ logits = logits[:, i, ...]
313
+
314
+ logits = logits[..., self.min_id : self.max_id]
315
+
316
+ if guidance_scale > 0.0:
317
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
318
+ # gamma = (
319
+ # guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
320
+ # )
321
+
322
+ logits = (1 + gamma) * logits - gamma * uncond_logits
323
+ next_id = process_logits(
324
+ logits,
325
+ top_p=top_p,
326
+ )
327
+ output_ids.append(next_id)
328
+ next_embed = self.gpt_model.encode_token(next_id)
329
+ if guidance_scale > 0.0:
330
+ next_embed = torch.cat([next_embed, next_embed], dim=0)
331
+ embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
332
+
333
+ #import ipdb; ipdb.set_trace()
334
+ return torch.cat(output_ids, dim=1)
335
+
336
+ @torch.inference_mode()
337
+ def run_apt(
338
+ self,
339
+ prompts: list[str],
340
+ inputs_ids: list[torch.Tensor],
341
+ use_kv_cache: bool,
342
+ guidance_scale: float = 3.0,
343
+ top_p: float = None,
344
+ bounding_box_xyz: Optional[Tuple[float]] = None,
345
+ ):
346
+ """
347
+ Generates text using a GPT model based on the provided prompts.
348
+ Args:
349
+ prompts (list[str]): A list of input prompts to generate text from.
350
+ use_kv_cache (bool): Whether to use key-value caching for faster generation.
351
+ guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
352
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
353
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
354
+ bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
355
+ as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
356
+ uses default bounding box sizing.
357
+ Returns:
358
+ torch.Tensor: A tensor containing the generated token IDs.
359
+ """
360
+ embed, _ = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
361
+
362
+ embed = embed[0][None]
363
+
364
+ output_ids = []
365
+ batch_size, input_seq_len, dim = embed.shape
366
+ max_seq_len = input_seq_len + self.max_new_tokens
367
+ embed_buffer = torch.zeros(
368
+ (batch_size, max_seq_len, dim), dtype=embed.dtype, device=embed.device
369
+ )
370
+ embed_buffer[:, :input_seq_len, :].copy_(embed)
371
+ #cond_len = cond.shape[1]
372
+ kv_cache = None
373
+ use_kv_cache = False
374
+ # if use_kv_cache:
375
+ # import ipdb; ipdb.set_trace()
376
+ # kv_cache = self.gpt_model.init_kv_cache(
377
+ # batch_size,
378
+ # cond_len,
379
+ # self.max_new_tokens + 1, # +1 for the BOS token
380
+ # torch.bfloat16,
381
+ # embed.device,
382
+ # )
383
+
384
+
385
+ with torch.no_grad():
386
+ attention_mask = inputs_ids != -1
387
+
388
+ cut_idx = (attention_mask == False)[:, :, 0].int().argmax(dim=1)
389
+ dat_id = inputs_ids[:,:,-6].long()
390
+ dat_id = torch.where(torch.arange(dat_id.shape[1], device=dat_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.dat_num, dat_id)
391
+
392
+ r_id = inputs_ids[:,:,0]
393
+ x_id = inputs_ids[:,:,-5]
394
+ y_id = inputs_ids[:,:,-4]
395
+ z_id = inputs_ids[:,:,-3]
396
+
397
+ r_id = torch.where(torch.arange(r_id.shape[1], device=r_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.rot_num, r_id)
398
+ x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.x_num, x_id)
399
+ y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.y_num, y_id)
400
+ z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.z_num, z_id)
401
+
402
+
403
+ inputs_ids[:, :, 0] = r_id.clone()
404
+ inputs_ids[:, :, -6] = dat_id.clone()
405
+ inputs_ids[:, :, -5] = x_id.clone()
406
+ inputs_ids[:, :, -4] = y_id.clone()
407
+ inputs_ids[:, :, -3] = z_id.clone()
408
+
409
+ # xembeds_from_id = self.gpt_model.xte(x_id)
410
+ # yembeds_from_id = self.gpt_model.yte(y_id)
411
+ # zembeds_from_id = self.gpt_model.zte(z_id)
412
+
413
+ # embeds_from_id = torch.stack([yembeds_from_id, xembeds_from_id, zembeds_from_id], dim=2) # [b, 310, 3, 1536]
414
+ # embeds_from_id = embeds_from_id.view(xembeds_from_id.shape[0], xembeds_from_id.shape[1] * 3, xembeds_from_id.shape[2]) # [b, 930, 1536]
415
+
416
+
417
+ inputs_embeds = self.gpt_model.dte(dat_id)
418
+
419
+ inputs_embeds = self.prepare_conditions_with_bboxs(inputs_embeds, bounding_box_xyz.to(inputs_embeds.device))
420
+
421
+ inputs_embeds = embed #bos
422
+
423
+ # #add bos
424
+ # place_holder = torch.ones_like(inputs_ids[:, 0, 0]).long() # batch x 1
425
+ # bos_embed = self.gpt_model.encode_token(place_holder * self.gpt_model.shape_bos_id) #[1, 1536]
426
+ # embeds_from_id = torch.cat([bos_embed[:, None, :], embeds_from_id], dim=1)
427
+
428
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
429
+ for i in tqdm(range(self.max_new_tokens), desc=f"generating"):
430
+ curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device)
431
+ #import ipdb; ipdb.set_trace()
432
+ logits = self.gpt_model(
433
+ embed = embed_buffer,
434
+ cond = inputs_embeds, #cond,
435
+ kv_cache=kv_cache,
436
+ curr_pos_id=curr_pos_id if use_kv_cache else None,
437
+ decode=(i > 0) if use_kv_cache else False,
438
+ )
439
+ if use_kv_cache:
440
+ logits = logits[:, 0, ...]
441
+ else:
442
+ logits = logits[:, i, ...]
443
+
444
+ logits = logits[..., self.min_id : self.max_id]
445
+
446
+ # if guidance_scale > 0.0:
447
+ # logits, uncond_logits = logits.float().chunk(2, dim=0)
448
+ # # gamma = (
449
+ # # guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
450
+ # # )
451
+ # gamma = guidance_scale
452
+ # logits = (1 + gamma) * logits - gamma * uncond_logits
453
+
454
+ next_id = process_logits_assembly(
455
+ logits,
456
+ top_p=0.9,
457
+ pos_id=curr_pos_id,
458
+ stride=self.stride
459
+ )
460
+
461
+ #next_embed = process_logits_assembly(logits) #self.gpt_model.encode_token(next_id)
462
+ #output_ids.append(next_id)
463
+ #output_ids.append(next_embed)
464
+ #import ipdb; ipdb.set_trace()
465
+ #next_embed = self.gpt_model.encode_token(next_id)
466
+ #next_embed = self.gpt_model.encode_embed(next_embed)
467
+
468
+ #output_ids.append(logits)
469
+
470
+ if curr_pos_id % self.stride == 0 and self.stride > 3:
471
+ #next_embed = self.gpt_model.dte(next_id)
472
+ if curr_pos_id<self.given*self.stride:
473
+ next_embed = self.gpt_model.dte(dat_id[0, max(0, min(i//self.stride, 309))])[None]
474
+ else:
475
+ next_embed = self.gpt_model.dte(next_id)
476
+ elif curr_pos_id % self.stride == 1 and self.stride > 4 :
477
+ if curr_pos_id<self.given*self.stride:
478
+ next_embed = self.gpt_model.rte(r_id[0, max(0, min(i//self.stride, 309))])[None]
479
+ else:
480
+ next_embed = self.gpt_model.rte(next_id)
481
+ elif curr_pos_id % self.stride == (self.stride - 3):
482
+ if curr_pos_id<self.given*self.stride:
483
+ next_embed = self.gpt_model.yte(y_id[0, max(0, min(i//self.stride, 309))])[None]
484
+ else:
485
+ next_embed = self.gpt_model.yte(next_id)
486
+ elif curr_pos_id % self.stride == (self.stride - 2):
487
+ if curr_pos_id<self.given*self.stride:
488
+ next_embed = self.gpt_model.xte(x_id[0, max(0, min(i//self.stride, 309))])[None]
489
+ else:
490
+ next_embed = self.gpt_model.xte(next_id)
491
+ elif curr_pos_id % self.stride == (self.stride - 1):
492
+ if curr_pos_id<self.given*self.stride:
493
+ next_embed = self.gpt_model.zte(z_id[0, max(0, min(i//self.stride, 309))])[None]
494
+ else:
495
+ next_embed = self.gpt_model.zte(next_id)
496
+
497
+
498
+ output_ids.append(next_id)
499
+ # if guidance_scale > 0.0:
500
+ # next_embed = torch.cat([next_embed, next_embed], dim=0)
501
+ embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
502
+
503
+ #return torch.cat(output_ids, dim=1)
504
+ return torch.cat(output_ids, dim=0), inputs_ids
505
+
506
+ @torch.inference_mode()
507
+ def run_shape_decode(
508
+ self,
509
+ output_ids: torch.Tensor,
510
+ resolution_base: float = 8.0,
511
+ chunk_size: int = 100_000,
512
+ ):
513
+ """
514
+ Decodes the shape from the given output IDs and extracts the geometry.
515
+ Args:
516
+ output_ids (torch.Tensor): The tensor containing the output IDs.
517
+ resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
518
+ chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
519
+ Returns:
520
+ tuple: A tuple containing the vertices and faces of the mesh.
521
+ """
522
+ shape_ids = (
523
+ output_ids[:, : self.shape_model.cfg.num_encoder_latents, ...]
524
+ .clamp_(0, self.shape_model.cfg.num_codes - 1)
525
+ .view(-1, self.shape_model.cfg.num_encoder_latents)
526
+ )
527
+ latents = self.shape_model.decode_indices(shape_ids)
528
+ mesh_v_f, _ = self.shape_model.extract_geometry(
529
+ latents,
530
+ resolution_base=resolution_base,
531
+ chunk_size=chunk_size,
532
+ use_warp=True,
533
+ )
534
+ return mesh_v_f
535
+
536
+ @torch.inference_mode()
537
+ def t2s(
538
+ self,
539
+ prompts: list[str],
540
+ use_kv_cache: bool,
541
+ guidance_scale: float = 3.0,
542
+ resolution_base: float = 8.0,
543
+ chunk_size: int = 100_000,
544
+ top_p: float = None,
545
+ bounding_box_xyz: Optional[Tuple[float]] = None,
546
+ ):
547
+ """
548
+ Generates a 3D mesh from text prompts using a GPT model and shape decoder.
549
+ Args:
550
+ prompts (list[str]): A list of text prompts to guide the generation.
551
+ use_kv_cache (bool): Whether to use key-value caching for the GPT model.
552
+ guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
553
+ resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
554
+ chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
555
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
556
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
557
+ bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
558
+ as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
559
+ uses default bounding box sizing.
560
+ Returns:
561
+ mesh_v_f: The generated 3D mesh vertices and faces.
562
+ """
563
+ #output_ids = self.run_gpt(
564
+ output_ids = self.run_apt(
565
+ prompts, use_kv_cache, guidance_scale, top_p, bounding_box_xyz
566
+ )
567
+
568
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
569
+ mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
570
+ return mesh_v_f
571
+
572
+ @torch.inference_mode()
573
+ def t2l(
574
+ self,
575
+ prompts: list[str],
576
+ inputs_ids: list[torch.Tensor],
577
+ use_kv_cache: bool,
578
+ guidance_scale: float = 3.0,
579
+ resolution_base: float = 8.0,
580
+ chunk_size: int = 100_000,
581
+ top_p: float = None,
582
+ bounding_box_xyz: Optional[Tuple[float]] = None,
583
+ idx: int = 0
584
+ ):
585
+ """
586
+ Generates a ldr file from text prompts using a GPT model and ldr decoder.
587
+ Args:
588
+ prompts (list[str]): A list of text prompts to guide the generation.
589
+ use_kv_cache (bool): Whether to use key-value caching for the GPT model.
590
+ guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
591
+ resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
592
+ chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
593
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
594
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
595
+ bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
596
+ as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
597
+ uses default bounding box sizing.
598
+ Returns:
599
+ ldr: The generated ldr file.
600
+ """
601
+ #output_ids = self.run_gpt(
602
+ output_ids, inputs_ids = self.run_apt(
603
+ prompts, inputs_ids, use_kv_cache, guidance_scale, top_p, bounding_box_xyz
604
+ )
605
+ #import ipdb; ipdb.set_trace()
606
+ # Use config-based paths (works in both local and HF Space environments)
607
+ forward_path, inverse_path = get_mapping_paths("subset_self")
608
+ label_mapping, label_inverse_mapping = load_mappings(forward_path, inverse_path)
609
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
610
+ #mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
611
+ #ldr = logits2ldr(output_ids.cpu().detach().numpy(), label_inverse_mapping)
612
+ ldr = ids2flatldrpr(output_ids.float().cpu().detach().numpy(), inputs_ids[0].cpu().detach().numpy(), self.stride, self.given, output_file=f"outputs/sample/test_{self.given}_drp_{idx}_r512_uncond.ldr")
613
+ return ldr
614
+
615
+
616
+ class EngineFast(Engine):
617
+ def __init__(
618
+ self,
619
+ config_path: str,
620
+ gpt_ckpt_path: str,
621
+ shape_ckpt_path: str,
622
+ device: torch.device,
623
+ ):
624
+ """
625
+ Initializes the inference engine with the given configuration and checkpoint paths.
626
+ Args:
627
+ config_path (str): Path to the configuration file.
628
+ gpt_ckpt_path (str): Path to the GPT checkpoint file.
629
+ shape_ckpt_path (str): Path to the shape checkpoint file.
630
+ device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
631
+ """
632
+
633
+ assert (
634
+ device.type == "cuda"
635
+ ), "EngineFast is only supported on cuda devices, please use Engine on non-cuda devices"
636
+
637
+ super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, device)
638
+
639
+ # CUDA Graph params
640
+ self.graph = torch.cuda.CUDAGraph()
641
+ self.embed_buffer = torch.Tensor()
642
+ self.cond_buffer = torch.Tensor()
643
+ self.logits_buffer = torch.Tensor()
644
+ self.curr_pos_id = torch.tensor([0], dtype=torch.long, device=self.device)
645
+ self.kv_cache: list[Cache] = []
646
+
647
+ self._warmup_and_capture_graph()
648
+
649
+ def _warmup_and_capture_graph(self):
650
+ """
651
+ Warms up the model by running a series of forward passes and captures the CUDA graph for efficient execution.
652
+ This method performs the following steps:
653
+ 1. Prepares the input embeddings and conditions using a warmup prompt.
654
+ 2. Initializes buffers for embeddings and conditions.
655
+ 3. Initializes the key-value cache for the GPT model.
656
+ 4. Runs a series of warmup passes to prefill the model and generate logits.
657
+ 5. Captures the CUDA graph for the model's forward pass to optimize future executions.
658
+ """
659
+
660
+ warmup_prompt = "A cube"
661
+ embed, cond = self.prepare_inputs([warmup_prompt], guidance_scale=3.0)
662
+
663
+ batch_size, input_seq_len, dim = embed.shape
664
+ max_seq_len = input_seq_len + self.max_new_tokens
665
+ self.embed_buffer = torch.zeros(
666
+ (batch_size, max_seq_len, dim), dtype=embed.dtype, device=self.device
667
+ )
668
+ self.embed_buffer[:, :input_seq_len, :].copy_(embed)
669
+
670
+ self.cond_buffer = torch.empty_like(cond)
671
+ self.cond_buffer.copy_(cond)
672
+ cond_len = self.cond_buffer.shape[1]
673
+
674
+ # Initialize kv_cache for the first time
675
+ self.kv_cache = self.gpt_model.init_kv_cache(
676
+ batch_size,
677
+ cond_len,
678
+ self.max_new_tokens + 1, # +1 for the BOS token
679
+ #torch.bfloat16,
680
+ torch.float32,
681
+ self.device,
682
+ )
683
+
684
+ num_warmup_passes = 10
685
+
686
+ #with torch.autocast(self.device.type, dtype=torch.bfloat16):
687
+ self._set_curr_pos_id(0)
688
+ _ = self._prefill_and_return_logits()
689
+
690
+ for x in range(1, num_warmup_passes):
691
+ self._set_curr_pos_id(x)
692
+ self.logits_buffer = self.gpt_model(
693
+ embed=self.embed_buffer,
694
+ cond=self.cond_buffer,
695
+ kv_cache=self.kv_cache,
696
+ curr_pos_id=self.curr_pos_id,
697
+ decode=True,
698
+ )
699
+
700
+ side_stream = torch.cuda.Stream(device=self.device)
701
+ with torch.cuda.graph(self.graph, stream=side_stream):
702
+ #with torch.autocast(self.device.type, dtype=torch.bfloat16):
703
+ self.logits_buffer = self.gpt_model(
704
+ embed=self.embed_buffer,
705
+ cond=self.cond_buffer,
706
+ kv_cache=self.kv_cache,
707
+ curr_pos_id=self.curr_pos_id,
708
+ decode=True,
709
+ #decode=False, #? should be false
710
+ )
711
+
712
+ def _reset_kv_cache(self):
713
+ """
714
+ Resets the key-value cache by setting all key and value states to zero.
715
+ This method iterates through each cache in the `kv_cache` attribute and
716
+ calls the `zero_()` method on both `key_states` and `value_states` to
717
+ reset them to their initial state.
718
+ """
719
+
720
+ for cache in self.kv_cache:
721
+ cache.key_states.zero_()
722
+ cache.value_states.zero_()
723
+
724
+ def _prefill_and_return_logits(self) -> torch.Tensor:
725
+ """
726
+ Prefills the model's key-value cache and returns the logits.
727
+ This method resets the key-value cache and then performs a forward pass
728
+ through the GPT model in eager mode to prefill the logits.
729
+ Returns:
730
+ torch.Tensor: The prefilled logits tensor with the first dimension removed.
731
+ """
732
+
733
+ self._reset_kv_cache()
734
+
735
+ # Prefill is always eager
736
+ prefill_logits = self.gpt_model(
737
+ embed=self.embed_buffer,
738
+ cond=self.cond_buffer,
739
+ kv_cache=self.kv_cache,
740
+ curr_pos_id=self.curr_pos_id,
741
+ decode=False,
742
+ )
743
+
744
+ return prefill_logits[:, 0, ...]
745
+
746
+ def _set_curr_pos_id(self, pos: int):
747
+ """
748
+ Set the current position ID.
749
+ This method updates the `curr_pos_id` attribute with the given position.
750
+ Args:
751
+ pos (int): The position ID to set.
752
+ """
753
+
754
+ self.curr_pos_id.copy_(
755
+ torch.tensor([pos], dtype=torch.long, device=self.device)
756
+ )
757
+
758
+ def run_gpt(
759
+ self,
760
+ prompts: list[str],
761
+ use_kv_cache: bool,
762
+ guidance_scale: float = 3.0,
763
+ top_p: float = None,
764
+ bounding_box_xyz: Optional[Tuple[float]] = None,
765
+ ):
766
+ """
767
+ Runs the GPT model to generate text based on the provided prompts.
768
+ Args:
769
+ prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
770
+ use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
771
+ guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
772
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
773
+ If None, argmax selection is performed. Otherwise, smallest
774
+ set of tokens with cumulative probability ≥ top_p are kept.
775
+ bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
776
+ as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
777
+ uses default bounding box sizing.
778
+ Returns:
779
+ torch.Tensor: A tensor containing the generated output token IDs.
780
+ Raises:
781
+ AssertionError: If the batch size is greater than 1.
782
+ """
783
+
784
+ embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
785
+ assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
786
+
787
+ batch_size, input_seq_len, _ = embed.shape
788
+ self.embed_buffer.zero_()
789
+ self.embed_buffer[:, :input_seq_len, :].copy_(embed)
790
+
791
+ assert self.cond_buffer.shape == cond.shape
792
+ self.cond_buffer.copy_(cond)
793
+
794
+ output_ids = torch.zeros(
795
+ (batch_size // 2, self.max_new_tokens), dtype=torch.int, device=self.device
796
+ )
797
+
798
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
799
+ self._set_curr_pos_id(0) #?
800
+
801
+ logits = self._prefill_and_return_logits()
802
+
803
+ logits = logits[..., self.min_id : self.max_id]
804
+ if guidance_scale > 0.0:
805
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
806
+ gamma = guidance_scale
807
+ logits = (1 + gamma) * logits - gamma * uncond_logits
808
+ next_id = process_logits(logits, top_p=top_p)
809
+
810
+ output_ids[:, 0] = next_id.squeeze()
811
+ next_embed = self.gpt_model.encode_token(next_id)
812
+ next_embed = next_embed.repeat(2, 1, 1)
813
+ self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
814
+
815
+ for i in tqdm(range(1, self.max_new_tokens), desc=f"generating"):
816
+ self._set_curr_pos_id(i)
817
+ self.graph.replay()
818
+
819
+ logits = self.logits_buffer[:, 0, ...]
820
+
821
+ logits = logits[..., self.min_id : self.max_id]
822
+ if guidance_scale > 0.0:
823
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
824
+ gamma = (
825
+ guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
826
+ )
827
+ logits = (1 + gamma) * logits - gamma * uncond_logits
828
+ next_id = process_logits(logits, top_p=top_p)
829
+
830
+ output_ids[:, i] = next_id.squeeze()
831
+ next_embed = self.gpt_model.encode_token(next_id)
832
+ next_embed = next_embed.repeat(2, 1, 1)
833
+ self.embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
834
+
835
+ #import ipdb; ipdb.set_trace()
836
+ return output_ids # torch.Size([1, 1024])
837
+
838
+ def run_apt(
839
+ self,
840
+ prompts: list[str],
841
+ use_kv_cache: bool,
842
+ guidance_scale: float = 3.0,
843
+ top_p: float = None,
844
+ bounding_box_xyz: Optional[Tuple[float]] = None,
845
+ ):
846
+ """
847
+ Runs the GPT model to generate text based on the provided prompts.
848
+ Args:
849
+ prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
850
+ use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
851
+ guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
852
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
853
+ If None, argmax selection is performed. Otherwise, smallest
854
+ set of tokens with cumulative probability ≥ top_p are kept.
855
+ bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
856
+ as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
857
+ uses default bounding box sizing.
858
+ Returns:
859
+ torch.Tensor: A tensor containing the generated output token IDs.
860
+ Raises:
861
+ AssertionError: If the batch size is greater than 1.
862
+ """
863
+
864
+ embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
865
+ assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
866
+
867
+ batch_size, input_seq_len, _ = embed.shape
868
+ self.embed_buffer.zero_()
869
+ self.embed_buffer[:, :input_seq_len, :].copy_(embed)
870
+
871
+ assert self.cond_buffer.shape == cond.shape
872
+ self.cond_buffer.copy_(cond) #distinguishing between cond and embed
873
+
874
+ # output_ids = torch.zeros(
875
+ # (batch_size // 2, self.max_new_tokens), dtype=torch.int, device=self.device
876
+ # )
877
+ output_ids = torch.zeros(
878
+ (batch_size // 2, self.max_new_tokens, embed.shape[2]), dtype=torch.int, device=self.device
879
+ )
880
+
881
+ #with torch.autocast(self.device.type, dtype=torch.bfloat16):
882
+ self._set_curr_pos_id(0)
883
+
884
+ logits = self._prefill_and_return_logits()
885
+
886
+ logits = logits[..., self.min_id : self.max_id]
887
+ if guidance_scale > 0.0:
888
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
889
+ gamma = guidance_scale
890
+ logits = (1 + gamma) * logits - gamma * uncond_logits
891
+ #next_id = process_logits(logits, top_p=top_p)
892
+
893
+ #import ipdb; ipdb.set_trace()
894
+ #output_ids[:, 0] = next_id.squeeze()
895
+ next_embed = process_logits_assembly(logits, 0) #self.gpt_model.encode_token(next_id)
896
+ output_ids[:, 0] = next_embed
897
+ next_embed = next_embed.repeat(2, 1, 1)
898
+ self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
899
+
900
+ for i in tqdm(range(1, self.max_new_tokens), desc=f"generating"):
901
+ self._set_curr_pos_id(i) #position id, indicating the current token position for kv and mask
902
+ self.graph.replay()
903
+
904
+ logits = self.logits_buffer[:, 0, ...]
905
+
906
+ logits = logits[..., self.min_id : self.max_id]
907
+ if guidance_scale > 0.0:
908
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
909
+ # gamma = (
910
+ # guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
911
+ # )
912
+ gamma = guidance_scale
913
+ logits = (1 + gamma) * logits - gamma * uncond_logits
914
+ #next_id = process_logits(logits, top_p=top_p)
915
+
916
+ #output_ids[:, i] = next_id.squeeze()
917
+
918
+ next_embed = process_logits_assembly(logits, i) #self.gpt_model.encode_token(next_id)
919
+ output_ids[:, i] = next_embed
920
+ next_embed = next_embed.repeat(2, 1, 1)
921
+ self.embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
922
+
923
+ # import ipdb; ipdb.set_trace()
924
+ return output_ids # torch.Size([1, 1024])