| | from typing import Dict, List, Any |
| | import torch.nn as nn |
| | from transformers import BertModel |
| | from transformers import BertConfig |
| | from transformers import BertTokenizer |
| | import torch |
| | import os |
| | import pickle |
| | from typing import Any |
| | import sys |
| | import time |
| |
|
| | class FeedForward (nn.Module): |
| | def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1): |
| | super(FeedForward, self).__init__() |
| | self.fc1 = nn.Linear(input_dim, hidden_dim) |
| | self.fc2 = nn.Linear(hidden_dim, output_dim) |
| | self.dropout = nn.Dropout(dropout) |
| | self.activation = nn.ReLU() |
| |
|
| | def forward(self, x): |
| | x = self.dropout(self.activation(self.fc1(x))) |
| | x = self.dropout(self.activation(self.fc2(x))) |
| | return x |
| |
|
| | class BertForSequenceClassificationCustom(nn.Module): |
| | """BERT model for sequence classification with custom architecture""" |
| | |
| | def __init__(self, config, num_labels): |
| | super().__init__() |
| | self.num_labels = num_labels |
| | self.config = config |
| | |
| | self.bert = BertModel(config) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| | self.ffd = FeedForward(config.hidden_size, config.hidden_size*2, config.hidden_size) |
| | self.classifier = nn.Linear(config.hidden_size, num_labels) |
| | |
| | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): |
| | outputs = self.bert( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids |
| | ) |
| | |
| | pooled_output = outputs['pooler_output'] |
| | pooled_output = self.dropout(pooled_output) |
| | internal_output = self.ffd(pooled_output) |
| | logits = self.classifier(internal_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | |
| | return type('ModelOutput', (), { |
| | 'loss': loss, |
| | 'logits': logits, |
| | 'hidden_states': outputs['last_hidden_state'] |
| | })() |
| |
|
| | |
| | def load_model(path ="") -> nn.Module: |
| | filename = "checkpoint.chkpt" |
| | filepath = os.path.join(path, filename) |
| | print(f"Loading checkpoint from: { filepath }") |
| |
|
| | |
| | config = BertConfig.from_pretrained("bert-base-uncased") |
| |
|
| |
|
| | |
| | num_labels = 4 |
| | model = BertForSequenceClassificationCustom(config, num_labels=num_labels) |
| |
|
| | |
| | |
| | import __main__ as _main |
| | had_main_attr = hasattr(_main, 'BertForSequenceClassificationCustom') |
| | if not had_main_attr: |
| | setattr(_main, 'BertForSequenceClassificationCustom', BertForSequenceClassificationCustom) |
| |
|
| | try: |
| | checkpoint = torch.load(filepath, weights_only=False) |
| | finally: |
| | |
| | if not had_main_attr and hasattr(_main, 'BertForSequenceClassificationCustom'): |
| | delattr(_main, 'BertForSequenceClassificationCustom') |
| |
|
| | |
| | model_state_dict = model.state_dict() |
| | sft_state_dict = checkpoint['model_state_dict'] |
| |
|
| | |
| | filtered_state_dict = { |
| | k: v for k, v in sft_state_dict.items() if k in model_state_dict and model_state_dict[k].shape == v.shape |
| | } |
| |
|
| | |
| | model_state_dict.update(filtered_state_dict) |
| | model.load_state_dict(model_state_dict) |
| | print("Checkpoint loaded successfully") |
| | model.eval() |
| | return model |
| |
|
| |
|
| |
|
| | class EndpointHandler(): |
| |
|
| | def __init__(self, path=""): |
| | print(f"Initializing model from base path: {path}") |
| | start = time.perf_counter() |
| | self.model= load_model(path) |
| | elapsed = time.perf_counter() - start |
| | print(f"Model loaded in {elapsed:.2f}s") |
| | self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
| | self.labels = ["High", "Latent", "Medium", "None"] |
| | print("Compiling model...") |
| | start = time.perf_counter() |
| | self.model.compile() |
| | elapsed = time.perf_counter() - start |
| | print(f"Model compiled in {elapsed:.2f}s") |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| |
|
| | |
| | raw_inputs = data.get("inputs", None) |
| | if raw_inputs is None: |
| | raw_inputs = data.get("text", data) |
| |
|
| | |
| | if isinstance(raw_inputs, dict): |
| | raw_inputs = raw_inputs.get("text", raw_inputs.get("inputs", raw_inputs)) |
| |
|
| | |
| | if isinstance(raw_inputs, str): |
| | texts = [raw_inputs] |
| | elif isinstance(raw_inputs, list): |
| | texts = raw_inputs |
| | else: |
| | texts = [str(raw_inputs)] |
| |
|
| | |
| | inputs_tok = self.tokenizer( |
| | texts, |
| | return_tensors="pt", |
| | truncation=True, |
| | padding=True, |
| | max_length=256 |
| | ) |
| |
|
| | with torch.no_grad(): |
| | start = time.perf_counter() |
| | outputs = self.model( |
| | input_ids=inputs_tok["input_ids"], |
| | attention_mask=inputs_tok["attention_mask"] |
| | ) |
| | logits = outputs.logits |
| | probabilities = torch.nn.functional.softmax(logits, dim=-1) |
| | preds = torch.argmax(probabilities, dim=-1).tolist() |
| | elapsed = time.perf_counter() - start |
| | print(f"Processed {len(texts)} inputs in {elapsed:.2f}s") |
| |
|
| | results = [] |
| | for i, p in enumerate(preds): |
| | results.append({ |
| | "text": texts[i], |
| | "predicted_class": self.labels[int(p)] if int(p) < len(self.labels) else int(p), |
| | "score": float(probabilities[i].max().item()) |
| | }) |
| |
|
| | return results |
| |
|
| |
|