Yzy00518 commited on
Commit
d41829b
·
1 Parent(s): 1a9922b

Upload src/utils/inference_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/utils/inference_utils.py +56 -0
src/utils/inference_utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+
6
+ def set_all_seeds(seed=42):
7
+ print("set all seeds", flush=True)
8
+ random.seed(seed)
9
+ np.random.seed(seed)
10
+ torch.manual_seed(seed)
11
+ torch.cuda.manual_seed(seed)
12
+ torch.cuda.manual_seed_all(seed)
13
+ torch.backends.cudnn.deterministic = True
14
+ torch.backends.cudnn.benchmark = False
15
+
16
+ def fix_state_dict(state_dict):
17
+ new_state_dict = {}
18
+ for k, v in state_dict.items():
19
+ name = k[7:] if k.startswith('module.') else k
20
+ new_state_dict[name] = v
21
+ return new_state_dict
22
+
23
+
24
+ #######################################################
25
+ def load_hint_texts_from_file(file_path):
26
+ hint_texts = []
27
+ with open(file_path, 'r') as file:
28
+ for line in file:
29
+ hint_texts.append([line.strip()])
30
+ return hint_texts
31
+
32
+ def load_mask_from_file(file_path):
33
+ mask = []
34
+ with open(file_path, 'r') as file:
35
+ for line in file:
36
+ mask.append([line.strip()])
37
+ return mask
38
+
39
+ def load_file_names(file_path):
40
+ with open(file_path, 'r') as file:
41
+ file_names = [line.strip() for line in file]
42
+ return file_names
43
+
44
+ def gen_prog_ind(num_cases=16, sublist_length=4):
45
+ total_range = 0.9
46
+ step = total_range / sublist_length
47
+ ranges = [(i * step, i * step + step / 5) for i in range(sublist_length)]
48
+
49
+ prog_ind_all = []
50
+ for _ in range(num_cases):
51
+ while True:
52
+ case = [random.uniform(r[0], r[1]) for r in ranges]
53
+ if all(step*0.8 <= case[i+1] - case[i] <= step*1.6 for i in range(len(case) - 1)):
54
+ prog_ind_all.append([case])
55
+ break
56
+ return prog_ind_all