File size: 5,270 Bytes
d713204 b269504 d713204 b269504 d713204 b269504 d713204 b269504 d713204 b269504 d713204 b269504 d713204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from torch import nn
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.utils import auto_docstring
from transformers.utils.generic import TransformersKwargs, can_return_tuple
from typing import Optional, Union
from transformers.processing_utils import Unpack
import torch
from transformers import Cache, Qwen3Config
from transformers.models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel, Qwen3Model
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import logging
logger = logging.get_logger(__name__)
class ZeroEntropyTokenizer(PreTrainedTokenizerFast):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, pairs, *args, **kwargs):
input_texts: list[str] = []
for query, document in pairs:
messages = [
{"role": "system", "content": query.strip()},
{"role": "user", "content": document.strip()},
]
input_text = self.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
assert isinstance(input_text, str)
input_texts.append(input_text)
batch_inputs = super().__call__(input_texts, *args, **kwargs)
return batch_inputs
class ZeroEntropyConfig(Qwen3Config):
model_type = "zeroentropy"
def __init__(self, yes_token_id: int = 9454, **kwargs):
super().__init__(**kwargs)
self.yes_token_id = yes_token_id
class ZeroEntropyForSequenceClassification(Qwen3PreTrainedModel):
config: ZeroEntropyConfig
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Qwen3Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
last_positions = attention_mask.sum(dim=1) - 1
batch_size = logits.shape[0]
batch_indices = torch.arange(batch_size, device=logits.device)
yes_logits = logits[batch_indices, last_positions, self.config.yes_token_id]
yes_logits = yes_logits / 5.0
yes_logits = yes_logits.unsqueeze(-1)
return SequenceClassifierOutputWithPast(
loss=None,
logits=yes_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|