prelington commited on
Commit
e71c280
·
verified ·
1 Parent(s): 80273ad

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +31 -8
model_loader.py CHANGED
@@ -6,25 +6,48 @@ from config import DEVICE, MODEL_LIST
6
 
7
  def load_model(model_name):
8
  """
9
- Load a model by name. Supports both Hugging Face repos and local safetensors.
 
 
 
 
 
 
10
  """
11
  try:
12
  if model_name.endswith(".safetensors"):
13
  print(f"[INFO] Loading safetensor model: {model_name}")
14
- tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use compatible tokenizer
15
- # Load safetensor weights into GPT2 model
16
  model = AutoModelForCausalLM.from_pretrained(
17
  "gpt2",
18
  state_dict=load_file(model_name),
19
- device_map="auto",
20
  torch_dtype=torch.float16
21
  )
22
  else:
23
  print(f"[INFO] Loading Hugging Face model: {model_name}")
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
26
- except Exception as e:
27
- raise RuntimeError(f"Failed to load model {model_name}: {e}")
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  model.to(DEVICE)
30
  return tokenizer, model
 
6
 
7
  def load_model(model_name):
8
  """
9
+ Load a model efficiently with memory optimization.
10
+ Supports:
11
+ - Hugging Face repos
12
+ - Local safetensor weights
13
+ Optimizations:
14
+ - FP16/BF16
15
+ - CPU offloading if GPU memory is low
16
  """
17
  try:
18
  if model_name.endswith(".safetensors"):
19
  print(f"[INFO] Loading safetensor model: {model_name}")
20
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
21
  model = AutoModelForCausalLM.from_pretrained(
22
  "gpt2",
23
  state_dict=load_file(model_name),
24
+ device_map="auto", # Automatically places layers on GPU/CPU
25
  torch_dtype=torch.float16
26
  )
27
  else:
28
  print(f"[INFO] Loading Hugging Face model: {model_name}")
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_name,
32
+ device_map="auto",
33
+ torch_dtype=torch.float16
34
+ )
35
+ except RuntimeError as e:
36
+ print(f"[WARN] GPU memory insufficient, switching to CPU offload. {e}")
37
+ # CPU offload
38
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
39
+ from transformers import AutoConfig
40
+
41
+ config = AutoConfig.from_pretrained(model_name)
42
+ with init_empty_weights():
43
+ model = AutoModelForCausalLM.from_config(config)
44
+ model = load_checkpoint_and_dispatch(
45
+ model,
46
+ model_name,
47
+ device_map={"": "cpu"},
48
+ no_split_module_classes=["GPT2Block"]
49
+ )
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+
52
  model.to(DEVICE)
53
  return tokenizer, model