File size: 6,300 Bytes
27e5fcf 3cd9d3d 27e5fcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | from typing import Dict, List, Any
import torch.nn as nn
from transformers import BertModel
from transformers import BertConfig
from transformers import BertTokenizer
import torch
import os
import pickle
from typing import Any
import sys
import time
class FeedForward (nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, x):
x = self.dropout(self.activation(self.fc1(x)))
x = self.dropout(self.activation(self.fc2(x)))
return x
class BertForSequenceClassificationCustom(nn.Module):
"""BERT model for sequence classification with custom architecture"""
def __init__(self, config, num_labels):
super().__init__()
self.num_labels = num_labels
self.config = config
self.bert = BertModel(config) # Replace BertPreTrainedModel with BertModel
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.ffd = FeedForward(config.hidden_size, config.hidden_size*2, config.hidden_size) # New feedforward layer
self.classifier = nn.Linear(config.hidden_size, num_labels)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
pooled_output = outputs['pooler_output']
pooled_output = self.dropout(pooled_output)
internal_output = self.ffd(pooled_output) # Pass through new feedforward layer
logits = self.classifier(internal_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return type('ModelOutput', (), {
'loss': loss,
'logits': logits,
'hidden_states': outputs['last_hidden_state']
})()
def load_model(path ="") -> nn.Module:
filename = "checkpoint.chkpt"
filepath = os.path.join(path, filename)
print(f"Loading checkpoint from: { filepath }")
# Load the configuration and tokenizer
config = BertConfig.from_pretrained("bert-base-uncased")
# Initialize the model
num_labels = 4 # Update this based on your dataset
model = BertForSequenceClassificationCustom(config, num_labels=num_labels)
# Some checkpoints expect the class to be available in __main__ during unpickling.
# Temporarily inject the class into the __main__ module to satisfy torch.load.
import __main__ as _main
had_main_attr = hasattr(_main, 'BertForSequenceClassificationCustom')
if not had_main_attr:
setattr(_main, 'BertForSequenceClassificationCustom', BertForSequenceClassificationCustom)
try:
checkpoint = torch.load(filepath, weights_only=False)
finally:
# Clean up the injected attribute if we added it
if not had_main_attr and hasattr(_main, 'BertForSequenceClassificationCustom'):
delattr(_main, 'BertForSequenceClassificationCustom')
# Load state dict while ignoring mismatched layers
model_state_dict = model.state_dict()
sft_state_dict = checkpoint['model_state_dict']
# Filter out mismatched keys
filtered_state_dict = {
k: v for k, v in sft_state_dict.items() if k in model_state_dict and model_state_dict[k].shape == v.shape
}
# Update the model's state dict
model_state_dict.update(filtered_state_dict)
model.load_state_dict(model_state_dict)
print("Checkpoint loaded successfully")
model.eval()
return model
class EndpointHandler():
def __init__(self, path=""):
print(f"Initializing model from base path: {path}")
start = time.perf_counter()
self.model= load_model(path)
elapsed = time.perf_counter() - start
print(f"Model loaded in {elapsed:.2f}s")
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.labels = ["High", "Latent", "Medium", "None"] # Update based on your dataset
print("Compiling model...")
start = time.perf_counter()
self.model.compile()
elapsed = time.perf_counter() - start
print(f"Model compiled in {elapsed:.2f}s")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Accept either {'inputs': ...} or {'text': ...} or raw string/list
raw_inputs = data.get("inputs", None)
if raw_inputs is None:
raw_inputs = data.get("text", data)
# If payload nested inside inputs as a dict
if isinstance(raw_inputs, dict):
raw_inputs = raw_inputs.get("text", raw_inputs.get("inputs", raw_inputs))
# Normalize to list of strings
if isinstance(raw_inputs, str):
texts = [raw_inputs]
elif isinstance(raw_inputs, list):
texts = raw_inputs
else:
texts = [str(raw_inputs)]
# Tokenize in batch
inputs_tok = self.tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
with torch.no_grad():
start = time.perf_counter()
outputs = self.model(
input_ids=inputs_tok["input_ids"],
attention_mask=inputs_tok["attention_mask"]
)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
preds = torch.argmax(probabilities, dim=-1).tolist()
elapsed = time.perf_counter() - start
print(f"Processed {len(texts)} inputs in {elapsed:.2f}s")
results = []
for i, p in enumerate(preds):
results.append({
"text": texts[i],
"predicted_class": self.labels[int(p)] if int(p) < len(self.labels) else int(p),
"score": float(probabilities[i].max().item())
})
return results
|