IvoHoese commited on
Commit
8b0ace6
·
verified ·
1 Parent(s): 5ac3ae1

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -146
utils.py DELETED
@@ -1,146 +0,0 @@
1
- from transformers import OlmoModel, OlmoPreTrainedModel, GenerationMixin, AutoConfig, AutoModelForSequenceClassification
2
- from transformers.modeling_outputs import SequenceClassifierOutputWithPast
3
- import torch
4
-
5
- from peft import PeftModel, PeftConfig
6
-
7
- from transformers import AutoConfig
8
-
9
- import logging
10
- from contextlib import contextmanager
11
-
12
- # The custom model for using Olmo with a sequence classification task
13
-
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- class OlmoForSequenceClassification(OlmoPreTrainedModel, GenerationMixin):
17
- def __init__(self, config):
18
- # Check OlmoForCausalLM.__init__
19
- super().__init__(config)
20
- self.model = OlmoModel(config)
21
- self.num_labels = config.num_labels
22
- self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
23
-
24
- # Initialize weights and apply final processing
25
- self.post_init()
26
-
27
- def forward(
28
- self,
29
- input_ids: torch.LongTensor = None,
30
- attention_mask: torch.Tensor | None = None,
31
- labels: torch.LongTensor | None = None,
32
- **kwargs,
33
- ) -> SequenceClassifierOutputWithPast:
34
- outputs = self.model(
35
- input_ids=input_ids,
36
- attention_mask=attention_mask,
37
- **kwargs,
38
- )
39
- logits = self.classifier(outputs.last_hidden_state) # [B, N, H] => [B, N, C]
40
- pooled_logits = logits[:, -1] # NOTE: tokenizer.padding_side must be 'left'
41
-
42
- loss = None
43
- if labels is not None:
44
- loss = self.loss_function(
45
- logits=logits,
46
- labels=labels,
47
- pooled_logits=pooled_logits,
48
- config=self.config,
49
- )
50
-
51
- return SequenceClassifierOutputWithPast(
52
- loss=loss,
53
- logits=pooled_logits,
54
- past_key_values=outputs.past_key_values,
55
- hidden_states=outputs.hidden_states,
56
- attentions=outputs.attentions,
57
- )
58
-
59
- # The function for loading a fulltuning model
60
-
61
- def get_fulltuning_model(model_path, model_type="olmo"):
62
- if model_type == "olmo":
63
- model = OlmoForSequenceClassification.from_pretrained(
64
- model_path,
65
- trust_remote_code=True,
66
- torch_dtype=torch.float32,
67
- ).to("cuda" if torch.cuda.is_available() else "cpu")
68
- model.eval()
69
- elif model_type == "pythia":
70
- cfg = AutoConfig.from_pretrained(model_path, num_labels=3)
71
- model = AutoModelForSequenceClassification.from_pretrained(
72
- model_path,
73
- config=cfg,
74
- torch_dtype=torch.float32,
75
- ).to(device)
76
- else:
77
- raise ValueError(f"Unsupported model_type: {model_type}")
78
-
79
- return model
80
-
81
- # The following function is used to suppress a "missing or unexpected params" warning.
82
- # This warning is no reason for concern. It stems from the fact that the model is first loaded
83
- # without a classifier head, which is added afterwards.
84
-
85
- class DropLoadReport(logging.Filter):
86
- def filter(self, record: logging.LogRecord) -> bool:
87
- return "LOAD REPORT" not in record.getMessage()
88
-
89
- @contextmanager
90
- def suppress_load_report_only():
91
- f = DropLoadReport()
92
-
93
- names = [
94
- "transformers.modeling_utils",
95
- "transformers.modeling_tf_pytorch_utils",
96
- "transformers",
97
- ]
98
- loggers = [logging.getLogger(n) for n in names]
99
-
100
- for lg in loggers:
101
- lg.addFilter(f)
102
- try:
103
- yield
104
- finally:
105
- for lg in loggers:
106
- lg.removeFilter(f)
107
-
108
- # The function for loading a softprompt model
109
-
110
- def get_peft_model(model_path, model_type="olmo"):
111
- peft_config = PeftConfig.from_pretrained(model_path)
112
- device = "cuda" if torch.cuda.is_available() else "cpu"
113
-
114
- if model_type == "olmo":
115
- config = AutoConfig.from_pretrained(
116
- peft_config.base_model_name_or_path,
117
- trust_remote_code=True,
118
- num_labels=2,
119
- )
120
- with suppress_load_report_only():
121
- base = OlmoForSequenceClassification.from_pretrained(
122
- peft_config.base_model_name_or_path,
123
- trust_remote_code=True,
124
- torch_dtype=torch.float32,
125
- config=config,
126
- ).to(device)
127
-
128
- elif model_type == "pythia":
129
- config = AutoConfig.from_pretrained(
130
- peft_config.base_model_name_or_path,
131
- num_labels=2,
132
- )
133
- with suppress_load_report_only():
134
- base = AutoModelForSequenceClassification.from_pretrained(
135
- peft_config.base_model_name_or_path,
136
- config=config,
137
- torch_dtype=torch.float32,
138
- ).to(device)
139
- else:
140
- raise ValueError(f"Unsupported model_type: {model_type}")
141
-
142
- with suppress_load_report_only():
143
- model = PeftModel.from_pretrained(base, model_path).to(device)
144
-
145
- model.eval()
146
- return model