File size: 2,093 Bytes
9142c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import torch
import torch.nn as nn

from core.device import DEVICE
from language.embeddings import EmbeddingLayer
from language.encoder import SentenceEncoder
from language.tokenizer import SimpleTokenizer

ARTIFACTS_DIR = "artifacts"


class ProgrammingModel(nn.Module):
    def __init__(self, tokenizer: SimpleTokenizer):
        super().__init__()

        self.tokenizer = tokenizer

        # EXACT SAME ARCHITECTURE AS TRAINING
        self.embedder = EmbeddingLayer(
            len(tokenizer.vocab),
            pad_index=tokenizer.vocab[tokenizer.PAD_TOKEN]
        )

        self.encoder = SentenceEncoder()
        self.classifier = nn.Linear(
            self.encoder.projection.out_features,
            2
        )

        self.load_models()
        self.to(DEVICE)
        self.eval()

    def load_models(self):
        self.embedder.load_state_dict(
            torch.load(os.path.join(ARTIFACTS_DIR, "programming_embedding.pt"), map_location=DEVICE)
        )
        self.encoder.load_state_dict(
            torch.load(os.path.join(ARTIFACTS_DIR, "programming_encoder.pt"), map_location=DEVICE)
        )
        self.classifier.load_state_dict(
            torch.load(os.path.join(ARTIFACTS_DIR, "programming_classifier.pt"), map_location=DEVICE)
        )

    def forward(self, token_ids):
        embeddings = self.embedder(token_ids)
        attention_mask = (token_ids != self.tokenizer.vocab[self.tokenizer.PAD_TOKEN]).long()
        sentence_vec = self.encoder(embeddings, attention_mask=attention_mask)
        return self.classifier(sentence_vec)

    def predict(self, text: str):
        token_ids = torch.tensor(
            [self.tokenizer.encode(text)],
            dtype=torch.long
        ).to(DEVICE)

        with torch.no_grad():
            logits = self.forward(token_ids)
            probs = torch.softmax(logits, dim=-1)
            label_idx = torch.argmax(probs, dim=-1).item()

        return {
            "label": "programming" if label_idx == 1 else "non_programming",
            "confidence": probs[0][label_idx].item()
        }