Alon Albalak commited on
Commit
65bd58a
·
1 Parent(s): 32964cf

use accelerators in app

Browse files
Files changed (1) hide show
  1. src/models/llm_manager.py +19 -4
src/models/llm_manager.py CHANGED
@@ -14,12 +14,26 @@ class LLMManager:
14
  def __init__(self):
15
  self.model = None
16
  self.tokenizer = None
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def load_models(self, model_name="meta-llama/Llama-3.2-1B-Instruct"):
19
  """Load the LLM model and tokenizer"""
20
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
21
- self.model = AutoModelForCausalLM.from_pretrained(model_name, token=HF_TOKEN)
22
-
 
23
  if self.tokenizer.pad_token is None:
24
  self.tokenizer.pad_token = self.tokenizer.eos_token
25
 
@@ -60,7 +74,8 @@ class LLMManager:
60
  full_prompt = f"{prompt}\n\nAssistant: {partial_response}{user_continuation}"
61
 
62
  inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
63
-
 
64
  with torch.no_grad():
65
  outputs = self.model.generate(
66
  inputs.input_ids,
 
14
  def __init__(self):
15
  self.model = None
16
  self.tokenizer = None
17
+
18
+ if torch.cuda.is_available():
19
+ device = "cuda"
20
+ dtype = torch.float16
21
+ elif torch.backends.mps.is_available():
22
+ device = "mps"
23
+ dtype = torch.float16
24
+ else:
25
+ device = "cpu"
26
+ dtype = torch.float32
27
+
28
+ self.device = device
29
+ self.dtype = dtype
30
+
31
  def load_models(self, model_name="meta-llama/Llama-3.2-1B-Instruct"):
32
  """Load the LLM model and tokenizer"""
33
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
34
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, dtype=self.dtype, low_cpu_mem_usage=True)
35
+ self.model = self.model.to(self.device)
36
+
37
  if self.tokenizer.pad_token is None:
38
  self.tokenizer.pad_token = self.tokenizer.eos_token
39
 
 
74
  full_prompt = f"{prompt}\n\nAssistant: {partial_response}{user_continuation}"
75
 
76
  inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
77
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
78
+
79
  with torch.no_grad():
80
  outputs = self.model.generate(
81
  inputs.input_ids,