SaraAlthubaiti commited on
Commit
da9202d
·
verified ·
1 Parent(s): 37444a8

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.py +64 -0
  2. decode_config.yaml +58 -0
config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+
21
+ class Config:
22
+ def __init__(self, args):
23
+ self.config = {}
24
+
25
+ self.args = args
26
+ user_config = self._build_opt_list(self.args.options)
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+ config = OmegaConf.merge(config, user_config)
29
+ self.config = config
30
+
31
+ def _convert_to_dot_list(self, opts):
32
+ if opts is None:
33
+ opts = []
34
+
35
+ if len(opts) == 0:
36
+ return opts
37
+
38
+ has_equal = opts[0].find("=") != -1
39
+
40
+ if has_equal:
41
+ return opts
42
+
43
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
44
+
45
+ def _build_opt_list(self, opts):
46
+ opts_dot_list = self._convert_to_dot_list(opts)
47
+ return OmegaConf.from_dotlist(opts_dot_list)
48
+
49
+ def pretty_print(self):
50
+ logging.info("\n===== Running Parameters =====")
51
+ logging.info(self._convert_node_to_json(self.config.run))
52
+
53
+ logging.info("\n====== Dataset Attributes ======")
54
+ logging.info(self._convert_node_to_json(self.config.datasets))
55
+
56
+ logging.info(f"\n====== Model Attributes ======")
57
+ logging.info(self._convert_node_to_json(self.config.model))
58
+
59
+ def _convert_node_to_json(self, node):
60
+ container = OmegaConf.to_container(node, resolve=True)
61
+ return json.dumps(container, indent=4, sort_keys=True)
62
+
63
+ def to_dict(self):
64
+ return OmegaConf.to_container(self.config)
decode_config.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ model:
16
+ # paths
17
+ llama_path: "DeepSeek-R1-Distill-Qwen-1.5B/"
18
+ whisper_path: "distil-whisper/distil-large-v3/"
19
+ beats_path: "BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt"
20
+
21
+ ckpt: "tiny_all_tasks_319.pth"
22
+
23
+ freeze_whisper: True
24
+ freeze_beats: True
25
+
26
+ # window-level Q-Former
27
+ use_speech_Qformer: True
28
+ freeze_speech_QFormer: False
29
+ window_level_Qformer: True
30
+ num_speech_query_token: 1
31
+ second_per_window: 0.333333
32
+ second_stride: 0.333333
33
+
34
+ speech_llama_proj_model: ""
35
+ freeze_speech_llama_proj: False
36
+
37
+ # LoRA
38
+ lora: True
39
+ lora_rank: 8
40
+ lora_alpha: 32
41
+ lora_dropout: 0.1
42
+
43
+ multi_prompt: True
44
+ prompt_template: "USER: {}\nASSISTANT:"
45
+ prompt_path: "prompts/train_prompt.json"
46
+ test_prompt_path: "prompts/test_prompt.json"
47
+ max_txt_len: 300
48
+ end_sym: "</s>"
49
+
50
+ generate:
51
+ max_new_tokens: 200
52
+ num_beams: 4
53
+ do_sample: False
54
+ min_length: 1
55
+ temperature: 1.0
56
+ top_p: 0.9
57
+ repetition_penalty: 1.0
58
+ length_penalty: 1.0