File size: 2,183 Bytes
f67bbe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import annotations

import json
from pathlib import Path

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer


def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    summed = torch.sum(last_hidden_state * mask, dim=1)
    denom = torch.clamp(mask.sum(dim=1), min=1e-9)
    return summed / denom


class EmbeddingClassifier(nn.Module):
    def __init__(self, model_name: str, num_labels: int):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, **inputs) -> torch.Tensor:
        outputs = self.encoder(**inputs)
        pooled = mean_pooling(outputs.last_hidden_state, inputs["attention_mask"])
        return self.classifier(self.dropout(pooled))


def load_bundle(bundle_dir: str | Path):
    bundle_dir = Path(bundle_dir)
    config = json.loads((bundle_dir / "hf_export_config.json").read_text(encoding="utf-8"))
    model = EmbeddingClassifier(
        model_name=config["base_model_name"],
        num_labels=config["num_labels"],
    )
    checkpoint = torch.load(bundle_dir / "best.pt", map_location="cpu", weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(bundle_dir / "tokenizer")
    id2label = {int(k): v for k, v in config["id2label"].items()}
    return tokenizer, model, id2label


if __name__ == "__main__":
    tokenizer, model, id2label = load_bundle(Path(__file__).resolve().parent)
    text = "๋‚ด๋…„ ๊ธˆ๋ฆฌ ์ธํ•˜ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์•„์งˆ ๊ฒƒ์œผ๋กœ ์ „๋ง๋œ๋‹ค."
    encoded = tokenizer(
        [text],
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    with torch.no_grad():
        logits = model(**encoded)
        pred = int(torch.argmax(logits, dim=1).item())
    print({"text": text, "pred_label": id2label[pred]})