kimyeonz commited on
Commit
18994fb
·
verified ·
1 Parent(s): 1fe4a4d

edit for inference

Browse files
Files changed (1) hide show
  1. model.py +51 -11
model.py CHANGED
@@ -11,9 +11,14 @@ class ClassificationOutput(ModelOutput):
11
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
12
 
13
  class MoralEmotionVLClassifier(nn.Module):
14
- def __init__(self, model_id, num_labels=1, device="auto", label_names=None):
15
  super().__init__()
16
-
 
 
 
 
 
17
  bnb_config = BitsAndBytesConfig(
18
  load_in_4bit=True,
19
  bnb_4bit_use_double_quant=True,
@@ -21,23 +26,37 @@ class MoralEmotionVLClassifier(nn.Module):
21
  bnb_4bit_compute_dtype=torch.float16
22
  )
23
 
24
- self.base_model = AutoModelForVision2Seq.from_pretrained(
25
- model_id,
26
- device_map='auto' if device == 'auto' else {"": device},
27
- torch_dtype=torch.float16,
28
- quantization_config=bnb_config
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  self.config = self.base_model.config
32
  self.config.num_labels = num_labels
33
  self.gradient_checkpointing_enable = self.base_model.gradient_checkpointing_enable
34
 
 
35
  original_lm_head = self.base_model.lm_head
36
  hidden_size = original_lm_head.in_features
37
  head_device = original_lm_head.weight.device
38
  head_dtype = original_lm_head.weight.dtype
39
 
40
- # change to classification head
41
  self.base_model.lm_head = nn.Linear(
42
  hidden_size,
43
  num_labels,
@@ -45,16 +64,37 @@ class MoralEmotionVLClassifier(nn.Module):
45
  dtype=head_dtype
46
  )
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # label mapping
49
  self.num_labels = num_labels
50
  self.label_names = label_names if label_names is not None else []
51
  self.label2id = {label: i for i, label in enumerate(self.label_names)}
52
  self.id2label = {i: label for i, label in enumerate(self.label_names)}
53
-
54
  def forward(self, **kwargs):
 
55
  outputs = self.base_model(**kwargs)
56
  logits = outputs.logits
57
- classification_logits = logits[:, -1, :]
58
 
59
  return ClassificationOutput(
60
  logits=classification_logits,
 
11
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
12
 
13
  class MoralEmotionVLClassifier(nn.Module):
14
+ def __init__(self, model_id_or_save_dir, num_labels=1, device="auto", max_memory=None, label_names=None, train=True):
15
  super().__init__()
16
+
17
+ self.device = device
18
+ self.max_memory = max_memory
19
+ self.model_id_or_save_dir = model_id_or_save_dir
20
+
21
+ # Bits and bytes config for model quantization
22
  bnb_config = BitsAndBytesConfig(
23
  load_in_4bit=True,
24
  bnb_4bit_use_double_quant=True,
 
26
  bnb_4bit_compute_dtype=torch.float16
27
  )
28
 
29
+ # Load base model (vision-to-text)
30
+ if device == 'auto':
31
+ self.base_model = AutoModelForVision2Seq.from_pretrained(
32
+ self.model_id_or_save_dir,
33
+ device_map=self.device,
34
+ torch_dtype=torch.float16,
35
+ quantization_config=bnb_config if train else None,
36
+ ignore_mismatched_sizes=not train,
37
+ )
38
+
39
+ else:
40
+ self.base_model = AutoModelForVision2Seq.from_pretrained(
41
+ self.model_id_or_save_dir,
42
+ device_map={"": device},
43
+ torch_dtype=torch.float16,
44
+ quantization_config=bnb_config if train else None,
45
+ max_memory=self.max_memory,
46
+ ignore_mismatched_sizes=not train,
47
+ )
48
 
49
  self.config = self.base_model.config
50
  self.config.num_labels = num_labels
51
  self.gradient_checkpointing_enable = self.base_model.gradient_checkpointing_enable
52
 
53
+ # Modify the final classification head (lm_head)
54
  original_lm_head = self.base_model.lm_head
55
  hidden_size = original_lm_head.in_features
56
  head_device = original_lm_head.weight.device
57
  head_dtype = original_lm_head.weight.dtype
58
 
59
+ # Change to classification head for the number of labels required
60
  self.base_model.lm_head = nn.Linear(
61
  hidden_size,
62
  num_labels,
 
64
  dtype=head_dtype
65
  )
66
 
67
+ if not train:
68
+ try:
69
+ from safetensors import safe_open
70
+ import os
71
+
72
+ safetensors_path = os.path.join(model_id_or_save_dir, "model.safetensors")
73
+ if os.path.exists(safetensors_path):
74
+ with safe_open(safetensors_path, framework="pt") as f:
75
+ lm_head_weight = f.get_tensor("lm_head.weight")
76
+ lm_head_bias = f.get_tensor("lm_head.bias") if "lm_head.bias" in f.keys() else None
77
+
78
+ target_device = self.base_model.lm_head.weight.device
79
+ self.base_model.lm_head.weight.data = lm_head_weight.to(target_device)
80
+ if lm_head_bias is not None:
81
+ self.base_model.lm_head.bias.data = lm_head_bias.to(target_device)
82
+ print('\nload the custom layer weights successed!\n')
83
+ except Exception as e:
84
+ print(f"Warning: Could not load lm_head weights: {e}")
85
+
86
+
87
  # label mapping
88
  self.num_labels = num_labels
89
  self.label_names = label_names if label_names is not None else []
90
  self.label2id = {label: i for i, label in enumerate(self.label_names)}
91
  self.id2label = {i: label for i, label in enumerate(self.label_names)}
92
+
93
  def forward(self, **kwargs):
94
+ # Forward pass through the model
95
  outputs = self.base_model(**kwargs)
96
  logits = outputs.logits
97
+ classification_logits = logits[:, -1, :] # Assuming we want to use the last token's logits
98
 
99
  return ClassificationOutput(
100
  logits=classification_logits,