krislette commited on
Commit
814391d
·
1 Parent(s): a71ea0b

Added explicity cache dir parameter for llm2vec

Browse files
Files changed (1) hide show
  1. src/llm2vectrain/model.py +9 -2
src/llm2vectrain/model.py CHANGED
@@ -4,16 +4,21 @@ from peft import PeftModel
4
  from src.llm2vectrain.config import access_token
5
  import torch
6
  from torchao.quantization import quantize_, Int8WeightOnlyConfig
 
7
 
8
 
9
  def load_llm2vec_model():
 
 
10
 
11
  model_id = "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp"
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(
14
- model_id, padding=True, truncation=True, max_length=512
 
 
 
15
  )
16
- config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
17
 
18
  if torch.cuda.is_available():
19
  # GPU path: use bf16 for speed
@@ -24,6 +29,7 @@ def load_llm2vec_model():
24
  torch_dtype=torch.bfloat16,
25
  device_map="cuda",
26
  token=access_token,
 
27
  )
28
  else:
29
  # CPU path: use float32 first, then quantize
@@ -34,6 +40,7 @@ def load_llm2vec_model():
34
  torch_dtype=torch.float32, # quantization requires fp32
35
  device_map="cpu",
36
  token=access_token,
 
37
  )
38
 
39
  try:
 
4
  from src.llm2vectrain.config import access_token
5
  import torch
6
  from torchao.quantization import quantize_, Int8WeightOnlyConfig
7
+ import os
8
 
9
 
10
  def load_llm2vec_model():
11
+ # Get cache directory from environment or use default
12
+ cache_dir = os.getenv("TRANSFORMERS_CACHE", "/app/.cache/huggingface")
13
 
14
  model_id = "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp"
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(
17
+ model_id, padding=True, truncation=True, max_length=512, cache_dir=cache_dir
18
+ )
19
+ config = AutoConfig.from_pretrained(
20
+ model_id, trust_remote_code=True, cache_dir=cache_dir
21
  )
 
22
 
23
  if torch.cuda.is_available():
24
  # GPU path: use bf16 for speed
 
29
  torch_dtype=torch.bfloat16,
30
  device_map="cuda",
31
  token=access_token,
32
+ cache_dir=cache_dir,
33
  )
34
  else:
35
  # CPU path: use float32 first, then quantize
 
40
  torch_dtype=torch.float32, # quantization requires fp32
41
  device_map="cpu",
42
  token=access_token,
43
+ cache_dir=cache_dir,
44
  )
45
 
46
  try: