ianfe's picture
Update handler.py
3cd9d3d verified
from typing import Dict, List, Any
import torch.nn as nn
from transformers import BertModel
from transformers import BertConfig
from transformers import BertTokenizer
import torch
import os
import pickle
from typing import Any
import sys
import time
class FeedForward (nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, x):
x = self.dropout(self.activation(self.fc1(x)))
x = self.dropout(self.activation(self.fc2(x)))
return x
class BertForSequenceClassificationCustom(nn.Module):
"""BERT model for sequence classification with custom architecture"""
def __init__(self, config, num_labels):
super().__init__()
self.num_labels = num_labels
self.config = config
self.bert = BertModel(config) # Replace BertPreTrainedModel with BertModel
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.ffd = FeedForward(config.hidden_size, config.hidden_size*2, config.hidden_size) # New feedforward layer
self.classifier = nn.Linear(config.hidden_size, num_labels)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
pooled_output = outputs['pooler_output']
pooled_output = self.dropout(pooled_output)
internal_output = self.ffd(pooled_output) # Pass through new feedforward layer
logits = self.classifier(internal_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return type('ModelOutput', (), {
'loss': loss,
'logits': logits,
'hidden_states': outputs['last_hidden_state']
})()
def load_model(path ="") -> nn.Module:
filename = "checkpoint.chkpt"
filepath = os.path.join(path, filename)
print(f"Loading checkpoint from: { filepath }")
# Load the configuration and tokenizer
config = BertConfig.from_pretrained("bert-base-uncased")
# Initialize the model
num_labels = 4 # Update this based on your dataset
model = BertForSequenceClassificationCustom(config, num_labels=num_labels)
# Some checkpoints expect the class to be available in __main__ during unpickling.
# Temporarily inject the class into the __main__ module to satisfy torch.load.
import __main__ as _main
had_main_attr = hasattr(_main, 'BertForSequenceClassificationCustom')
if not had_main_attr:
setattr(_main, 'BertForSequenceClassificationCustom', BertForSequenceClassificationCustom)
try:
checkpoint = torch.load(filepath, weights_only=False)
finally:
# Clean up the injected attribute if we added it
if not had_main_attr and hasattr(_main, 'BertForSequenceClassificationCustom'):
delattr(_main, 'BertForSequenceClassificationCustom')
# Load state dict while ignoring mismatched layers
model_state_dict = model.state_dict()
sft_state_dict = checkpoint['model_state_dict']
# Filter out mismatched keys
filtered_state_dict = {
k: v for k, v in sft_state_dict.items() if k in model_state_dict and model_state_dict[k].shape == v.shape
}
# Update the model's state dict
model_state_dict.update(filtered_state_dict)
model.load_state_dict(model_state_dict)
print("Checkpoint loaded successfully")
model.eval()
return model
class EndpointHandler():
def __init__(self, path=""):
print(f"Initializing model from base path: {path}")
start = time.perf_counter()
self.model= load_model(path)
elapsed = time.perf_counter() - start
print(f"Model loaded in {elapsed:.2f}s")
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.labels = ["High", "Latent", "Medium", "None"] # Update based on your dataset
print("Compiling model...")
start = time.perf_counter()
self.model.compile()
elapsed = time.perf_counter() - start
print(f"Model compiled in {elapsed:.2f}s")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Accept either {'inputs': ...} or {'text': ...} or raw string/list
raw_inputs = data.get("inputs", None)
if raw_inputs is None:
raw_inputs = data.get("text", data)
# If payload nested inside inputs as a dict
if isinstance(raw_inputs, dict):
raw_inputs = raw_inputs.get("text", raw_inputs.get("inputs", raw_inputs))
# Normalize to list of strings
if isinstance(raw_inputs, str):
texts = [raw_inputs]
elif isinstance(raw_inputs, list):
texts = raw_inputs
else:
texts = [str(raw_inputs)]
# Tokenize in batch
inputs_tok = self.tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
with torch.no_grad():
start = time.perf_counter()
outputs = self.model(
input_ids=inputs_tok["input_ids"],
attention_mask=inputs_tok["attention_mask"]
)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
preds = torch.argmax(probabilities, dim=-1).tolist()
elapsed = time.perf_counter() - start
print(f"Processed {len(texts)} inputs in {elapsed:.2f}s")
results = []
for i, p in enumerate(preds):
results.append({
"text": texts[i],
"predicted_class": self.labels[int(p)] if int(p) < len(self.labels) else int(p),
"score": float(probabilities[i].max().item())
})
return results