Spaces:
Running
Running
| import os, json, torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from config import DEVICE, HF_TOKEN | |
| SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"] | |
| def get_current_branch(): | |
| if os.path.exists("current_branch.txt"): | |
| with open("current_branch.txt", "r") as f: | |
| return f.read().strip() | |
| return "latest" | |
| class ModelWrapper: | |
| def __init__(self): | |
| # Flags ์ ๋ณด | |
| flags_path = os.path.join(os.path.dirname(__file__), "flags.json") | |
| self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"] | |
| self.num_flags = len(self.flags_order) | |
| branch = get_current_branch() | |
| # 1) ํ ํฌ๋์ด์ (ํ์ต ๋น์ vocab + SPECIALS) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "m97j/npc_LoRA-fps", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ repo | |
| revision=branch, | |
| subfolder="testcase_output", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ ๊ฒฝ๋ก | |
| use_fast=True, | |
| token=HF_TOKEN, | |
| trust_remote_code=True | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.tokenizer.padding_side = "right" | |
| self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS}) | |
| # 2) ๋ณํฉ๋ ๋ชจ๋ธ ๋ก๋ (์ค๋ ์๋ ์ธ์) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "m97j/npc_LoRA-fps", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ repo | |
| revision=branch, | |
| subfolder="testcase_output", # ๋ณํฉ๋ ๋ชจ๋ธ์ด ์ฌ๋ผ๊ฐ ๊ฒฝ๋ก | |
| device_map=None, # ์คํ๋ก๋ฉ ๋นํ์ฑํ | |
| low_cpu_mem_usage=False, | |
| trust_remote_code=True, | |
| token=HF_TOKEN | |
| ) | |
| # 3) ์ปค์คํ ํค๋ ์ถ๊ฐ | |
| hidden_size = self.model.config.hidden_size | |
| self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE) | |
| self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
| self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
| # 4) ์ปค์คํ ํค๋ ๊ฐ์ค์น ๋ก๋ | |
| for head_name, file_name in [ | |
| ("delta_head", "delta_head.pt"), | |
| ("flag_head", "flag_head.pt"), | |
| ("flag_threshold_head", "flag_threshold_head.pt") | |
| ]: | |
| try: | |
| if os.path.exists(file_name): | |
| getattr(self.model, head_name).load_state_dict( | |
| torch.load(file_name, map_location=DEVICE) | |
| ) | |
| except Exception as e: | |
| print(f"[WARN] Failed to load {file_name}: {e}") | |
| # 5) ๋๋ฐ์ด์ค ๋ฐฐ์น | |
| self.model.to(DEVICE) | |
| self.model.eval() | |
| def get(self): | |
| return self.tokenizer, self.model, self.flags_order | |