theapemachine commited on
Commit
86eaefa
·
1 Parent(s): a02f16c

Add inference settings for model loading: prioritize CUDA, MPS, and CPU. Update runner and pipeline to utilize new settings for dtype and device placement.

Browse files
tensegrity/bench/runner.py CHANGED
@@ -33,6 +33,7 @@ from dataclasses import dataclass, field, asdict
33
  from pathlib import Path
34
 
35
  from tensegrity.bench.tasks import TaskSample, TaskConfig, TASK_REGISTRY, load_task_samples
 
36
 
37
  logger = logging.getLogger(__name__)
38
 
@@ -164,17 +165,19 @@ class EvalRunner:
164
  return
165
 
166
  from transformers import AutoTokenizer, AutoModelForCausalLM
167
- import torch
168
 
 
169
  logger.info(f"Loading model {self.model_name}...")
170
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
171
  if self._tokenizer.pad_token is None:
172
  self._tokenizer.pad_token = self._tokenizer.eos_token
173
  self._model = AutoModelForCausalLM.from_pretrained(
174
  self.model_name,
175
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
176
- device_map="auto" if torch.cuda.is_available() else None,
177
  )
 
 
178
  self._model.eval()
179
  logger.info("Model loaded.")
180
 
 
33
  from pathlib import Path
34
 
35
  from tensegrity.bench.tasks import TaskSample, TaskConfig, TASK_REGISTRY, load_task_samples
36
+ from tensegrity.torch_device import inference_load_settings
37
 
38
  logger = logging.getLogger(__name__)
39
 
 
165
  return
166
 
167
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
168
 
169
+ dtype, device_map, move_to = inference_load_settings()
170
  logger.info(f"Loading model {self.model_name}...")
171
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
172
  if self._tokenizer.pad_token is None:
173
  self._tokenizer.pad_token = self._tokenizer.eos_token
174
  self._model = AutoModelForCausalLM.from_pretrained(
175
  self.model_name,
176
+ torch_dtype=dtype,
177
+ device_map=device_map,
178
  )
179
+ if move_to is not None:
180
+ self._model = self._model.to(move_to)
181
  self._model.eval()
182
  logger.info("Model loaded.")
183
 
tensegrity/graft/pipeline.py CHANGED
@@ -30,6 +30,7 @@ from tensegrity.graft.logit_bias import (
30
  StaticLogitBiasBuilder,
31
  GraftState,
32
  )
 
33
 
34
  logger = logging.getLogger(__name__)
35
 
@@ -96,15 +97,17 @@ class HybridPipeline:
96
  return
97
 
98
  from transformers import AutoTokenizer, AutoModelForCausalLM
99
- import torch
100
-
101
  logger.info(f"Loading {self.model_name}...")
102
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
103
  self._model = AutoModelForCausalLM.from_pretrained(
104
  self.model_name,
105
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
106
- device_map="auto" if torch.cuda.is_available() else None,
107
  )
 
 
108
 
109
  # Build vocabulary grounding
110
  if self._hypothesis_keywords:
 
30
  StaticLogitBiasBuilder,
31
  GraftState,
32
  )
33
+ from tensegrity.torch_device import inference_load_settings
34
 
35
  logger = logging.getLogger(__name__)
36
 
 
97
  return
98
 
99
  from transformers import AutoTokenizer, AutoModelForCausalLM
100
+
101
+ dtype, device_map, move_to = inference_load_settings()
102
  logger.info(f"Loading {self.model_name}...")
103
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
104
  self._model = AutoModelForCausalLM.from_pretrained(
105
  self.model_name,
106
+ torch_dtype=dtype,
107
+ device_map=device_map,
108
  )
109
+ if move_to is not None:
110
+ self._model = self._model.to(move_to)
111
 
112
  # Build vocabulary grounding
113
  if self._hypothesis_keywords:
tensegrity/torch_device.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pick inference dtype and placement for transformers models.
3
+
4
+ Preference order: CUDA (device_map auto) → Apple MPS → CPU.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Optional, Tuple
10
+
11
+
12
+ def inference_load_settings() -> Tuple[Any, Optional[str], Optional[Any]]:
13
+ """
14
+ Returns (torch_dtype, device_map, move_to_device).
15
+
16
+ - CUDA: float16, device_map=\"auto\", move_to_device=None
17
+ - MPS: float16, device_map=None, move_to_device=torch.device(\"mps\")
18
+ - CPU: float32, device_map=None, move_to_device=None
19
+ """
20
+ import torch
21
+
22
+ if torch.cuda.is_available():
23
+ return torch.float16, "auto", None
24
+ mps = getattr(torch.backends, "mps", None)
25
+ if mps is not None and mps.is_available():
26
+ return torch.float16, None, torch.device("mps")
27
+ return torch.float32, None, None