File size: 2,651 Bytes
1fe4a4d
 
 
 
 
 
 
 
 
 
 
 
 
77f6a88
1fe4a4d
18994fb
 
 
77f6a88
18994fb
 
1fe4a4d
 
 
 
 
 
 
18994fb
77f6a88
 
 
 
 
 
 
1fe4a4d
 
 
 
 
18994fb
1fe4a4d
 
 
 
 
18994fb
1fe4a4d
 
 
 
 
 
 
 
 
 
 
 
18994fb
1fe4a4d
 
 
77f6a88
1fe4a4d
 
 
 
77f6a88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from transformers import AutoModelForVision2Seq, BitsAndBytesConfig
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional, Tuple

@dataclass
class ClassificationOutput(ModelOutput):
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None

class MoralEmotionVLClassifier(nn.Module):
    def __init__(self, model_id, num_labels=1, device="auto", max_memory=None, label_names=None):
        super().__init__()

        self.device = device
        self.max_memory = max_memory
        self.model_id = model_id

        # Bits and bytes config for model quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, 
            bnb_4bit_use_double_quant=True, 
            bnb_4bit_quant_type="nf4", 
            bnb_4bit_compute_dtype=torch.float16
        )
        
        # Load base model (vision-to-text)
        self.base_model = AutoModelForVision2Seq.from_pretrained(
            self.model_id,
            device_map="auto" if device == "auto" else {"": device},
            torch_dtype=torch.float16,
            quantization_config=bnb_config,
            max_memory=self.max_memory if device == "auto" else None
        )

        self.config = self.base_model.config
        self.config.num_labels = num_labels
        self.gradient_checkpointing_enable = self.base_model.gradient_checkpointing_enable
        
        # Modify the final classification head (lm_head)
        original_lm_head = self.base_model.lm_head
        hidden_size = original_lm_head.in_features
        head_device = original_lm_head.weight.device
        head_dtype = original_lm_head.weight.dtype

        # Change to classification head for the number of labels required
        self.base_model.lm_head = nn.Linear(
            hidden_size, 
            num_labels,
            device=head_device,
            dtype=head_dtype
        )

        # label mapping
        self.num_labels = num_labels
        self.label_names = label_names if label_names is not None else []
        self.label2id = {label: i for i, label in enumerate(self.label_names)}
        self.id2label = {i: label for i, label in enumerate(self.label_names)}

    def forward(self, **kwargs):
        outputs = self.base_model(**kwargs)        
        logits = outputs.logits        
        classification_logits = logits[:, -1, :]
        
        return ClassificationOutput(
            logits=classification_logits,
            hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None
        )