IvoHoese commited on
Commit
5ac3ae1
·
verified ·
1 Parent(s): 24aebf2

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +146 -0
utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import OlmoModel, OlmoPreTrainedModel, GenerationMixin, AutoConfig, AutoModelForSequenceClassification
2
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
3
+ import torch
4
+
5
+ from peft import PeftModel, PeftConfig
6
+
7
+ from transformers import AutoConfig
8
+
9
+ import logging
10
+ from contextlib import contextmanager
11
+
12
+ # The custom model for using Olmo with a sequence classification task
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ class OlmoForSequenceClassification(OlmoPreTrainedModel, GenerationMixin):
17
+ def __init__(self, config):
18
+ # Check OlmoForCausalLM.__init__
19
+ super().__init__(config)
20
+ self.model = OlmoModel(config)
21
+ self.num_labels = config.num_labels
22
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
23
+
24
+ # Initialize weights and apply final processing
25
+ self.post_init()
26
+
27
+ def forward(
28
+ self,
29
+ input_ids: torch.LongTensor = None,
30
+ attention_mask: torch.Tensor | None = None,
31
+ labels: torch.LongTensor | None = None,
32
+ **kwargs,
33
+ ) -> SequenceClassifierOutputWithPast:
34
+ outputs = self.model(
35
+ input_ids=input_ids,
36
+ attention_mask=attention_mask,
37
+ **kwargs,
38
+ )
39
+ logits = self.classifier(outputs.last_hidden_state) # [B, N, H] => [B, N, C]
40
+ pooled_logits = logits[:, -1] # NOTE: tokenizer.padding_side must be 'left'
41
+
42
+ loss = None
43
+ if labels is not None:
44
+ loss = self.loss_function(
45
+ logits=logits,
46
+ labels=labels,
47
+ pooled_logits=pooled_logits,
48
+ config=self.config,
49
+ )
50
+
51
+ return SequenceClassifierOutputWithPast(
52
+ loss=loss,
53
+ logits=pooled_logits,
54
+ past_key_values=outputs.past_key_values,
55
+ hidden_states=outputs.hidden_states,
56
+ attentions=outputs.attentions,
57
+ )
58
+
59
+ # The function for loading a fulltuning model
60
+
61
+ def get_fulltuning_model(model_path, model_type="olmo"):
62
+ if model_type == "olmo":
63
+ model = OlmoForSequenceClassification.from_pretrained(
64
+ model_path,
65
+ trust_remote_code=True,
66
+ torch_dtype=torch.float32,
67
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
68
+ model.eval()
69
+ elif model_type == "pythia":
70
+ cfg = AutoConfig.from_pretrained(model_path, num_labels=3)
71
+ model = AutoModelForSequenceClassification.from_pretrained(
72
+ model_path,
73
+ config=cfg,
74
+ torch_dtype=torch.float32,
75
+ ).to(device)
76
+ else:
77
+ raise ValueError(f"Unsupported model_type: {model_type}")
78
+
79
+ return model
80
+
81
+ # The following function is used to suppress a "missing or unexpected params" warning.
82
+ # This warning is no reason for concern. It stems from the fact that the model is first loaded
83
+ # without a classifier head, which is added afterwards.
84
+
85
+ class DropLoadReport(logging.Filter):
86
+ def filter(self, record: logging.LogRecord) -> bool:
87
+ return "LOAD REPORT" not in record.getMessage()
88
+
89
+ @contextmanager
90
+ def suppress_load_report_only():
91
+ f = DropLoadReport()
92
+
93
+ names = [
94
+ "transformers.modeling_utils",
95
+ "transformers.modeling_tf_pytorch_utils",
96
+ "transformers",
97
+ ]
98
+ loggers = [logging.getLogger(n) for n in names]
99
+
100
+ for lg in loggers:
101
+ lg.addFilter(f)
102
+ try:
103
+ yield
104
+ finally:
105
+ for lg in loggers:
106
+ lg.removeFilter(f)
107
+
108
+ # The function for loading a softprompt model
109
+
110
+ def get_peft_model(model_path, model_type="olmo"):
111
+ peft_config = PeftConfig.from_pretrained(model_path)
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+
114
+ if model_type == "olmo":
115
+ config = AutoConfig.from_pretrained(
116
+ peft_config.base_model_name_or_path,
117
+ trust_remote_code=True,
118
+ num_labels=2,
119
+ )
120
+ with suppress_load_report_only():
121
+ base = OlmoForSequenceClassification.from_pretrained(
122
+ peft_config.base_model_name_or_path,
123
+ trust_remote_code=True,
124
+ torch_dtype=torch.float32,
125
+ config=config,
126
+ ).to(device)
127
+
128
+ elif model_type == "pythia":
129
+ config = AutoConfig.from_pretrained(
130
+ peft_config.base_model_name_or_path,
131
+ num_labels=2,
132
+ )
133
+ with suppress_load_report_only():
134
+ base = AutoModelForSequenceClassification.from_pretrained(
135
+ peft_config.base_model_name_or_path,
136
+ config=config,
137
+ torch_dtype=torch.float32,
138
+ ).to(device)
139
+ else:
140
+ raise ValueError(f"Unsupported model_type: {model_type}")
141
+
142
+ with suppress_load_report_only():
143
+ model = PeftModel.from_pretrained(base, model_path).to(device)
144
+
145
+ model.eval()
146
+ return model