Spaces:
Running
Running
File size: 2,927 Bytes
0fc77f3 bbb241c f1f1dc0 0fc77f3 90bc37b 0fc77f3 f1f1dc0 0fc77f3 bbb241c 90bc37b 3d8218a bbb241c 3d8218a 90bc37b f1f1dc0 90bc37b 0fc77f3 f1f1dc0 0fc77f3 bbb241c 3d8218a 90bc37b 3d8218a bbb241c 90bc37b 0fc77f3 bbb241c 0fc77f3 bbb241c 79d952a 0fc77f3 bbb241c f1f1dc0 0fc77f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os, json, torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from config import DEVICE, HF_TOKEN
SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"]
def get_current_branch():
if os.path.exists("current_branch.txt"):
with open("current_branch.txt", "r") as f:
return f.read().strip()
return "latest"
class ModelWrapper:
def __init__(self):
# Flags ์ ๋ณด
flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
self.num_flags = len(self.flags_order)
branch = get_current_branch()
# 1) ํ ํฌ๋์ด์ (ํ์ต ๋น์ vocab + SPECIALS)
self.tokenizer = AutoTokenizer.from_pretrained(
"m97j/npc_LoRA-fps", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ repo
revision=branch,
subfolder="testcase_output", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ ๊ฒฝ๋ก
use_fast=True,
token=HF_TOKEN,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS})
# 2) ๋ณํฉ๋ ๋ชจ๋ธ ๋ก๋ (์ค๋ ์๋ ์ธ์)
self.model = AutoModelForCausalLM.from_pretrained(
"m97j/npc_LoRA-fps", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ repo
revision=branch,
subfolder="testcase_output", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ ๊ฒฝ๋ก
device_map=None, # ์คํ๋ก๋ฉ ๋นํ์ฑํ
low_cpu_mem_usage=False,
trust_remote_code=True,
token=HF_TOKEN
)
# 3) ์ปค์คํ
ํค๋ ์ถ๊ฐ
hidden_size = self.model.config.hidden_size
self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
# 4) ์ปค์คํ
ํค๋ ๊ฐ์ค์น ๋ก๋
for head_name, file_name in [
("delta_head", "delta_head.pt"),
("flag_head", "flag_head.pt"),
("flag_threshold_head", "flag_threshold_head.pt")
]:
try:
if os.path.exists(file_name):
getattr(self.model, head_name).load_state_dict(
torch.load(file_name, map_location=DEVICE)
)
except Exception as e:
print(f"[WARN] Failed to load {file_name}: {e}")
# 5) ๋๋ฐ์ด์ค ๋ฐฐ์น
self.model.to(DEVICE)
self.model.eval()
def get(self):
return self.tokenizer, self.model, self.flags_order
|