File size: 3,007 Bytes
39eaf79
 
 
 
0fc77f3
bbb241c
39eaf79
f1f1dc0
 
 
0fc77f3
 
 
 
90bc37b
0fc77f3
 
 
39eaf79
0fc77f3
 
 
 
bbb241c
 
39eaf79
90bc37b
39eaf79
bbb241c
39eaf79
90bc37b
f1f1dc0
 
90bc37b
0fc77f3
 
 
f1f1dc0
0fc77f3
39eaf79
bbb241c
39eaf79
90bc37b
39eaf79
 
bbb241c
 
90bc37b
 
0fc77f3
39eaf79
0fc77f3
 
 
 
 
39eaf79
 
79d952a
 
 
 
 
 
 
 
 
 
 
 
0fc77f3
39eaf79
f1f1dc0
0fc77f3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import json
import os

import torch
import torch.nn as nn
from config import DEVICE, HF_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer

SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"]

def get_current_branch():
    if os.path.exists("current_branch.txt"):
        with open("current_branch.txt", "r") as f:
            return f.read().strip()
    return "latest"

class ModelWrapper:
    def __init__(self):
        # Flags info
        flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
        self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
        self.num_flags = len(self.flags_order)

        branch = get_current_branch()

        # 1) Tokenizer (vocab + SPECIALS at the time of training LoRA)
        self.tokenizer = AutoTokenizer.from_pretrained(
            "m97j/npc_LoRA-fps",
            revision=branch,
            subfolder="testcase_output",
            use_fast=True,
            token=HF_TOKEN,
            trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"
        self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS})

        # 2) Base model (LoRA model with merged weights, but without custom heads)
        self.model = AutoModelForCausalLM.from_pretrained(
            "m97j/npc_LoRA-fps",
            revision=branch,
            subfolder="testcase_output",
            device_map=None,
            low_cpu_mem_usage=False,
            trust_remote_code=True,
            token=HF_TOKEN
        )

        # 3) add custom heads (delta, flag, flag_threshold) - architecture only, weights will be loaded separately
        hidden_size = self.model.config.hidden_size
        self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
        self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
        self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)

        # 4) Load custom head weights separately (if available)
        #  - this is necessary because the LoRA merging process may not include these heads, and they might be trained separately.
        for head_name, file_name in [
            ("delta_head", "delta_head.pt"),
            ("flag_head", "flag_head.pt"),
            ("flag_threshold_head", "flag_threshold_head.pt")
        ]:
            try:
                if os.path.exists(file_name):
                    getattr(self.model, head_name).load_state_dict(
                        torch.load(file_name, map_location=DEVICE)
                    )
            except Exception as e:
                print(f"[WARN] Failed to load {file_name}: {e}")

        # 5) Move model to device and set to eval mode
        self.model.to(DEVICE)
        self.model.eval()

    def get(self):
        return self.tokenizer, self.model, self.flags_order