YongganFu commited on
Commit
7542260
·
verified ·
1 Parent(s): 41d881f

Upload model

Browse files
Files changed (2) hide show
  1. chat_utils.py +0 -13
  2. modeling_nvrdiff.py +15 -38
chat_utils.py CHANGED
@@ -11,20 +11,7 @@ import torch
11
  import torch.nn.functional as F
12
  from transformers import AutoTokenizer
13
 
14
- sys.path.insert(1, "/lustre/fsw/portfolios/nvr/users/yongganf/adlr-megatron-lm")
15
- from get_hf_model import get_torchtitan_model_sft # noqa: E402
16
 
17
-
18
- # --------------------------- Reproducibility ----------------------------------
19
- def set_seed(seed: int = 42):
20
- torch.manual_seed(seed)
21
- random.seed(seed)
22
- np.random.seed(seed)
23
- torch.backends.cudnn.deterministic = True
24
- torch.backends.cudnn.benchmark = False
25
-
26
-
27
- # -------------------- Diffusion helpers (unchanged logic) --------------------
28
  def get_transfer_index(
29
  logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False
30
  ):
 
11
  import torch.nn.functional as F
12
  from transformers import AutoTokenizer
13
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  def get_transfer_index(
16
  logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False
17
  ):
modeling_nvrdiff.py CHANGED
@@ -535,41 +535,18 @@ class DiffEncoderModel(Qwen3PreTrainedModel, GenerationMixin):
535
  )
536
 
537
 
538
- def chat(self, tokenizer, max_new_tokens, steps, block_length, threshold):
539
- print("Stateless chat (type 'exit' to quit)")
540
- print("------------------------------------")
541
-
542
- try:
543
- while True:
544
- user_input = input("User: ").strip()
545
- if user_input.lower() in {"exit", "quit", "q"}:
546
- print("Conversation ended.")
547
- break
548
-
549
- prompt_ids = tokenizer(
550
- user_input,return_tensors='pt'
551
- ).input_ids.to(device='cuda')
552
-
553
- out_ids, nfe = generate_with_prefix_cache_block_diff(
554
- model=self,
555
- prompt=prompt_ids,
556
- gen_length=max_new_tokens,
557
- steps=steps,
558
- block_length=block_length,
559
- remasking="low_confidence",
560
- mask_id=self.mask_token_id,
561
- threshold=threshold,
562
- shift_logits=True,
563
- neg_entropy=True,
564
- )
565
-
566
- generated_tokens = out_ids[:, prompt_ids.shape[1]:]
567
- tokenized_out = tokenizer.batch_decode(
568
- generated_tokens,
569
- skip_special_tokens=True
570
- )[0]
571
- print(f"Model: {tokenized_out}")
572
- print(f"[nfe={nfe}]")
573
-
574
- except KeyboardInterrupt:
575
- print("\n[info] interrupted by user (Ctrl-C).")
 
535
  )
536
 
537
 
538
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, threshold):
539
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
540
+ model=self,
541
+ prompt=prompt_ids,
542
+ gen_length=max_new_tokens,
543
+ steps=steps,
544
+ block_length=block_length,
545
+ remasking="low_confidence",
546
+ mask_id=self.mask_token_id,
547
+ threshold=threshold,
548
+ shift_logits=True,
549
+ neg_entropy=True,
550
+ )
551
+
552
+ return out_ids, nfe