kimyeonz's picture
Update model.py
77f6a88 verified
import torch
import torch.nn as nn
from transformers import AutoModelForVision2Seq, BitsAndBytesConfig
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class ClassificationOutput(ModelOutput):
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class MoralEmotionVLClassifier(nn.Module):
def __init__(self, model_id, num_labels=1, device="auto", max_memory=None, label_names=None):
super().__init__()
self.device = device
self.max_memory = max_memory
self.model_id = model_id
# Bits and bytes config for model quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
# Load base model (vision-to-text)
self.base_model = AutoModelForVision2Seq.from_pretrained(
self.model_id,
device_map="auto" if device == "auto" else {"": device},
torch_dtype=torch.float16,
quantization_config=bnb_config,
max_memory=self.max_memory if device == "auto" else None
)
self.config = self.base_model.config
self.config.num_labels = num_labels
self.gradient_checkpointing_enable = self.base_model.gradient_checkpointing_enable
# Modify the final classification head (lm_head)
original_lm_head = self.base_model.lm_head
hidden_size = original_lm_head.in_features
head_device = original_lm_head.weight.device
head_dtype = original_lm_head.weight.dtype
# Change to classification head for the number of labels required
self.base_model.lm_head = nn.Linear(
hidden_size,
num_labels,
device=head_device,
dtype=head_dtype
)
# label mapping
self.num_labels = num_labels
self.label_names = label_names if label_names is not None else []
self.label2id = {label: i for i, label in enumerate(self.label_names)}
self.id2label = {i: label for i, label in enumerate(self.label_names)}
def forward(self, **kwargs):
outputs = self.base_model(**kwargs)
logits = outputs.logits
classification_logits = logits[:, -1, :]
return ClassificationOutput(
logits=classification_logits,
hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None
)