Update factual/green_score/green.py
Browse files
factual/green_score/green.py
CHANGED
|
@@ -62,6 +62,7 @@ class GREEN:
|
|
| 62 |
category=FutureWarning,
|
| 63 |
module="transformers.tokenization_utils_base",
|
| 64 |
)
|
|
|
|
| 65 |
self.cpu = cpu
|
| 66 |
self.model_name = model_name.split("/")[-1]
|
| 67 |
self.output_dir = output_dir
|
|
@@ -100,8 +101,8 @@ class GREEN:
|
|
| 100 |
model_name,
|
| 101 |
trust_remote_code=False if "Phi" in model_name else True,
|
| 102 |
device_map=(
|
| 103 |
-
{"": "cuda:{
|
| 104 |
-
if not self.cpu
|
| 105 |
else {"": "cpu"}
|
| 106 |
),
|
| 107 |
torch_dtype=torch.float16,
|
|
@@ -211,7 +212,7 @@ class GREEN:
|
|
| 211 |
return self.process_results()
|
| 212 |
|
| 213 |
def tokenize_batch_as_chat(self, batch):
|
| 214 |
-
local_rank =
|
| 215 |
batch = [
|
| 216 |
self.tokenizer.apply_chat_template(
|
| 217 |
i, tokenize=False, add_generation_prompt=True
|
|
|
|
| 62 |
category=FutureWarning,
|
| 63 |
module="transformers.tokenization_utils_base",
|
| 64 |
)
|
| 65 |
+
cpu = cpu or not torch.cuda.is_available()
|
| 66 |
self.cpu = cpu
|
| 67 |
self.model_name = model_name.split("/")[-1]
|
| 68 |
self.output_dir = output_dir
|
|
|
|
| 101 |
model_name,
|
| 102 |
trust_remote_code=False if "Phi" in model_name else True,
|
| 103 |
device_map=(
|
| 104 |
+
{"": f"cuda:{torch.cuda.current_device()}"}
|
| 105 |
+
if (not self.cpu and torch.cuda.is_available())
|
| 106 |
else {"": "cpu"}
|
| 107 |
),
|
| 108 |
torch_dtype=torch.float16,
|
|
|
|
| 212 |
return self.process_results()
|
| 213 |
|
| 214 |
def tokenize_batch_as_chat(self, batch):
|
| 215 |
+
local_rank = self.device
|
| 216 |
batch = [
|
| 217 |
self.tokenizer.apply_chat_template(
|
| 218 |
i, tokenize=False, add_generation_prompt=True
|