File size: 6,300 Bytes
27e5fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cd9d3d
27e5fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import Dict, List, Any
import torch.nn as nn
from transformers import BertModel
from transformers import BertConfig
from transformers import BertTokenizer
import torch
import os
import pickle
from typing import Any
import sys
import time

class FeedForward (nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.dropout(self.activation(self.fc1(x)))
        x = self.dropout(self.activation(self.fc2(x)))
        return x

class BertForSequenceClassificationCustom(nn.Module):
    """BERT model for sequence classification with custom architecture"""
    
    def __init__(self, config, num_labels):
        super().__init__()
        self.num_labels = num_labels
        self.config = config
        
        self.bert = BertModel(config)  # Replace BertPreTrainedModel with BertModel
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.ffd = FeedForward(config.hidden_size, config.hidden_size*2, config.hidden_size)  # New feedforward layer
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(
            input_ids=input_ids,    
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs['pooler_output'] 
        pooled_output = self.dropout(pooled_output)
        internal_output = self.ffd(pooled_output)  # Pass through new feedforward layer
        logits = self.classifier(internal_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
        return type('ModelOutput', (), {
            'loss': loss,
            'logits': logits,
            'hidden_states': outputs['last_hidden_state']
        })()

    
def load_model(path ="") -> nn.Module:
    filename = "checkpoint.chkpt"
    filepath = os.path.join(path, filename)
    print(f"Loading checkpoint from: {  filepath }")

    # Load the configuration and tokenizer
    config = BertConfig.from_pretrained("bert-base-uncased")


    # Initialize the model
    num_labels = 4  # Update this based on your dataset
    model = BertForSequenceClassificationCustom(config, num_labels=num_labels)

    # Some checkpoints expect the class to be available in __main__ during unpickling.
    # Temporarily inject the class into the __main__ module to satisfy torch.load.
    import __main__ as _main
    had_main_attr = hasattr(_main, 'BertForSequenceClassificationCustom')
    if not had_main_attr:
        setattr(_main, 'BertForSequenceClassificationCustom', BertForSequenceClassificationCustom)

    try:
        checkpoint = torch.load(filepath, weights_only=False)
    finally:
        # Clean up the injected attribute if we added it
        if not had_main_attr and hasattr(_main, 'BertForSequenceClassificationCustom'):
            delattr(_main, 'BertForSequenceClassificationCustom')

    # Load state dict while ignoring mismatched layers
    model_state_dict = model.state_dict()
    sft_state_dict = checkpoint['model_state_dict']

    # Filter out mismatched keys
    filtered_state_dict = {
        k: v for k, v in sft_state_dict.items() if k in model_state_dict and model_state_dict[k].shape == v.shape
    }

    # Update the model's state dict
    model_state_dict.update(filtered_state_dict)
    model.load_state_dict(model_state_dict)
    print("Checkpoint loaded successfully")
    model.eval()
    return model



class EndpointHandler():

    def __init__(self, path=""):
        print(f"Initializing model from base path: {path}")
        start = time.perf_counter()
        self.model= load_model(path)
        elapsed = time.perf_counter() - start
        print(f"Model loaded in {elapsed:.2f}s")
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.labels = ["High", "Latent", "Medium", "None"]  # Update based on your dataset
        print("Compiling model...")
        start = time.perf_counter()
        self.model.compile()
        elapsed = time.perf_counter() - start
        print(f"Model compiled in {elapsed:.2f}s")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:

        # Accept either {'inputs': ...} or {'text': ...} or raw string/list
        raw_inputs = data.get("inputs", None)
        if raw_inputs is None:
            raw_inputs = data.get("text", data)

        # If payload nested inside inputs as a dict
        if isinstance(raw_inputs, dict):
            raw_inputs = raw_inputs.get("text", raw_inputs.get("inputs", raw_inputs))

        # Normalize to list of strings
        if isinstance(raw_inputs, str):
            texts = [raw_inputs]
        elif isinstance(raw_inputs, list):
            texts = raw_inputs
        else:
            texts = [str(raw_inputs)]

        # Tokenize in batch
        inputs_tok = self.tokenizer(
            texts,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=256
        )

        with torch.no_grad():
            start = time.perf_counter()
            outputs = self.model(
                input_ids=inputs_tok["input_ids"],
                attention_mask=inputs_tok["attention_mask"]
            )
            logits = outputs.logits
            probabilities = torch.nn.functional.softmax(logits, dim=-1)
            preds = torch.argmax(probabilities, dim=-1).tolist()
            elapsed = time.perf_counter() - start
            print(f"Processed {len(texts)} inputs in {elapsed:.2f}s")

        results = []
        for i, p in enumerate(preds):
            results.append({
                "text": texts[i],
                "predicted_class": self.labels[int(p)] if int(p) < len(self.labels) else int(p),
                "score": float(probabilities[i].max().item())
            })

        return results