m97j commited on
Commit
79d952a
ยท
1 Parent(s): 8d29c2e

Update model_loading logig

Browse files
Files changed (2) hide show
  1. config.py +1 -1
  2. model_loader.py +22 -10
config.py CHANGED
@@ -7,7 +7,7 @@ load_dotenv()
7
 
8
  # ๋ชจ๋ธ ๊ฒฝ๋กœ (ํ™˜๊ฒฝ๋ณ€์ˆ˜ ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์‚ฌ์šฉ)
9
  BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
10
- ADAPTER_MODEL = os.getenv("ADAPTER_MODEL", "m97j/npc_LoRA-fps")
11
 
12
  # ์žฅ์น˜ ์„ค์ •
13
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
 
7
 
8
  # ๋ชจ๋ธ ๊ฒฝ๋กœ (ํ™˜๊ฒฝ๋ณ€์ˆ˜ ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์‚ฌ์šฉ)
9
  BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
10
+ ADAPTERS = os.getenv("ADAPTER_MODEL", "m97j/npc_LoRA-fps")
11
 
12
  # ์žฅ์น˜ ์„ค์ •
13
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
model_loader.py CHANGED
@@ -2,7 +2,7 @@ import os, json, torch
2
  import torch.nn as nn
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
- from config import BASE_MODEL, ADAPTER_MODEL, DEVICE, HF_TOKEN
6
 
7
  def get_current_branch():
8
  if os.path.exists("current_branch.txt"):
@@ -12,13 +12,14 @@ def get_current_branch():
12
 
13
  class ModelWrapper:
14
  def __init__(self):
 
15
  flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
16
  self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
17
  self.num_flags = len(self.flags_order)
18
 
19
- # ํ† ํฐ ์ „๋‹ฌ
20
  self.tokenizer = AutoTokenizer.from_pretrained(
21
- ADAPTER_MODEL,
22
  use_fast=True,
23
  token=HF_TOKEN
24
  )
@@ -26,6 +27,7 @@ class ModelWrapper:
26
  self.tokenizer.pad_token = self.tokenizer.eos_token
27
  self.tokenizer.padding_side = "right"
28
 
 
29
  branch = get_current_branch()
30
  base = AutoModelForCausalLM.from_pretrained(
31
  BASE_MODEL,
@@ -33,25 +35,35 @@ class ModelWrapper:
33
  trust_remote_code=True,
34
  token=HF_TOKEN
35
  )
 
 
36
  self.model = PeftModel.from_pretrained(
37
  base,
38
- ADAPTER_MODEL,
39
  revision=branch,
40
  device_map="auto",
41
  token=HF_TOKEN
42
  )
43
 
 
44
  hidden_size = self.model.config.hidden_size
45
  self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
46
  self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
47
  self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
48
 
49
- if os.path.exists("delta_head.pt"):
50
- self.model.delta_head.load_state_dict(torch.load("delta_head.pt", map_location=DEVICE))
51
- if os.path.exists("flag_head.pt"):
52
- self.model.flag_head.load_state_dict(torch.load("flag_head.pt", map_location=DEVICE))
53
- if os.path.exists("flag_threshold_head.pt"):
54
- self.model.flag_threshold_head.load_state_dict(torch.load("flag_threshold_head.pt", map_location=DEVICE))
 
 
 
 
 
 
 
55
 
56
  self.model.eval()
57
 
 
2
  import torch.nn as nn
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
+ from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN
6
 
7
  def get_current_branch():
8
  if os.path.exists("current_branch.txt"):
 
12
 
13
  class ModelWrapper:
14
  def __init__(self):
15
+ # Flags ์ •๋ณด ๋กœ๋“œ
16
  flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
17
  self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
18
  self.num_flags = len(self.flags_order)
19
 
20
+ # ํ† ํฌ๋‚˜์ด์ €๋Š” ๋ฒ ์ด์Šค ๋ชจ๋ธ์—์„œ ๋กœ๋“œ
21
  self.tokenizer = AutoTokenizer.from_pretrained(
22
+ BASE_MODEL,
23
  use_fast=True,
24
  token=HF_TOKEN
25
  )
 
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
28
  self.tokenizer.padding_side = "right"
29
 
30
+ # ๋ฒ ์ด์Šค ๋ชจ๋ธ ๋กœ๋“œ
31
  branch = get_current_branch()
32
  base = AutoModelForCausalLM.from_pretrained(
33
  BASE_MODEL,
 
35
  trust_remote_code=True,
36
  token=HF_TOKEN
37
  )
38
+
39
+ # LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ
40
  self.model = PeftModel.from_pretrained(
41
  base,
42
+ ADAPTERS,
43
  revision=branch,
44
  device_map="auto",
45
  token=HF_TOKEN
46
  )
47
 
48
+ # ์ปค์Šคํ…€ ํ—ค๋“œ ์ถ”๊ฐ€
49
  hidden_size = self.model.config.hidden_size
50
  self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
51
  self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
52
  self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
53
 
54
+ # .pt ํŒŒ์ผ์ด ์—†์œผ๋ฉด ๊ทธ๋ƒฅ ๋„˜์–ด๊ฐ
55
+ for head_name, file_name in [
56
+ ("delta_head", "delta_head.pt"),
57
+ ("flag_head", "flag_head.pt"),
58
+ ("flag_threshold_head", "flag_threshold_head.pt")
59
+ ]:
60
+ try:
61
+ if os.path.exists(file_name):
62
+ getattr(self.model, head_name).load_state_dict(
63
+ torch.load(file_name, map_location=DEVICE)
64
+ )
65
+ except Exception as e:
66
+ print(f"[WARN] Failed to load {file_name}: {e}")
67
 
68
  self.model.eval()
69