gbyuvd commited on
Commit
9627161
·
verified ·
1 Parent(s): 8750859

Update modeling_chemq3mtp.py

Browse files
Files changed (1) hide show
  1. modeling_chemq3mtp.py +466 -456
modeling_chemq3mtp.py CHANGED
@@ -1,457 +1,467 @@
1
- # ========================
2
- # ChemQ3-MTP - HuggingFace Compatible Version
3
- # MODEL COMPONENTS
4
- # by gbyuvd
5
- # ========================
6
-
7
- import os
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from torch.distributions import Categorical
12
- from typing import List, Union, Optional, Tuple, Dict, Any
13
- from transformers import Qwen2Config, Qwen2ForCausalLM, AutoTokenizer
14
- from transformers.modeling_outputs import CausalLMOutputWithPast
15
- from transformers.utils import logging
16
- from transformers.configuration_utils import PretrainedConfig
17
- from transformers.modeling_utils import PreTrainedModel
18
- from rdkit import Chem
19
- from rdkit.Chem import Descriptors, Lipinski
20
- import selfies as sf
21
- from rdkit import RDLogger
22
- RDLogger.DisableLog('rdApp.*')
23
- import json
24
- import numpy as np
25
- from collections import Counter
26
- from rdkit.Chem import rdMolDescriptors
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
- # ========================
31
- # CONFIGURATION CLASS
32
- # ========================
33
-
34
- class ChemQ3MTPConfig(Qwen2Config):
35
- """
36
- Configuration class for ChemQ3MTP model.
37
- """
38
- model_type = "chemq3_mtp"
39
-
40
- def __init__(
41
- self,
42
- num_future_tokens: int = 3,
43
- horizon_weights: Optional[List[float]] = None,
44
- use_mtp_training: bool = True,
45
- entropy_controller_config: Optional[Dict[str, Any]] = None,
46
- **kwargs
47
- ):
48
- super().__init__(**kwargs)
49
- self.num_future_tokens = num_future_tokens
50
- self.horizon_weights = horizon_weights or [0.9 ** i for i in range(num_future_tokens)]
51
- self.use_mtp_training = use_mtp_training
52
- self.entropy_controller_config = entropy_controller_config or {
53
- "min_entropy": 0.5,
54
- "max_entropy": 3.0,
55
- "target_entropy": 1.5,
56
- "adaptation_rate": 0.01
57
- }
58
-
59
- # ========================
60
- # UTILITY FUNCTIONS (kept minimal for HF compatibility)
61
- # ========================
62
-
63
- def selfies_to_smiles(selfies_str: str) -> str | None:
64
- """Convert SELFIES string to SMILES, handling tokenizer artifacts."""
65
- try:
66
- clean_selfies = selfies_str.replace(" ", "")
67
- return sf.decoder(clean_selfies)
68
- except Exception:
69
- return None
70
-
71
- def is_valid_smiles(smiles: str) -> bool:
72
- if not isinstance(smiles, str) or len(smiles.strip()) == 0:
73
- return False
74
- return Chem.MolFromSmiles(smiles.strip()) is not None
75
-
76
- # ========================
77
- # MODEL COMPONENTS
78
- # ========================
79
-
80
- class MTPHead(nn.Module):
81
- """Multi-Token Prediction Head for predicting future tokens."""
82
-
83
- def __init__(self, hidden_size: int, vocab_size: int, num_future_tokens: int = 3):
84
- super().__init__()
85
- self.num_future_tokens = num_future_tokens
86
- self.vocab_size = vocab_size
87
- self.prediction_heads = nn.ModuleList([
88
- nn.Linear(hidden_size, vocab_size, bias=False)
89
- for _ in range(num_future_tokens)
90
- ])
91
- self.position_embeddings = nn.Embedding(num_future_tokens, hidden_size)
92
- self.layer_norm = nn.LayerNorm(hidden_size)
93
-
94
- def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
95
- batch_size, seq_len, hidden_size = hidden_states.shape
96
- outputs = {}
97
-
98
- for i in range(self.num_future_tokens):
99
- pos_emb = self.position_embeddings(torch.tensor(i, device=hidden_states.device))
100
- enhanced_hidden = self.layer_norm(hidden_states + pos_emb)
101
- logits = self.prediction_heads[i](enhanced_hidden)
102
- outputs[f'logits_t{i+1}'] = logits
103
-
104
- return outputs
105
-
106
-
107
- class HorizonLoss(nn.Module):
108
- """Loss function for multi-horizon prediction."""
109
-
110
- def __init__(self, num_future_tokens: int = 3, horizon_weights: Optional[List[float]] = None):
111
- super().__init__()
112
- self.num_future_tokens = num_future_tokens
113
- if horizon_weights is None:
114
- self.horizon_weights = [0.9 ** i for i in range(num_future_tokens)]
115
- else:
116
- self.horizon_weights = horizon_weights
117
- self.log_weights = nn.Parameter(torch.log(torch.tensor(self.horizon_weights)))
118
-
119
- def forward(
120
- self,
121
- mtp_outputs: Dict[str, torch.Tensor],
122
- input_ids: torch.Tensor,
123
- attention_mask: Optional[torch.Tensor] = None
124
- ) -> Dict[str, torch.Tensor]:
125
-
126
- batch_size, seq_len = input_ids.shape
127
- device = input_ids.device
128
- weights = F.softmax(self.log_weights, dim=0)
129
- total_loss = 0.0
130
- horizon_losses = {}
131
-
132
- for i in range(self.num_future_tokens):
133
- logits_key = f'logits_t{i+1}'
134
- if logits_key not in mtp_outputs:
135
- continue
136
-
137
- logits = mtp_outputs[logits_key]
138
- shift = i + 1
139
- if seq_len <= shift:
140
- continue
141
-
142
- shifted_logits = logits[:, :-shift, :].contiguous()
143
- shifted_targets = input_ids[:, shift:].contiguous()
144
-
145
- if attention_mask is not None:
146
- shifted_mask = attention_mask[:, shift:].contiguous()
147
- mask_expanded = shifted_mask.view(-1)
148
- valid_indices = mask_expanded == 1
149
- if valid_indices.sum() == 0:
150
- continue
151
- flat_logits = shifted_logits.view(-1, logits.size(-1))[valid_indices]
152
- flat_targets = shifted_targets.view(-1)[valid_indices]
153
- else:
154
- flat_logits = shifted_logits.view(-1, logits.size(-1))
155
- flat_targets = shifted_targets.view(-1)
156
-
157
- horizon_loss = F.cross_entropy(flat_logits, flat_targets, reduction='mean')
158
- horizon_losses[f'horizon_loss_t{i+1}'] = horizon_loss
159
- total_loss += weights[i] * horizon_loss
160
-
161
- return {'loss': total_loss, 'horizon_weights': weights, **horizon_losses}
162
-
163
-
164
- class EnhancedEntropyController:
165
- """Enhanced entropy controller for adaptive training."""
166
-
167
- def __init__(self, min_entropy: float = 0.5, max_entropy: float = 3.0,
168
- target_entropy: float = 1.5, adaptation_rate: float = 0.01):
169
- self.min_entropy = min_entropy
170
- self.max_entropy = max_entropy
171
- self.target_entropy = target_entropy
172
- self.adaptation_rate = adaptation_rate
173
- self.entropy_history = []
174
- self.entropy_weight = 0.01
175
-
176
- def update_entropy_weight(self, current_entropy: float) -> float:
177
- """Dynamically adjust entropy weight based on current entropy levels."""
178
- self.entropy_history.append(current_entropy)
179
-
180
- if len(self.entropy_history) > 100:
181
- self.entropy_history = self.entropy_history[-100:]
182
-
183
- if len(self.entropy_history) >= 10:
184
- avg_entropy = np.mean(self.entropy_history[-10:])
185
-
186
- if avg_entropy < self.target_entropy * 0.8:
187
- self.entropy_weight = min(0.05, self.entropy_weight * 1.1)
188
- elif avg_entropy > self.target_entropy * 1.2:
189
- self.entropy_weight = max(0.001, self.entropy_weight * 0.95)
190
-
191
- return self.entropy_weight
192
-
193
- # ========================
194
- # MAIN MODEL CLASS
195
- # ========================
196
-
197
- class ChemQ3MTPForCausalLM(Qwen2ForCausalLM):
198
- """
199
- ChemQ3MTP model for causal language modeling with multi-token prediction.
200
-
201
- This model extends Qwen2ForCausalLM with additional capabilities for
202
- multi-token prediction and chemistry-specific training.
203
- """
204
-
205
- config_class = ChemQ3MTPConfig
206
- _supports_flash_attn_2 = True
207
- _supports_sdpa = True
208
- _supports_cache_class = True
209
-
210
- def __init__(self, config: ChemQ3MTPConfig):
211
- super().__init__(config)
212
-
213
- # Initialize MTP components
214
- self.mtp_head = MTPHead(
215
- config.hidden_size,
216
- config.vocab_size,
217
- config.num_future_tokens
218
- )
219
- self.horizon_loss = HorizonLoss(
220
- num_future_tokens=config.num_future_tokens,
221
- horizon_weights=config.horizon_weights
222
- )
223
-
224
- # Training configuration
225
- self.use_mtp_training = config.use_mtp_training
226
-
227
- # Initialize entropy controller
228
- self.entropy_controller = EnhancedEntropyController(
229
- **config.entropy_controller_config
230
- )
231
-
232
- # Initialize weights and apply final processing
233
- self.post_init()
234
-
235
- def forward(
236
- self,
237
- input_ids: Optional[torch.LongTensor] = None,
238
- attention_mask: Optional[torch.FloatTensor] = None,
239
- position_ids: Optional[torch.LongTensor] = None,
240
- past_key_values: Optional[List[torch.FloatTensor]] = None,
241
- inputs_embeds: Optional[torch.FloatTensor] = None,
242
- labels: Optional[torch.LongTensor] = None,
243
- use_cache: Optional[bool] = None,
244
- output_attentions: Optional[bool] = None,
245
- output_hidden_states: Optional[bool] = None,
246
- return_dict: Optional[bool] = None,
247
- cache_position: Optional[torch.LongTensor] = None,
248
- ) -> Union[Tuple, CausalLMOutputWithPast]:
249
- """
250
- Forward pass of the ChemQ3MTP model.
251
- """
252
-
253
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
254
- output_hidden_states = (
255
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
256
- )
257
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
258
-
259
- # Default attention mask if not provided
260
- if attention_mask is None and input_ids is not None:
261
- # Handle case where pad_token_id is None
262
- if hasattr(self.config, 'pad_token_id') and self.config.pad_token_id is not None:
263
- attention_mask = (input_ids != self.config.pad_token_id).long()
264
- else:
265
- # Default to all 1s if no pad_token_id is defined
266
- attention_mask = torch.ones_like(input_ids, dtype=torch.long)
267
-
268
- # Call parent forward with required hidden states
269
- outputs = super().forward(
270
- input_ids=input_ids,
271
- attention_mask=attention_mask,
272
- position_ids=position_ids,
273
- past_key_values=past_key_values,
274
- inputs_embeds=inputs_embeds,
275
- labels=None, # Handle labels manually
276
- use_cache=use_cache,
277
- output_attentions=output_attentions,
278
- output_hidden_states=True, # Always need hidden states for MTP
279
- return_dict=True,
280
- cache_position=cache_position,
281
- )
282
-
283
- # Rest of your forward method...
284
- hidden_states = outputs.hidden_states[-1]
285
- lm_logits = outputs.logits
286
- loss = None
287
-
288
- # Compute loss if labels are provided
289
- if labels is not None:
290
- if self.training and self.use_mtp_training:
291
- # Multi-token prediction training
292
- mtp_outputs = self.mtp_head(hidden_states)
293
- horizon_loss_dict = self.horizon_loss(mtp_outputs, input_ids, attention_mask)
294
-
295
- # Standard causal LM loss
296
- shift_logits = lm_logits[..., :-1, :].contiguous()
297
- shift_labels = labels[..., 1:].contiguous()
298
-
299
- if attention_mask is not None:
300
- shift_mask = attention_mask[..., 1:].contiguous()
301
- loss_mask = shift_mask.view(-1) == 1
302
- if loss_mask.sum() == 0:
303
- causal_lm_loss = torch.tensor(0.0, device=lm_logits.device)
304
- else:
305
- flat_logits = shift_logits.view(-1, shift_logits.size(-1))[loss_mask]
306
- flat_labels = shift_labels.view(-1)[loss_mask]
307
- causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
308
- else:
309
- flat_logits = shift_logits.view(-1, shift_logits.size(-1))
310
- flat_labels = shift_labels.view(-1)
311
- causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
312
-
313
- # Combine losses
314
- loss = 0.7 * horizon_loss_dict['loss'] + 0.3 * causal_lm_loss
315
-
316
- else:
317
- # Standard causal LM training
318
- shift_logits = lm_logits[..., :-1, :].contiguous()
319
- shift_labels = labels[..., 1:].contiguous()
320
- loss = F.cross_entropy(
321
- shift_logits.view(-1, shift_logits.size(-1)),
322
- shift_labels.view(-1),
323
- ignore_index=-100
324
- )
325
-
326
- if not return_dict:
327
- output = (lm_logits,) + outputs[1:]
328
- return (loss,) + output if loss is not None else output
329
-
330
- return CausalLMOutputWithPast(
331
- loss=loss,
332
- logits=lm_logits,
333
- past_key_values=outputs.past_key_values,
334
- hidden_states=outputs.hidden_states,
335
- attentions=outputs.attentions,
336
- )
337
-
338
- def set_mtp_training(self, use_mtp: bool):
339
- """Enable or disable multi-token prediction training."""
340
- self.use_mtp_training = use_mtp
341
-
342
- def prepare_inputs_for_generation(
343
- self,
344
- input_ids,
345
- past_key_values=None,
346
- attention_mask=None,
347
- inputs_embeds=None,
348
- cache_position=None,
349
- **kwargs
350
- ):
351
- """
352
- Prepare inputs for generation. This method is required for compatibility
353
- with HuggingFace's generation utilities.
354
- """
355
- # This delegates to the parent class implementation
356
- return super().prepare_inputs_for_generation(
357
- input_ids=input_ids,
358
- past_key_values=past_key_values,
359
- attention_mask=attention_mask,
360
- inputs_embeds=inputs_embeds,
361
- cache_position=cache_position,
362
- **kwargs
363
- )
364
-
365
- def generate_with_logprobs(
366
- self,
367
- input_ids: torch.LongTensor,
368
- max_new_tokens: int = 50,
369
- temperature: float = 1.0,
370
- top_k: Optional[int] = None,
371
- top_p: Optional[float] = None,
372
- do_sample: bool = True,
373
- return_probs: bool = True,
374
- tokenizer=None,
375
- ) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
376
- """
377
- Generate sequences with log probabilities for RL training.
378
- """
379
- self.eval()
380
- device = input_ids.device
381
-
382
- # Normalize input shapes
383
- if input_ids.dim() == 1:
384
- input_ids = input_ids.unsqueeze(0)
385
- if input_ids.dim() == 3 and input_ids.size(1) == 1:
386
- input_ids = input_ids.squeeze(1)
387
- assert input_ids.dim() == 2, f"input_ids must be 2-D, got {input_ids.shape}"
388
-
389
- batch_size, seq_len = input_ids.shape
390
- current_input = input_ids
391
-
392
- generated_tokens, generated_logprobs, generated_probs = [], [], []
393
-
394
- with torch.no_grad():
395
- for _ in range(max_new_tokens):
396
- outputs = self(current_input, use_cache=False)
397
- logits = outputs.logits[:, -1, :] / temperature
398
-
399
- # Apply top-k filtering
400
- if top_k is not None:
401
- values, indices = torch.topk(logits, k=top_k)
402
- logits = torch.full_like(logits, float("-inf"))
403
- logits.scatter_(1, indices, values)
404
-
405
- # Apply top-p filtering
406
- if top_p is not None and top_p < 1.0:
407
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
408
- cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
409
- mask = cumprobs > top_p
410
- mask[..., 1:] = mask[..., :-1].clone()
411
- mask[..., 0] = False
412
- logits[mask.scatter(1, sorted_indices, mask)] = float("-inf")
413
-
414
- probs = F.softmax(logits, dim=-1)
415
-
416
- if do_sample:
417
- dist = Categorical(probs)
418
- next_token = dist.sample()
419
- log_p = dist.log_prob(next_token)
420
- else:
421
- next_token = torch.argmax(probs, dim=-1)
422
- log_p = torch.log(torch.gather(probs, 1, next_token.unsqueeze(1))).squeeze(1)
423
-
424
- generated_tokens.append(next_token.unsqueeze(1))
425
- generated_logprobs.append(log_p.unsqueeze(1))
426
- if return_probs:
427
- generated_probs.append(probs.unsqueeze(1))
428
-
429
- current_input = torch.cat([current_input, next_token.unsqueeze(1)], dim=1)
430
-
431
- generated_tokens = torch.cat(generated_tokens, dim=1)
432
- generated_logprobs = torch.cat(generated_logprobs, dim=1)
433
- generated_probs = torch.cat(generated_probs, dim=1) if return_probs else None
434
-
435
- # Decode generated tokens
436
- if tokenizer is None:
437
- tokenizer = getattr(self, "tokenizer", None)
438
- if tokenizer is None:
439
- raise ValueError("Tokenizer must be provided to decode generated tokens.")
440
-
441
- decoded_list = [
442
- tokenizer.decode(tok_ids, skip_special_tokens=True)
443
- for tok_ids in generated_tokens
444
- ]
445
-
446
- return decoded_list, generated_logprobs, generated_tokens, generated_probs
447
-
448
- # ========================
449
- # REGISTRATION
450
- # ========================
451
-
452
- # Register the configuration and model classes
453
- from transformers import AutoConfig, AutoModelForCausalLM
454
-
455
- # Register the configuration and model classes
456
- AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig)
 
 
 
 
 
 
 
 
 
 
457
  AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)
 
1
+ # ========================
2
+ # ChemQ3-MTP - HuggingFace Compatible Version
3
+ # MODEL COMPONENTS
4
+ # by gbyuvd
5
+ # ========================
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.distributions import Categorical
12
+ from typing import List, Union, Optional, Tuple, Dict, Any
13
+ from transformers import Qwen2Config, Qwen2ForCausalLM, AutoTokenizer
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+ from transformers.utils import logging
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from rdkit import Chem
19
+ from rdkit.Chem import Descriptors, Lipinski
20
+ import selfies as sf
21
+ from rdkit import RDLogger
22
+ RDLogger.DisableLog('rdApp.*')
23
+ import json
24
+ import numpy as np
25
+ from collections import Counter
26
+ from rdkit.Chem import rdMolDescriptors
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ # ========================
31
+ # CONFIGURATION CLASS
32
+ # ========================
33
+
34
+ class ChemQ3MTPConfig(Qwen2Config):
35
+ """
36
+ Configuration class for ChemQ3MTP model.
37
+ """
38
+ model_type = "chemq3_mtp"
39
+
40
+ def __init__(
41
+ self,
42
+ num_future_tokens: int = 3,
43
+ horizon_weights: Optional[List[float]] = None,
44
+ use_mtp_training: bool = True,
45
+ entropy_controller_config: Optional[Dict[str, Any]] = None,
46
+ **kwargs
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.num_future_tokens = num_future_tokens
50
+ self.horizon_weights = horizon_weights or [0.9 ** i for i in range(num_future_tokens)]
51
+ self.use_mtp_training = use_mtp_training
52
+ self.entropy_controller_config = entropy_controller_config or {
53
+ "min_entropy": 0.5,
54
+ "max_entropy": 3.0,
55
+ "target_entropy": 1.5,
56
+ "adaptation_rate": 0.01
57
+ }
58
+
59
+ # ========================
60
+ # UTILITY FUNCTIONS (kept minimal for HF compatibility)
61
+ # ========================
62
+
63
+ def selfies_to_smiles(selfies_str: str) -> str | None:
64
+ """Convert SELFIES string to SMILES, handling tokenizer artifacts."""
65
+ try:
66
+ clean_selfies = selfies_str.replace(" ", "")
67
+ return sf.decoder(clean_selfies)
68
+ except Exception:
69
+ return None
70
+
71
+ def is_valid_smiles(smiles: str) -> bool:
72
+ if not isinstance(smiles, str) or len(smiles.strip()) == 0:
73
+ return False
74
+ return Chem.MolFromSmiles(smiles.strip()) is not None
75
+
76
+ # ========================
77
+ # MODEL COMPONENTS
78
+ # ========================
79
+
80
+ class MTPHead(nn.Module):
81
+ """Multi-Token Prediction Head for predicting future tokens."""
82
+
83
+ def __init__(self, hidden_size: int, vocab_size: int, num_future_tokens: int = 3):
84
+ super().__init__()
85
+ self.num_future_tokens = num_future_tokens
86
+ self.vocab_size = vocab_size
87
+ self.prediction_heads = nn.ModuleList([
88
+ nn.Linear(hidden_size, vocab_size, bias=False)
89
+ for _ in range(num_future_tokens)
90
+ ])
91
+ self.position_embeddings = nn.Embedding(num_future_tokens, hidden_size)
92
+ self.layer_norm = nn.LayerNorm(hidden_size)
93
+
94
+ def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
95
+ batch_size, seq_len, hidden_size = hidden_states.shape
96
+ outputs = {}
97
+
98
+ for i in range(self.num_future_tokens):
99
+ pos_emb = self.position_embeddings(torch.tensor(i, device=hidden_states.device))
100
+ enhanced_hidden = self.layer_norm(hidden_states + pos_emb)
101
+ logits = self.prediction_heads[i](enhanced_hidden)
102
+ outputs[f'logits_t{i+1}'] = logits
103
+
104
+ return outputs
105
+
106
+
107
+ class HorizonLoss(nn.Module):
108
+ """Loss function for multi-horizon prediction."""
109
+
110
+ def __init__(self, num_future_tokens: int = 3, horizon_weights: Optional[List[float]] = None):
111
+ super().__init__()
112
+ self.num_future_tokens = num_future_tokens
113
+ if horizon_weights is None:
114
+ self.horizon_weights = [0.9 ** i for i in range(num_future_tokens)]
115
+ else:
116
+ self.horizon_weights = horizon_weights
117
+ self.log_weights = nn.Parameter(torch.log(torch.tensor(self.horizon_weights)))
118
+
119
+ def forward(
120
+ self,
121
+ mtp_outputs: Dict[str, torch.Tensor],
122
+ input_ids: torch.Tensor,
123
+ attention_mask: Optional[torch.Tensor] = None
124
+ ) -> Dict[str, torch.Tensor]:
125
+
126
+ batch_size, seq_len = input_ids.shape
127
+ device = input_ids.device
128
+ weights = F.softmax(self.log_weights, dim=0)
129
+ total_loss = 0.0
130
+ horizon_losses = {}
131
+
132
+ for i in range(self.num_future_tokens):
133
+ logits_key = f'logits_t{i+1}'
134
+ if logits_key not in mtp_outputs:
135
+ continue
136
+
137
+ logits = mtp_outputs[logits_key]
138
+ shift = i + 1
139
+ if seq_len <= shift:
140
+ continue
141
+
142
+ shifted_logits = logits[:, :-shift, :].contiguous()
143
+ shifted_targets = input_ids[:, shift:].contiguous()
144
+
145
+ if attention_mask is not None:
146
+ shifted_mask = attention_mask[:, shift:].contiguous()
147
+ mask_expanded = shifted_mask.view(-1)
148
+ valid_indices = mask_expanded == 1
149
+ if valid_indices.sum() == 0:
150
+ continue
151
+ flat_logits = shifted_logits.view(-1, logits.size(-1))[valid_indices]
152
+ flat_targets = shifted_targets.view(-1)[valid_indices]
153
+ else:
154
+ flat_logits = shifted_logits.view(-1, logits.size(-1))
155
+ flat_targets = shifted_targets.view(-1)
156
+
157
+ horizon_loss = F.cross_entropy(flat_logits, flat_targets, reduction='mean')
158
+ horizon_losses[f'horizon_loss_t{i+1}'] = horizon_loss
159
+ total_loss += weights[i] * horizon_loss
160
+
161
+ return {'loss': total_loss, 'horizon_weights': weights, **horizon_losses}
162
+
163
+
164
+ class EnhancedEntropyController:
165
+ """Enhanced entropy controller for adaptive training."""
166
+
167
+ def __init__(self, min_entropy: float = 0.5, max_entropy: float = 3.0,
168
+ target_entropy: float = 1.5, adaptation_rate: float = 0.01):
169
+ self.min_entropy = min_entropy
170
+ self.max_entropy = max_entropy
171
+ self.target_entropy = target_entropy
172
+ self.adaptation_rate = adaptation_rate
173
+ self.entropy_history = []
174
+ self.entropy_weight = 0.01
175
+
176
+ def update_entropy_weight(self, current_entropy: float) -> float:
177
+ """Dynamically adjust entropy weight based on current entropy levels."""
178
+ self.entropy_history.append(current_entropy)
179
+
180
+ if len(self.entropy_history) > 100:
181
+ self.entropy_history = self.entropy_history[-100:]
182
+
183
+ if len(self.entropy_history) >= 10:
184
+ avg_entropy = np.mean(self.entropy_history[-10:])
185
+
186
+ if avg_entropy < self.target_entropy * 0.8:
187
+ self.entropy_weight = min(0.05, self.entropy_weight * 1.1)
188
+ elif avg_entropy > self.target_entropy * 1.2:
189
+ self.entropy_weight = max(0.001, self.entropy_weight * 0.95)
190
+
191
+ return self.entropy_weight
192
+
193
+ # ========================
194
+ # MAIN MODEL CLASS
195
+ # ========================
196
+
197
+ class ChemQ3MTPForCausalLM(Qwen2ForCausalLM):
198
+ """
199
+ ChemQ3MTP model for causal language modeling with multi-token prediction.
200
+
201
+ This model extends Qwen2ForCausalLM with additional capabilities for
202
+ multi-token prediction and chemistry-specific training.
203
+ """
204
+
205
+ config_class = ChemQ3MTPConfig
206
+ _supports_flash_attn_2 = True
207
+ _supports_sdpa = True
208
+ _supports_cache_class = True
209
+
210
+ def __init__(self, config: ChemQ3MTPConfig):
211
+ super().__init__(config)
212
+
213
+ # Initialize MTP components
214
+ self.mtp_head = MTPHead(
215
+ config.hidden_size,
216
+ config.vocab_size,
217
+ config.num_future_tokens
218
+ )
219
+ self.horizon_loss = HorizonLoss(
220
+ num_future_tokens=config.num_future_tokens,
221
+ horizon_weights=config.horizon_weights
222
+ )
223
+
224
+ # Training configuration
225
+ self.use_mtp_training = config.use_mtp_training
226
+
227
+ # Initialize entropy controller
228
+ self.entropy_controller = EnhancedEntropyController(
229
+ **config.entropy_controller_config
230
+ )
231
+
232
+ # Initialize weights and apply final processing
233
+ self.post_init()
234
+
235
+ def forward(
236
+ self,
237
+ input_ids: Optional[torch.LongTensor] = None,
238
+ attention_mask: Optional[torch.FloatTensor] = None,
239
+ position_ids: Optional[torch.LongTensor] = None,
240
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
241
+ inputs_embeds: Optional[torch.FloatTensor] = None,
242
+ labels: Optional[torch.LongTensor] = None,
243
+ use_cache: Optional[bool] = None,
244
+ output_attentions: Optional[bool] = None,
245
+ output_hidden_states: Optional[bool] = None,
246
+ return_dict: Optional[bool] = None,
247
+ cache_position: Optional[torch.LongTensor] = None,
248
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
249
+ """
250
+ Forward pass of the ChemQ3MTP model.
251
+ """
252
+
253
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
254
+ output_hidden_states = (
255
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
256
+ )
257
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
258
+
259
+ # Default attention mask if not provided
260
+ if attention_mask is None and input_ids is not None:
261
+ # Handle case where pad_token_id is None
262
+ if hasattr(self.config, 'pad_token_id') and self.config.pad_token_id is not None:
263
+ attention_mask = (input_ids != self.config.pad_token_id).long()
264
+ else:
265
+ # Default to all 1s if no pad_token_id is defined
266
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
267
+
268
+ # Call parent forward with required hidden states
269
+ outputs = super().forward(
270
+ input_ids=input_ids,
271
+ attention_mask=attention_mask,
272
+ position_ids=position_ids,
273
+ past_key_values=past_key_values,
274
+ inputs_embeds=inputs_embeds,
275
+ labels=None, # Handle labels manually
276
+ use_cache=use_cache,
277
+ output_attentions=output_attentions,
278
+ output_hidden_states=True, # Always need hidden states for MTP
279
+ return_dict=True,
280
+ cache_position=cache_position,
281
+ )
282
+
283
+ # Rest of your forward method...
284
+ hidden_states = outputs.hidden_states[-1]
285
+ lm_logits = outputs.logits
286
+ loss = None
287
+
288
+ # Compute loss if labels are provided
289
+ if labels is not None:
290
+ if self.training and self.use_mtp_training:
291
+ # Multi-token prediction training
292
+ mtp_outputs = self.mtp_head(hidden_states)
293
+ horizon_loss_dict = self.horizon_loss(mtp_outputs, input_ids, attention_mask)
294
+
295
+ # Standard causal LM loss
296
+ shift_logits = lm_logits[..., :-1, :].contiguous()
297
+ shift_labels = labels[..., 1:].contiguous()
298
+
299
+ if attention_mask is not None:
300
+ shift_mask = attention_mask[..., 1:].contiguous()
301
+ loss_mask = shift_mask.view(-1) == 1
302
+ if loss_mask.sum() == 0:
303
+ causal_lm_loss = torch.tensor(0.0, device=lm_logits.device)
304
+ else:
305
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))[loss_mask]
306
+ flat_labels = shift_labels.view(-1)[loss_mask]
307
+ causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
308
+ else:
309
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))
310
+ flat_labels = shift_labels.view(-1)
311
+ causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
312
+
313
+ # Combine losses
314
+ loss = 0.7 * horizon_loss_dict['loss'] + 0.3 * causal_lm_loss
315
+
316
+ else:
317
+ # Standard causal LM training
318
+ shift_logits = lm_logits[..., :-1, :].contiguous()
319
+ shift_labels = labels[..., 1:].contiguous()
320
+ loss = F.cross_entropy(
321
+ shift_logits.view(-1, shift_logits.size(-1)),
322
+ shift_labels.view(-1),
323
+ ignore_index=-100
324
+ )
325
+
326
+ if not return_dict:
327
+ output = (lm_logits,) + outputs[1:]
328
+ return (loss,) + output if loss is not None else output
329
+
330
+ return CausalLMOutputWithPast(
331
+ loss=loss,
332
+ logits=lm_logits,
333
+ past_key_values=outputs.past_key_values,
334
+ hidden_states=outputs.hidden_states,
335
+ attentions=outputs.attentions,
336
+ )
337
+
338
+ def set_mtp_training(self, use_mtp: bool):
339
+ """Enable or disable multi-token prediction training."""
340
+ self.use_mtp_training = use_mtp
341
+
342
+ def prepare_inputs_for_generation(
343
+ self,
344
+ input_ids,
345
+ past_key_values=None,
346
+ attention_mask=None,
347
+ inputs_embeds=None,
348
+ cache_position=None,
349
+ **kwargs
350
+ ):
351
+ """
352
+ Prepare inputs for generation. This method is required for compatibility
353
+ with HuggingFace's generation utilities.
354
+ """
355
+ # This delegates to the parent class implementation
356
+ return super().prepare_inputs_for_generation(
357
+ input_ids=input_ids,
358
+ past_key_values=past_key_values,
359
+ attention_mask=attention_mask,
360
+ inputs_embeds=inputs_embeds,
361
+ cache_position=cache_position,
362
+ **kwargs
363
+ )
364
+
365
+ def generate_with_logprobs(
366
+ self,
367
+ input_ids: torch.LongTensor,
368
+ max_new_tokens: int = 50,
369
+ temperature: float = 1.0,
370
+ top_k: Optional[int] = None,
371
+ top_p: Optional[float] = None,
372
+ do_sample: bool = True,
373
+ return_probs: bool = True,
374
+ tokenizer=None,
375
+ ) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
376
+ """
377
+ Generate sequences with log probabilities for RL training.
378
+
379
+ FIXED VERSION: Corrects log probability calculation to avoid numerical issues.
380
+ Changes:
381
+ 1. Use log_softmax instead of log(softmax) to avoid log(0) issues
382
+ 2. Correct the gather operation for non-sampling case
383
+ 3. Handle the case where filtered logits become -inf properly
384
+ """
385
+ self.eval()
386
+ device = input_ids.device
387
+
388
+ # Normalize input shapes
389
+ if input_ids.dim() == 1:
390
+ input_ids = input_ids.unsqueeze(0)
391
+ if input_ids.dim() == 3 and input_ids.size(1) == 1:
392
+ input_ids = input_ids.squeeze(1)
393
+ assert input_ids.dim() == 2, f"input_ids must be 2-D, got {input_ids.shape}"
394
+
395
+ batch_size, seq_len = input_ids.shape
396
+ current_input = input_ids
397
+
398
+ generated_tokens, generated_logprobs, generated_probs = [], [], []
399
+
400
+ with torch.no_grad():
401
+ for _ in range(max_new_tokens):
402
+ outputs = self(current_input, use_cache=False)
403
+ logits = outputs.logits[:, -1, :] / temperature
404
+
405
+ # Apply top-k filtering
406
+ if top_k is not None:
407
+ values, indices = torch.topk(logits, k=top_k)
408
+ logits = torch.full_like(logits, float("-inf"))
409
+ logits.scatter_(1, indices, values)
410
+
411
+ # Apply top-p filtering
412
+ if top_p is not None and top_p < 1.0:
413
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
414
+ cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
415
+ mask = cumprobs > top_p
416
+ mask[..., 1:] = mask[..., :-1].clone()
417
+ mask[..., 0] = False
418
+ logits[mask.scatter(1, sorted_indices, mask)] = float("-inf")
419
+
420
+ # FIX: Calculate log probabilities using log_softmax for numerical stability
421
+ log_probs = F.log_softmax(logits, dim=-1)
422
+ probs = F.softmax(logits, dim=-1)
423
+
424
+ if do_sample:
425
+ dist = Categorical(probs)
426
+ next_token = dist.sample()
427
+ # FIX: Get log prob directly from log_probs tensor
428
+ log_p = torch.gather(log_probs, 1, next_token.unsqueeze(1)).squeeze(1)
429
+ else:
430
+ next_token = torch.argmax(probs, dim=-1)
431
+ # FIX: Use log_probs instead of log(probs) to avoid numerical issues
432
+ log_p = torch.gather(log_probs, 1, next_token.unsqueeze(1)).squeeze(1)
433
+
434
+ generated_tokens.append(next_token.unsqueeze(1))
435
+ generated_logprobs.append(log_p.unsqueeze(1))
436
+ if return_probs:
437
+ generated_probs.append(probs.unsqueeze(1))
438
+
439
+ current_input = torch.cat([current_input, next_token.unsqueeze(1)], dim=1)
440
+
441
+ generated_tokens = torch.cat(generated_tokens, dim=1)
442
+ generated_logprobs = torch.cat(generated_logprobs, dim=1)
443
+ generated_probs = torch.cat(generated_probs, dim=1) if return_probs else None
444
+
445
+ # Decode generated tokens
446
+ if tokenizer is None:
447
+ tokenizer = getattr(self, "tokenizer", None)
448
+ if tokenizer is None:
449
+ raise ValueError("Tokenizer must be provided to decode generated tokens.")
450
+
451
+ decoded_list = [
452
+ tokenizer.decode(tok_ids, skip_special_tokens=True)
453
+ for tok_ids in generated_tokens
454
+ ]
455
+
456
+ return decoded_list, generated_logprobs, generated_tokens, generated_probs
457
+
458
+ # ========================
459
+ # REGISTRATION
460
+ # ========================
461
+
462
+ # Register the configuration and model classes
463
+ from transformers import AutoConfig, AutoModelForCausalLM
464
+
465
+ # Register the configuration and model classes
466
+ AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig)
467
  AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)