gbyuvd commited on
Commit
bfc65c7
·
verified ·
1 Parent(s): 137bf2b

Update ChemQ3MTP/modeling_chemq3mtp.py

Browse files
Files changed (1) hide show
  1. ChemQ3MTP/modeling_chemq3mtp.py +27 -2
ChemQ3MTP/modeling_chemq3mtp.py CHANGED
@@ -5,6 +5,7 @@
5
  # ========================
6
 
7
  import os
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
@@ -24,6 +25,7 @@ import json
24
  import numpy as np
25
  from collections import Counter
26
  from rdkit.Chem import rdMolDescriptors
 
27
 
28
  logger = logging.get_logger(__name__)
29
 
@@ -373,16 +375,34 @@ class ChemQ3MTPForCausalLM(Qwen2ForCausalLM):
373
  **kwargs
374
  )
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def generate_with_logprobs(
377
  self,
378
  input_ids: torch.LongTensor,
379
- max_new_tokens: int = 50,
380
  temperature: float = 1.0,
381
  top_k: Optional[int] = None,
382
  top_p: Optional[float] = None,
383
  do_sample: bool = True,
384
  return_probs: bool = True,
385
- tokenizer=None,
386
  ) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
387
  """
388
  Generate sequences with log probabilities for RL training.
@@ -392,10 +412,15 @@ class ChemQ3MTPForCausalLM(Qwen2ForCausalLM):
392
  1. Use log_softmax instead of log(softmax) to avoid log(0) issues
393
  2. Correct the gather operation for non-sampling case
394
  3. Handle the case where filtered logits become -inf properly
 
395
  """
396
  self.eval()
397
  device = input_ids.device
398
 
 
 
 
 
399
  # Normalize input shapes
400
  if input_ids.dim() == 1:
401
  input_ids = input_ids.unsqueeze(0)
 
5
  # ========================
6
 
7
  import os
8
+ import math
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
 
25
  import numpy as np
26
  from collections import Counter
27
  from rdkit.Chem import rdMolDescriptors
28
+ from transformers.generation.utils import GenerationMixin
29
 
30
  logger = logging.get_logger(__name__)
31
 
 
375
  **kwargs
376
  )
377
 
378
+ def generate(self, *args, **kwargs):
379
+ """
380
+ Wrap HF GenerationMixin.generate so that
381
+ max_new_tokens = ceil(0.25 * prompt_length) when the caller
382
+ omits both max_new_tokens and max_length.
383
+ """
384
+ # only touch if user did NOT set any length cap
385
+ if (kwargs.get("max_new_tokens") is None
386
+ and kwargs.get("max_length") is None):
387
+
388
+ # locate input_ids (works when passed positionally or by keyword)
389
+ input_ids = kwargs.get("input_ids", args[0] if args else None)
390
+ if input_ids is not None:
391
+ kwargs["max_new_tokens"] = max(1, math.ceil(input_ids.shape[1] * 0.25))
392
+
393
+ # delegate to HF implementation
394
+ return super().generate(*args, **kwargs)
395
+
396
  def generate_with_logprobs(
397
  self,
398
  input_ids: torch.LongTensor,
399
+ max_new_tokens: int = None, # ← changed default to None
400
  temperature: float = 1.0,
401
  top_k: Optional[int] = None,
402
  top_p: Optional[float] = None,
403
  do_sample: bool = True,
404
  return_probs: bool = True,
405
+ tokenizer=None
406
  ) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
407
  """
408
  Generate sequences with log probabilities for RL training.
 
412
  1. Use log_softmax instead of log(softmax) to avoid log(0) issues
413
  2. Correct the gather operation for non-sampling case
414
  3. Handle the case where filtered logits become -inf properly
415
+ NEW: if max_new_tokens is not given it is set to ceil(0.25 * prompt_length).
416
  """
417
  self.eval()
418
  device = input_ids.device
419
 
420
+ # ---------- auto-compute max_new_tokens ----------
421
+ if max_new_tokens is None:
422
+ max_new_tokens = max(1, math.ceil(input_ids.size(1) * 0))
423
+
424
  # Normalize input shapes
425
  if input_ids.dim() == 1:
426
  input_ids = input_ids.unsqueeze(0)