from typing import Dict, List, Any import torch.nn as nn from transformers import BertModel from transformers import BertConfig from transformers import BertTokenizer import torch import os import pickle from typing import Any import sys import time class FeedForward (nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1): super(FeedForward, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU() def forward(self, x): x = self.dropout(self.activation(self.fc1(x))) x = self.dropout(self.activation(self.fc2(x))) return x class BertForSequenceClassificationCustom(nn.Module): """BERT model for sequence classification with custom architecture""" def __init__(self, config, num_labels): super().__init__() self.num_labels = num_labels self.config = config self.bert = BertModel(config) # Replace BertPreTrainedModel with BertModel self.dropout = nn.Dropout(config.hidden_dropout_prob) self.ffd = FeedForward(config.hidden_size, config.hidden_size*2, config.hidden_size) # New feedforward layer self.classifier = nn.Linear(config.hidden_size, num_labels) def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) pooled_output = outputs['pooler_output'] pooled_output = self.dropout(pooled_output) internal_output = self.ffd(pooled_output) # Pass through new feedforward layer logits = self.classifier(internal_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return type('ModelOutput', (), { 'loss': loss, 'logits': logits, 'hidden_states': outputs['last_hidden_state'] })() def load_model(path ="") -> nn.Module: filename = "checkpoint.chkpt" filepath = os.path.join(path, filename) print(f"Loading checkpoint from: { filepath }") # Load the configuration and tokenizer config = BertConfig.from_pretrained("bert-base-uncased") # Initialize the model num_labels = 4 # Update this based on your dataset model = BertForSequenceClassificationCustom(config, num_labels=num_labels) # Some checkpoints expect the class to be available in __main__ during unpickling. # Temporarily inject the class into the __main__ module to satisfy torch.load. import __main__ as _main had_main_attr = hasattr(_main, 'BertForSequenceClassificationCustom') if not had_main_attr: setattr(_main, 'BertForSequenceClassificationCustom', BertForSequenceClassificationCustom) try: checkpoint = torch.load(filepath, weights_only=False) finally: # Clean up the injected attribute if we added it if not had_main_attr and hasattr(_main, 'BertForSequenceClassificationCustom'): delattr(_main, 'BertForSequenceClassificationCustom') # Load state dict while ignoring mismatched layers model_state_dict = model.state_dict() sft_state_dict = checkpoint['model_state_dict'] # Filter out mismatched keys filtered_state_dict = { k: v for k, v in sft_state_dict.items() if k in model_state_dict and model_state_dict[k].shape == v.shape } # Update the model's state dict model_state_dict.update(filtered_state_dict) model.load_state_dict(model_state_dict) print("Checkpoint loaded successfully") model.eval() return model class EndpointHandler(): def __init__(self, path=""): print(f"Initializing model from base path: {path}") start = time.perf_counter() self.model= load_model(path) elapsed = time.perf_counter() - start print(f"Model loaded in {elapsed:.2f}s") self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.labels = ["High", "Latent", "Medium", "None"] # Update based on your dataset print("Compiling model...") start = time.perf_counter() self.model.compile() elapsed = time.perf_counter() - start print(f"Model compiled in {elapsed:.2f}s") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # Accept either {'inputs': ...} or {'text': ...} or raw string/list raw_inputs = data.get("inputs", None) if raw_inputs is None: raw_inputs = data.get("text", data) # If payload nested inside inputs as a dict if isinstance(raw_inputs, dict): raw_inputs = raw_inputs.get("text", raw_inputs.get("inputs", raw_inputs)) # Normalize to list of strings if isinstance(raw_inputs, str): texts = [raw_inputs] elif isinstance(raw_inputs, list): texts = raw_inputs else: texts = [str(raw_inputs)] # Tokenize in batch inputs_tok = self.tokenizer( texts, return_tensors="pt", truncation=True, padding=True, max_length=256 ) with torch.no_grad(): start = time.perf_counter() outputs = self.model( input_ids=inputs_tok["input_ids"], attention_mask=inputs_tok["attention_mask"] ) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) preds = torch.argmax(probabilities, dim=-1).tolist() elapsed = time.perf_counter() - start print(f"Processed {len(texts)} inputs in {elapsed:.2f}s") results = [] for i, p in enumerate(preds): results.append({ "text": texts[i], "predicted_class": self.labels[int(p)] if int(p) < len(self.labels) else int(p), "score": float(probabilities[i].max().item()) }) return results