Update ChemQ3MTP/modeling_chemq3mtp.py
Browse files
ChemQ3MTP/modeling_chemq3mtp.py
CHANGED
|
@@ -5,6 +5,7 @@
|
|
| 5 |
# ========================
|
| 6 |
|
| 7 |
import os
|
|
|
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
|
@@ -24,6 +25,7 @@ import json
|
|
| 24 |
import numpy as np
|
| 25 |
from collections import Counter
|
| 26 |
from rdkit.Chem import rdMolDescriptors
|
|
|
|
| 27 |
|
| 28 |
logger = logging.get_logger(__name__)
|
| 29 |
|
|
@@ -373,16 +375,34 @@ class ChemQ3MTPForCausalLM(Qwen2ForCausalLM):
|
|
| 373 |
**kwargs
|
| 374 |
)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def generate_with_logprobs(
|
| 377 |
self,
|
| 378 |
input_ids: torch.LongTensor,
|
| 379 |
-
max_new_tokens: int =
|
| 380 |
temperature: float = 1.0,
|
| 381 |
top_k: Optional[int] = None,
|
| 382 |
top_p: Optional[float] = None,
|
| 383 |
do_sample: bool = True,
|
| 384 |
return_probs: bool = True,
|
| 385 |
-
tokenizer=None
|
| 386 |
) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 387 |
"""
|
| 388 |
Generate sequences with log probabilities for RL training.
|
|
@@ -392,10 +412,15 @@ class ChemQ3MTPForCausalLM(Qwen2ForCausalLM):
|
|
| 392 |
1. Use log_softmax instead of log(softmax) to avoid log(0) issues
|
| 393 |
2. Correct the gather operation for non-sampling case
|
| 394 |
3. Handle the case where filtered logits become -inf properly
|
|
|
|
| 395 |
"""
|
| 396 |
self.eval()
|
| 397 |
device = input_ids.device
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
# Normalize input shapes
|
| 400 |
if input_ids.dim() == 1:
|
| 401 |
input_ids = input_ids.unsqueeze(0)
|
|
|
|
| 5 |
# ========================
|
| 6 |
|
| 7 |
import os
|
| 8 |
+
import math
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 25 |
import numpy as np
|
| 26 |
from collections import Counter
|
| 27 |
from rdkit.Chem import rdMolDescriptors
|
| 28 |
+
from transformers.generation.utils import GenerationMixin
|
| 29 |
|
| 30 |
logger = logging.get_logger(__name__)
|
| 31 |
|
|
|
|
| 375 |
**kwargs
|
| 376 |
)
|
| 377 |
|
| 378 |
+
def generate(self, *args, **kwargs):
|
| 379 |
+
"""
|
| 380 |
+
Wrap HF GenerationMixin.generate so that
|
| 381 |
+
max_new_tokens = ceil(0.25 * prompt_length) when the caller
|
| 382 |
+
omits both max_new_tokens and max_length.
|
| 383 |
+
"""
|
| 384 |
+
# only touch if user did NOT set any length cap
|
| 385 |
+
if (kwargs.get("max_new_tokens") is None
|
| 386 |
+
and kwargs.get("max_length") is None):
|
| 387 |
+
|
| 388 |
+
# locate input_ids (works when passed positionally or by keyword)
|
| 389 |
+
input_ids = kwargs.get("input_ids", args[0] if args else None)
|
| 390 |
+
if input_ids is not None:
|
| 391 |
+
kwargs["max_new_tokens"] = max(1, math.ceil(input_ids.shape[1] * 0.25))
|
| 392 |
+
|
| 393 |
+
# delegate to HF implementation
|
| 394 |
+
return super().generate(*args, **kwargs)
|
| 395 |
+
|
| 396 |
def generate_with_logprobs(
|
| 397 |
self,
|
| 398 |
input_ids: torch.LongTensor,
|
| 399 |
+
max_new_tokens: int = None, # ← changed default to None
|
| 400 |
temperature: float = 1.0,
|
| 401 |
top_k: Optional[int] = None,
|
| 402 |
top_p: Optional[float] = None,
|
| 403 |
do_sample: bool = True,
|
| 404 |
return_probs: bool = True,
|
| 405 |
+
tokenizer=None
|
| 406 |
) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 407 |
"""
|
| 408 |
Generate sequences with log probabilities for RL training.
|
|
|
|
| 412 |
1. Use log_softmax instead of log(softmax) to avoid log(0) issues
|
| 413 |
2. Correct the gather operation for non-sampling case
|
| 414 |
3. Handle the case where filtered logits become -inf properly
|
| 415 |
+
NEW: if max_new_tokens is not given it is set to ceil(0.25 * prompt_length).
|
| 416 |
"""
|
| 417 |
self.eval()
|
| 418 |
device = input_ids.device
|
| 419 |
|
| 420 |
+
# ---------- auto-compute max_new_tokens ----------
|
| 421 |
+
if max_new_tokens is None:
|
| 422 |
+
max_new_tokens = max(1, math.ceil(input_ids.size(1) * 0))
|
| 423 |
+
|
| 424 |
# Normalize input shapes
|
| 425 |
if input_ids.dim() == 1:
|
| 426 |
input_ids = input_ids.unsqueeze(0)
|