kimyeonz commited on
Commit
77f6a88
·
verified ·
1 Parent(s): 18994fb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +11 -43
model.py CHANGED
@@ -11,12 +11,12 @@ class ClassificationOutput(ModelOutput):
11
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
12
 
13
  class MoralEmotionVLClassifier(nn.Module):
14
- def __init__(self, model_id_or_save_dir, num_labels=1, device="auto", max_memory=None, label_names=None, train=True):
15
  super().__init__()
16
 
17
  self.device = device
18
  self.max_memory = max_memory
19
- self.model_id_or_save_dir = model_id_or_save_dir
20
 
21
  # Bits and bytes config for model quantization
22
  bnb_config = BitsAndBytesConfig(
@@ -27,24 +27,13 @@ class MoralEmotionVLClassifier(nn.Module):
27
  )
28
 
29
  # Load base model (vision-to-text)
30
- if device == 'auto':
31
- self.base_model = AutoModelForVision2Seq.from_pretrained(
32
- self.model_id_or_save_dir,
33
- device_map=self.device,
34
- torch_dtype=torch.float16,
35
- quantization_config=bnb_config if train else None,
36
- ignore_mismatched_sizes=not train,
37
- )
38
-
39
- else:
40
- self.base_model = AutoModelForVision2Seq.from_pretrained(
41
- self.model_id_or_save_dir,
42
- device_map={"": device},
43
- torch_dtype=torch.float16,
44
- quantization_config=bnb_config if train else None,
45
- max_memory=self.max_memory,
46
- ignore_mismatched_sizes=not train,
47
- )
48
 
49
  self.config = self.base_model.config
50
  self.config.num_labels = num_labels
@@ -64,26 +53,6 @@ class MoralEmotionVLClassifier(nn.Module):
64
  dtype=head_dtype
65
  )
66
 
67
- if not train:
68
- try:
69
- from safetensors import safe_open
70
- import os
71
-
72
- safetensors_path = os.path.join(model_id_or_save_dir, "model.safetensors")
73
- if os.path.exists(safetensors_path):
74
- with safe_open(safetensors_path, framework="pt") as f:
75
- lm_head_weight = f.get_tensor("lm_head.weight")
76
- lm_head_bias = f.get_tensor("lm_head.bias") if "lm_head.bias" in f.keys() else None
77
-
78
- target_device = self.base_model.lm_head.weight.device
79
- self.base_model.lm_head.weight.data = lm_head_weight.to(target_device)
80
- if lm_head_bias is not None:
81
- self.base_model.lm_head.bias.data = lm_head_bias.to(target_device)
82
- print('\nload the custom layer weights successed!\n')
83
- except Exception as e:
84
- print(f"Warning: Could not load lm_head weights: {e}")
85
-
86
-
87
  # label mapping
88
  self.num_labels = num_labels
89
  self.label_names = label_names if label_names is not None else []
@@ -91,12 +60,11 @@ class MoralEmotionVLClassifier(nn.Module):
91
  self.id2label = {i: label for i, label in enumerate(self.label_names)}
92
 
93
  def forward(self, **kwargs):
94
- # Forward pass through the model
95
  outputs = self.base_model(**kwargs)
96
  logits = outputs.logits
97
- classification_logits = logits[:, -1, :] # Assuming we want to use the last token's logits
98
 
99
  return ClassificationOutput(
100
  logits=classification_logits,
101
  hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None
102
- )
 
11
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
12
 
13
  class MoralEmotionVLClassifier(nn.Module):
14
+ def __init__(self, model_id, num_labels=1, device="auto", max_memory=None, label_names=None):
15
  super().__init__()
16
 
17
  self.device = device
18
  self.max_memory = max_memory
19
+ self.model_id = model_id
20
 
21
  # Bits and bytes config for model quantization
22
  bnb_config = BitsAndBytesConfig(
 
27
  )
28
 
29
  # Load base model (vision-to-text)
30
+ self.base_model = AutoModelForVision2Seq.from_pretrained(
31
+ self.model_id,
32
+ device_map="auto" if device == "auto" else {"": device},
33
+ torch_dtype=torch.float16,
34
+ quantization_config=bnb_config,
35
+ max_memory=self.max_memory if device == "auto" else None
36
+ )
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  self.config = self.base_model.config
39
  self.config.num_labels = num_labels
 
53
  dtype=head_dtype
54
  )
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # label mapping
57
  self.num_labels = num_labels
58
  self.label_names = label_names if label_names is not None else []
 
60
  self.id2label = {i: label for i, label in enumerate(self.label_names)}
61
 
62
  def forward(self, **kwargs):
 
63
  outputs = self.base_model(**kwargs)
64
  logits = outputs.logits
65
+ classification_logits = logits[:, -1, :]
66
 
67
  return ClassificationOutput(
68
  logits=classification_logits,
69
  hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None
70
+ )