File size: 2,927 Bytes
0fc77f3
 
 
bbb241c
f1f1dc0
 
 
0fc77f3
 
 
 
90bc37b
0fc77f3
 
 
f1f1dc0
0fc77f3
 
 
 
bbb241c
 
 
90bc37b
3d8218a
bbb241c
3d8218a
90bc37b
f1f1dc0
 
90bc37b
0fc77f3
 
 
f1f1dc0
0fc77f3
bbb241c
 
3d8218a
90bc37b
3d8218a
bbb241c
 
 
90bc37b
 
0fc77f3
bbb241c
0fc77f3
 
 
 
 
bbb241c
79d952a
 
 
 
 
 
 
 
 
 
 
 
0fc77f3
bbb241c
f1f1dc0
0fc77f3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os, json, torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from config import DEVICE, HF_TOKEN

SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"]

def get_current_branch():
    if os.path.exists("current_branch.txt"):
        with open("current_branch.txt", "r") as f:
            return f.read().strip()
    return "latest"

class ModelWrapper:
    def __init__(self):
        # Flags ์ •๋ณด
        flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
        self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
        self.num_flags = len(self.flags_order)

        branch = get_current_branch()

        # 1) ํ† ํฌ๋‚˜์ด์ € (ํ•™์Šต ๋‹น์‹œ vocab + SPECIALS)
        self.tokenizer = AutoTokenizer.from_pretrained(
            "m97j/npc_LoRA-fps",          # ๋ณ‘ํ•ฉ๋œ ๋ชจ๋ธ์ด ์˜ฌ๋ผ๊ฐ„ repo
            revision=branch,
            subfolder="testcase_output", # ๋ณ‘ํ•ฉ๋œ ๋ชจ๋ธ์ด ์˜ฌ๋ผ๊ฐ„ ๊ฒฝ๋กœ
            use_fast=True,
            token=HF_TOKEN,
            trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"
        self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS})

        # 2) ๋ณ‘ํ•ฉ๋œ ๋ชจ๋ธ ๋กœ๋“œ (์ƒค๋“œ ์ž๋™ ์ธ์‹)
        self.model = AutoModelForCausalLM.from_pretrained(
            "m97j/npc_LoRA-fps",          # ๋ณ‘ํ•ฉ๋œ ๋ชจ๋ธ์ด ์˜ฌ๋ผ๊ฐ„ repo
            revision=branch,
            subfolder="testcase_output", # ๋ณ‘ํ•ฉ๋œ ๋ชจ๋ธ์ด ์˜ฌ๋ผ๊ฐ„ ๊ฒฝ๋กœ
            device_map=None,              # ์˜คํ”„๋กœ๋”ฉ ๋น„ํ™œ์„ฑํ™”
            low_cpu_mem_usage=False,
            trust_remote_code=True,
            token=HF_TOKEN
        )

        # 3) ์ปค์Šคํ…€ ํ—ค๋“œ ์ถ”๊ฐ€
        hidden_size = self.model.config.hidden_size
        self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
        self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
        self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)

        # 4) ์ปค์Šคํ…€ ํ—ค๋“œ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
        for head_name, file_name in [
            ("delta_head", "delta_head.pt"),
            ("flag_head", "flag_head.pt"),
            ("flag_threshold_head", "flag_threshold_head.pt")
        ]:
            try:
                if os.path.exists(file_name):
                    getattr(self.model, head_name).load_state_dict(
                        torch.load(file_name, map_location=DEVICE)
                    )
            except Exception as e:
                print(f"[WARN] Failed to load {file_name}: {e}")

        # 5) ๋””๋ฐ”์ด์Šค ๋ฐฐ์น˜
        self.model.to(DEVICE)
        self.model.eval()

    def get(self):
        return self.tokenizer, self.model, self.flags_order