| |
| |
| |
| |
| |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.distributions import Categorical |
| from typing import List, Union, Optional, Tuple, Dict, Any |
| from transformers import Qwen2Config, Qwen2ForCausalLM, AutoTokenizer |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.utils import logging |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.modeling_utils import PreTrainedModel |
| from rdkit import Chem |
| from rdkit.Chem import Descriptors, Lipinski |
| import selfies as sf |
| from rdkit import RDLogger |
| RDLogger.DisableLog('rdApp.*') |
| import json |
| import numpy as np |
| from collections import Counter |
| from rdkit.Chem import rdMolDescriptors |
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| |
| |
|
|
| class ChemQ3MTPConfig(Qwen2Config): |
| """ |
| Configuration class for ChemQ3MTP model. |
| """ |
| model_type = "chemq3_mtp" |
| |
| def __init__( |
| self, |
| num_future_tokens: int = 3, |
| horizon_weights: Optional[List[float]] = None, |
| use_mtp_training: bool = True, |
| entropy_controller_config: Optional[Dict[str, Any]] = None, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.num_future_tokens = num_future_tokens |
| self.horizon_weights = horizon_weights or [0.9 ** i for i in range(num_future_tokens)] |
| self.use_mtp_training = use_mtp_training |
| self.entropy_controller_config = entropy_controller_config or { |
| "min_entropy": 0.5, |
| "max_entropy": 3.0, |
| "target_entropy": 1.5, |
| "adaptation_rate": 0.01 |
| } |
|
|
| |
| |
| |
|
|
| def selfies_to_smiles(selfies_str: str) -> str | None: |
| """Convert SELFIES string to SMILES, handling tokenizer artifacts.""" |
| try: |
| clean_selfies = selfies_str.replace(" ", "") |
| return sf.decoder(clean_selfies) |
| except Exception: |
| return None |
|
|
| def is_valid_smiles(smiles: str) -> bool: |
| if not isinstance(smiles, str) or len(smiles.strip()) == 0: |
| return False |
| return Chem.MolFromSmiles(smiles.strip()) is not None |
|
|
| |
| |
| |
|
|
| class MTPHead(nn.Module): |
| """Multi-Token Prediction Head for predicting future tokens.""" |
| |
| def __init__(self, hidden_size: int, vocab_size: int, num_future_tokens: int = 3): |
| super().__init__() |
| self.num_future_tokens = num_future_tokens |
| self.vocab_size = vocab_size |
| self.prediction_heads = nn.ModuleList([ |
| nn.Linear(hidden_size, vocab_size, bias=False) |
| for _ in range(num_future_tokens) |
| ]) |
| self.position_embeddings = nn.Embedding(num_future_tokens, hidden_size) |
| self.layer_norm = nn.LayerNorm(hidden_size) |
| |
| def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]: |
| batch_size, seq_len, hidden_size = hidden_states.shape |
| outputs = {} |
| |
| for i in range(self.num_future_tokens): |
| pos_emb = self.position_embeddings(torch.tensor(i, device=hidden_states.device)) |
| enhanced_hidden = self.layer_norm(hidden_states + pos_emb) |
| logits = self.prediction_heads[i](enhanced_hidden) |
| outputs[f'logits_t{i+1}'] = logits |
| |
| return outputs |
|
|
|
|
| class HorizonLoss(nn.Module): |
| """Loss function for multi-horizon prediction.""" |
| |
| def __init__(self, num_future_tokens: int = 3, horizon_weights: Optional[List[float]] = None): |
| super().__init__() |
| self.num_future_tokens = num_future_tokens |
| if horizon_weights is None: |
| self.horizon_weights = [0.9 ** i for i in range(num_future_tokens)] |
| else: |
| self.horizon_weights = horizon_weights |
| self.log_weights = nn.Parameter(torch.log(torch.tensor(self.horizon_weights))) |
| |
| def forward( |
| self, |
| mtp_outputs: Dict[str, torch.Tensor], |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None |
| ) -> Dict[str, torch.Tensor]: |
| |
| batch_size, seq_len = input_ids.shape |
| device = input_ids.device |
| weights = F.softmax(self.log_weights, dim=0) |
| total_loss = 0.0 |
| horizon_losses = {} |
| |
| for i in range(self.num_future_tokens): |
| logits_key = f'logits_t{i+1}' |
| if logits_key not in mtp_outputs: |
| continue |
| |
| logits = mtp_outputs[logits_key] |
| shift = i + 1 |
| if seq_len <= shift: |
| continue |
| |
| shifted_logits = logits[:, :-shift, :].contiguous() |
| shifted_targets = input_ids[:, shift:].contiguous() |
| |
| if attention_mask is not None: |
| shifted_mask = attention_mask[:, shift:].contiguous() |
| mask_expanded = shifted_mask.view(-1) |
| valid_indices = mask_expanded == 1 |
| if valid_indices.sum() == 0: |
| continue |
| flat_logits = shifted_logits.view(-1, logits.size(-1))[valid_indices] |
| flat_targets = shifted_targets.view(-1)[valid_indices] |
| else: |
| flat_logits = shifted_logits.view(-1, logits.size(-1)) |
| flat_targets = shifted_targets.view(-1) |
| |
| horizon_loss = F.cross_entropy(flat_logits, flat_targets, reduction='mean') |
| horizon_losses[f'horizon_loss_t{i+1}'] = horizon_loss |
| total_loss += weights[i] * horizon_loss |
| |
| return {'loss': total_loss, 'horizon_weights': weights, **horizon_losses} |
|
|
|
|
| class EnhancedEntropyController: |
| """Enhanced entropy controller for adaptive training.""" |
| |
| def __init__(self, min_entropy: float = 0.5, max_entropy: float = 3.0, |
| target_entropy: float = 1.5, adaptation_rate: float = 0.01): |
| self.min_entropy = min_entropy |
| self.max_entropy = max_entropy |
| self.target_entropy = target_entropy |
| self.adaptation_rate = adaptation_rate |
| self.entropy_history = [] |
| self.entropy_weight = 0.01 |
| |
| def update_entropy_weight(self, current_entropy: float) -> float: |
| """Dynamically adjust entropy weight based on current entropy levels.""" |
| self.entropy_history.append(current_entropy) |
| |
| if len(self.entropy_history) > 100: |
| self.entropy_history = self.entropy_history[-100:] |
| |
| if len(self.entropy_history) >= 10: |
| avg_entropy = np.mean(self.entropy_history[-10:]) |
| |
| if avg_entropy < self.target_entropy * 0.8: |
| self.entropy_weight = min(0.05, self.entropy_weight * 1.1) |
| elif avg_entropy > self.target_entropy * 1.2: |
| self.entropy_weight = max(0.001, self.entropy_weight * 0.95) |
| |
| return self.entropy_weight |
|
|
| |
| |
| |
|
|
| class ChemQ3MTPForCausalLM(Qwen2ForCausalLM): |
| """ |
| ChemQ3MTP model for causal language modeling with multi-token prediction. |
| |
| This model extends Qwen2ForCausalLM with additional capabilities for |
| multi-token prediction and chemistry-specific training. |
| """ |
| |
| config_class = ChemQ3MTPConfig |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| |
| def __init__(self, config: ChemQ3MTPConfig): |
| super().__init__(config) |
| |
| |
| self.mtp_head = MTPHead( |
| config.hidden_size, |
| config.vocab_size, |
| config.num_future_tokens |
| ) |
| self.horizon_loss = HorizonLoss( |
| num_future_tokens=config.num_future_tokens, |
| horizon_weights=config.horizon_weights |
| ) |
| |
| |
| self.use_mtp_training = config.use_mtp_training |
| |
| |
| self.entropy_controller = EnhancedEntropyController( |
| **config.entropy_controller_config |
| ) |
| |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| """ |
| Forward pass of the ChemQ3MTP model. |
| """ |
| |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if attention_mask is None and input_ids is not None: |
| |
| if hasattr(self.config, 'pad_token_id') and self.config.pad_token_id is not None: |
| attention_mask = (input_ids != self.config.pad_token_id).long() |
| else: |
| |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long) |
|
|
| |
| outputs = super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=None, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=True, |
| return_dict=True, |
| cache_position=cache_position, |
| ) |
|
|
| |
| hidden_states = outputs.hidden_states[-1] |
| lm_logits = outputs.logits |
| loss = None |
|
|
| |
| if labels is not None: |
| if self.training and self.use_mtp_training: |
| |
| mtp_outputs = self.mtp_head(hidden_states) |
| horizon_loss_dict = self.horizon_loss(mtp_outputs, input_ids, attention_mask) |
|
|
| |
| shift_logits = lm_logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| if attention_mask is not None: |
| shift_mask = attention_mask[..., 1:].contiguous() |
| loss_mask = shift_mask.view(-1) == 1 |
| if loss_mask.sum() == 0: |
| causal_lm_loss = torch.tensor(0.0, device=lm_logits.device) |
| else: |
| flat_logits = shift_logits.view(-1, shift_logits.size(-1))[loss_mask] |
| flat_labels = shift_labels.view(-1)[loss_mask] |
| causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean') |
| else: |
| flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
| flat_labels = shift_labels.view(-1) |
| causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean') |
|
|
| |
| loss = 0.7 * horizon_loss_dict['loss'] + 0.3 * causal_lm_loss |
|
|
| else: |
| |
| shift_logits = lm_logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100 |
| ) |
|
|
| if not return_dict: |
| output = (lm_logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=lm_logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def set_mtp_training(self, use_mtp: bool): |
| """Enable or disable multi-token prediction training.""" |
| self.use_mtp_training = use_mtp |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| **kwargs |
| ): |
| """ |
| Prepare inputs for generation. This method is required for compatibility |
| with HuggingFace's generation utilities. |
| """ |
| |
| return super().prepare_inputs_for_generation( |
| input_ids=input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| **kwargs |
| ) |
|
|
| def generate_with_logprobs( |
| self, |
| input_ids: torch.LongTensor, |
| max_new_tokens: int = 50, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| do_sample: bool = True, |
| return_probs: bool = True, |
| tokenizer=None, |
| ) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| """ |
| Generate sequences with log probabilities for RL training. |
| |
| FIXED VERSION: Corrects log probability calculation to avoid numerical issues. |
| Changes: |
| 1. Use log_softmax instead of log(softmax) to avoid log(0) issues |
| 2. Correct the gather operation for non-sampling case |
| 3. Handle the case where filtered logits become -inf properly |
| """ |
| self.eval() |
| device = input_ids.device |
|
|
| |
| if input_ids.dim() == 1: |
| input_ids = input_ids.unsqueeze(0) |
| if input_ids.dim() == 3 and input_ids.size(1) == 1: |
| input_ids = input_ids.squeeze(1) |
| assert input_ids.dim() == 2, f"input_ids must be 2-D, got {input_ids.shape}" |
|
|
| batch_size, seq_len = input_ids.shape |
| current_input = input_ids |
|
|
| generated_tokens, generated_logprobs, generated_probs = [], [], [] |
|
|
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| outputs = self(current_input, use_cache=False) |
| logits = outputs.logits[:, -1, :] / temperature |
|
|
| |
| if top_k is not None: |
| values, indices = torch.topk(logits, k=top_k) |
| logits = torch.full_like(logits, float("-inf")) |
| logits.scatter_(1, indices, values) |
|
|
| |
| if top_p is not None and top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| mask = cumprobs > top_p |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| logits[mask.scatter(1, sorted_indices, mask)] = float("-inf") |
|
|
| |
| log_probs = F.log_softmax(logits, dim=-1) |
| probs = F.softmax(logits, dim=-1) |
|
|
| if do_sample: |
| dist = Categorical(probs) |
| next_token = dist.sample() |
| |
| log_p = torch.gather(log_probs, 1, next_token.unsqueeze(1)).squeeze(1) |
| else: |
| next_token = torch.argmax(probs, dim=-1) |
| |
| log_p = torch.gather(log_probs, 1, next_token.unsqueeze(1)).squeeze(1) |
|
|
| generated_tokens.append(next_token.unsqueeze(1)) |
| generated_logprobs.append(log_p.unsqueeze(1)) |
| if return_probs: |
| generated_probs.append(probs.unsqueeze(1)) |
|
|
| current_input = torch.cat([current_input, next_token.unsqueeze(1)], dim=1) |
|
|
| generated_tokens = torch.cat(generated_tokens, dim=1) |
| generated_logprobs = torch.cat(generated_logprobs, dim=1) |
| generated_probs = torch.cat(generated_probs, dim=1) if return_probs else None |
|
|
| |
| if tokenizer is None: |
| tokenizer = getattr(self, "tokenizer", None) |
| if tokenizer is None: |
| raise ValueError("Tokenizer must be provided to decode generated tokens.") |
|
|
| decoded_list = [ |
| tokenizer.decode(tok_ids, skip_special_tokens=True) |
| for tok_ids in generated_tokens |
| ] |
| |
| return decoded_list, generated_logprobs, generated_tokens, generated_probs |
|
|
| |
| |
| |
|
|
| |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| |
| AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig) |
| AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM) |