Spaces:
Running
Running
File size: 3,007 Bytes
39eaf79 0fc77f3 bbb241c 39eaf79 f1f1dc0 0fc77f3 90bc37b 0fc77f3 39eaf79 0fc77f3 bbb241c 39eaf79 90bc37b 39eaf79 bbb241c 39eaf79 90bc37b f1f1dc0 90bc37b 0fc77f3 f1f1dc0 0fc77f3 39eaf79 bbb241c 39eaf79 90bc37b 39eaf79 bbb241c 90bc37b 0fc77f3 39eaf79 0fc77f3 39eaf79 79d952a 0fc77f3 39eaf79 f1f1dc0 0fc77f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import json
import os
import torch
import torch.nn as nn
from config import DEVICE, HF_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"]
def get_current_branch():
if os.path.exists("current_branch.txt"):
with open("current_branch.txt", "r") as f:
return f.read().strip()
return "latest"
class ModelWrapper:
def __init__(self):
# Flags info
flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
self.num_flags = len(self.flags_order)
branch = get_current_branch()
# 1) Tokenizer (vocab + SPECIALS at the time of training LoRA)
self.tokenizer = AutoTokenizer.from_pretrained(
"m97j/npc_LoRA-fps",
revision=branch,
subfolder="testcase_output",
use_fast=True,
token=HF_TOKEN,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS})
# 2) Base model (LoRA model with merged weights, but without custom heads)
self.model = AutoModelForCausalLM.from_pretrained(
"m97j/npc_LoRA-fps",
revision=branch,
subfolder="testcase_output",
device_map=None,
low_cpu_mem_usage=False,
trust_remote_code=True,
token=HF_TOKEN
)
# 3) add custom heads (delta, flag, flag_threshold) - architecture only, weights will be loaded separately
hidden_size = self.model.config.hidden_size
self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
# 4) Load custom head weights separately (if available)
# - this is necessary because the LoRA merging process may not include these heads, and they might be trained separately.
for head_name, file_name in [
("delta_head", "delta_head.pt"),
("flag_head", "flag_head.pt"),
("flag_threshold_head", "flag_threshold_head.pt")
]:
try:
if os.path.exists(file_name):
getattr(self.model, head_name).load_state_dict(
torch.load(file_name, map_location=DEVICE)
)
except Exception as e:
print(f"[WARN] Failed to load {file_name}: {e}")
# 5) Move model to device and set to eval mode
self.model.to(DEVICE)
self.model.eval()
def get(self):
return self.tokenizer, self.model, self.flags_order
|