ianfe commited on
Commit
27e5fcf
·
verified ·
1 Parent(s): 7ee6cac

Upload 2 files

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. checkpoint.chkpt +3 -0
  3. handler_intent.py +173 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint.chkpt filter=lfs diff=lfs merge=lfs -text
checkpoint.chkpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b89adfd2378cc1e237b28678ae25014b5481fa8cdc9732f2763513d56d211bf7
3
+ size 1342454211
handler_intent.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch.nn as nn
3
+ from transformers import BertModel
4
+ from transformers import BertConfig
5
+ from transformers import BertTokenizer
6
+ import torch
7
+ import os
8
+ import pickle
9
+ from typing import Any
10
+ import sys
11
+ import time
12
+
13
+ class FeedForward (nn.Module):
14
+ def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
15
+ super(FeedForward, self).__init__()
16
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
17
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
18
+ self.dropout = nn.Dropout(dropout)
19
+ self.activation = nn.ReLU()
20
+
21
+ def forward(self, x):
22
+ x = self.dropout(self.activation(self.fc1(x)))
23
+ x = self.dropout(self.activation(self.fc2(x)))
24
+ return x
25
+
26
+ class BertForSequenceClassificationCustom(nn.Module):
27
+ """BERT model for sequence classification with custom architecture"""
28
+
29
+ def __init__(self, config, num_labels):
30
+ super().__init__()
31
+ self.num_labels = num_labels
32
+ self.config = config
33
+
34
+ self.bert = BertModel(config) # Replace BertPreTrainedModel with BertModel
35
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
36
+ self.ffd = FeedForward(config.hidden_size, config.hidden_size*2, config.hidden_size) # New feedforward layer
37
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
38
+
39
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
40
+ outputs = self.bert(
41
+ input_ids=input_ids,
42
+ attention_mask=attention_mask,
43
+ token_type_ids=token_type_ids
44
+ )
45
+
46
+ pooled_output = outputs['pooler_output']
47
+ pooled_output = self.dropout(pooled_output)
48
+ internal_output = self.ffd(pooled_output) # Pass through new feedforward layer
49
+ logits = self.classifier(internal_output)
50
+
51
+ loss = None
52
+ if labels is not None:
53
+ loss_fct = nn.CrossEntropyLoss()
54
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
55
+
56
+ return type('ModelOutput', (), {
57
+ 'loss': loss,
58
+ 'logits': logits,
59
+ 'hidden_states': outputs['last_hidden_state']
60
+ })()
61
+
62
+
63
+ def load_model(path ="") -> nn.Module:
64
+ filename = "checkpoint.chkpt"
65
+ filepath = os.path.join(path, filename)
66
+ print(f"Loading checkpoint from: { filepath }")
67
+
68
+ # Load the configuration and tokenizer
69
+ config = BertConfig.from_pretrained("bert-base-uncased")
70
+
71
+
72
+ # Initialize the model
73
+ num_labels = 4 # Update this based on your dataset
74
+ model = BertForSequenceClassificationCustom(config, num_labels=num_labels)
75
+
76
+ # Some checkpoints expect the class to be available in __main__ during unpickling.
77
+ # Temporarily inject the class into the __main__ module to satisfy torch.load.
78
+ import __main__ as _main
79
+ had_main_attr = hasattr(_main, 'BertForSequenceClassificationCustom')
80
+ if not had_main_attr:
81
+ setattr(_main, 'BertForSequenceClassificationCustom', BertForSequenceClassificationCustom)
82
+
83
+ try:
84
+ checkpoint = torch.load(filepath, weights_only=False)
85
+ finally:
86
+ # Clean up the injected attribute if we added it
87
+ if not had_main_attr and hasattr(_main, 'BertForSequenceClassificationCustom'):
88
+ delattr(_main, 'BertForSequenceClassificationCustom')
89
+
90
+ # Load state dict while ignoring mismatched layers
91
+ model_state_dict = model.state_dict()
92
+ sft_state_dict = checkpoint['model_state_dict']
93
+
94
+ # Filter out mismatched keys
95
+ filtered_state_dict = {
96
+ k: v for k, v in sft_state_dict.items() if k in model_state_dict and model_state_dict[k].shape == v.shape
97
+ }
98
+
99
+ # Update the model's state dict
100
+ model_state_dict.update(filtered_state_dict)
101
+ model.load_state_dict(model_state_dict)
102
+ print("Checkpoint loaded successfully")
103
+ model.eval()
104
+ return model
105
+
106
+
107
+
108
+ class EndpointHandler():
109
+
110
+ def __init__(self, path=""):
111
+ print(f"Initializing model from base path: {path}")
112
+ start = time.perf_counter()
113
+ self.model= load_model(path)
114
+ elapsed = time.perf_counter() - start
115
+ print(f"Model loaded in {elapsed:.2f}s")
116
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
117
+ self.labels = ["High", "Low", "Medium", "UNKNOWN"] # Update based on your dataset
118
+ print("Compiling model...")
119
+ start = time.perf_counter()
120
+ self.model.compile()
121
+ elapsed = time.perf_counter() - start
122
+ print(f"Model compiled in {elapsed:.2f}s")
123
+
124
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
125
+
126
+ # Accept either {'inputs': ...} or {'text': ...} or raw string/list
127
+ raw_inputs = data.get("inputs", None)
128
+ if raw_inputs is None:
129
+ raw_inputs = data.get("text", data)
130
+
131
+ # If payload nested inside inputs as a dict
132
+ if isinstance(raw_inputs, dict):
133
+ raw_inputs = raw_inputs.get("text", raw_inputs.get("inputs", raw_inputs))
134
+
135
+ # Normalize to list of strings
136
+ if isinstance(raw_inputs, str):
137
+ texts = [raw_inputs]
138
+ elif isinstance(raw_inputs, list):
139
+ texts = raw_inputs
140
+ else:
141
+ texts = [str(raw_inputs)]
142
+
143
+ # Tokenize in batch
144
+ inputs_tok = self.tokenizer(
145
+ texts,
146
+ return_tensors="pt",
147
+ truncation=True,
148
+ padding=True,
149
+ max_length=256
150
+ )
151
+
152
+ with torch.no_grad():
153
+ start = time.perf_counter()
154
+ outputs = self.model(
155
+ input_ids=inputs_tok["input_ids"],
156
+ attention_mask=inputs_tok["attention_mask"]
157
+ )
158
+ logits = outputs.logits
159
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
160
+ preds = torch.argmax(probabilities, dim=-1).tolist()
161
+ elapsed = time.perf_counter() - start
162
+ print(f"Processed {len(texts)} inputs in {elapsed:.2f}s")
163
+
164
+ results = []
165
+ for i, p in enumerate(preds):
166
+ results.append({
167
+ "text": texts[i],
168
+ "predicted_class": self.labels[int(p)] if int(p) < len(self.labels) else int(p),
169
+ "score": float(probabilities[i].max().item())
170
+ })
171
+
172
+ return results
173
+