Instructions to use baseten/gemma-4-e2b-it-sequence-classification with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use baseten/gemma-4-e2b-it-sequence-classification with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="baseten/gemma-4-e2b-it-sequence-classification", trust_remote_code=True)# Load model directly from transformers import AutoProcessor, AutoModelForSequenceClassification processor = AutoProcessor.from_pretrained("baseten/gemma-4-e2b-it-sequence-classification", trust_remote_code=True) model = AutoModelForSequenceClassification.from_pretrained("baseten/gemma-4-e2b-it-sequence-classification", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
File size: 7,920 Bytes
4dd975e b3842c5 4dd975e b3842c5 4dd975e 2bf0822 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """Gemma 4 sequence classifier backed by selected next-token logits.
This module is intentionally small: it reuses the Gemma 4 multimodal backbone and
replaces the LM head with a classifier head containing selected token rows.
"""
from __future__ import annotations
from collections.abc import Sequence
import torch
from torch import nn
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config
from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4PreTrainedModel
class Gemma4ForSequenceClassification(Gemma4PreTrainedModel):
"""Pool the last text position and score it with selected Gemma 4 token rows."""
config_class = Gemma4Config
base_model_prefix = "model"
@classmethod
def _can_set_experts_implementation(cls) -> bool:
return True
def __init__(
self,
config: Gemma4Config,
source_model: nn.Module | None = None,
classifier_weight: torch.Tensor | None = None,
) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.model = source_model.model if source_model is not None else Gemma4Model(config)
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
if classifier_weight is not None:
self.score.to(device=classifier_weight.device, dtype=classifier_weight.dtype)
self.score.weight.data.copy_(classifier_weight)
if source_model is None and classifier_weight is None:
self.post_init()
@classmethod
def from_conditional_generation(
cls,
model_lm: nn.Module,
selected_token_ids: Sequence[int],
labels: Sequence[str],
) -> "Gemma4ForSequenceClassification":
token_ids = torch.tensor(selected_token_ids, device=model_lm.lm_head.weight.device)
classifier_weight = model_lm.lm_head.weight.index_select(0, token_ids).detach().clone()
cls.configure_classification_config(model_lm.config, selected_token_ids, labels)
return cls(model_lm.config, source_model=model_lm, classifier_weight=classifier_weight)
@classmethod
def configure_classification_config(
cls,
config: Gemma4Config,
selected_token_ids: Sequence[int],
labels: Sequence[str],
) -> None:
config.num_labels = len(labels)
config.id2label = {i: label for i, label in enumerate(labels)}
config.label2id = {label: i for i, label in enumerate(labels)}
config.classifier_token_ids = {
label: int(token_id) for label, token_id in zip(labels, selected_token_ids)
}
config.architectures = [cls.__name__]
config.problem_type = "single_label_classification"
if getattr(config, "pad_token_id", None) is None:
config.pad_token_id = config.text_config.pad_token_id
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_per_layer_input_embeddings(self):
return self.model.get_per_layer_input_embeddings()
def set_per_layer_input_embeddings(self, value):
self.model.set_per_layer_input_embeddings(value)
def _last_non_pad_token(
self,
logits: torch.Tensor,
input_ids: torch.LongTensor | None,
attention_mask: torch.Tensor | None,
inputs_embeds: torch.FloatTensor | None,
) -> torch.Tensor | int:
batch_size = logits.shape[0]
if attention_mask is not None:
token_indices = torch.arange(logits.shape[1], device=logits.device)
return (attention_mask.to(logits.device) * token_indices).argmax(-1)
pad_token_id = getattr(self.config, "pad_token_id", None)
if input_ids is not None and pad_token_id is not None:
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
non_pad = input_ids.to(logits.device).ne(pad_token_id)
return (non_pad * token_indices).argmax(-1)
if batch_size != 1:
raise ValueError(
"Cannot infer sequence lengths for a padded batch without a pad token."
)
if input_ids is None and inputs_embeds is None:
raise ValueError("Expected input_ids or inputs_embeds.")
return -1
def _apply_final_logit_softcapping(self, logits: torch.Tensor) -> torch.Tensor:
final_logit_softcapping = self.config.get_text_config().final_logit_softcapping
if final_logit_softcapping is None:
return logits
logits = logits / final_logit_softcapping
logits = torch.tanh(logits)
return logits * final_logit_softcapping
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
input_features: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
input_features_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
image_position_ids: torch.LongTensor | None = None,
video_position_ids: torch.LongTensor | None = None,
past_key_values=None,
mm_token_type_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
return_dict: bool | None = None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
attention_mask=attention_mask,
input_features_mask=input_features_mask,
position_ids=position_ids,
past_key_values=past_key_values,
mm_token_type_ids=mm_token_type_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
image_position_ids=image_position_ids,
video_position_ids=video_position_ids,
return_dict=True,
**kwargs,
)
logits = self.score(outputs.last_hidden_state)
logits = self._apply_final_logit_softcapping(logits)
sequence_lengths = self._last_non_pad_token(
logits,
input_ids,
attention_mask,
inputs_embeds,
)
pooled_logits = logits[
torch.arange(logits.shape[0], device=logits.device),
sequence_lengths,
]
loss = None
if labels is not None:
labels = labels.to(pooled_logits.device)
if self.config.problem_type == "regression":
loss = nn.MSELoss()(pooled_logits.squeeze(), labels.squeeze())
elif self.config.problem_type == "multi_label_classification":
loss = nn.BCEWithLogitsLoss()(pooled_logits, labels)
else:
loss = nn.CrossEntropyLoss()(
pooled_logits.view(-1, self.num_labels),
labels.view(-1),
)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Gemma4ForSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification")
|