britto224 commited on
Commit
3fde5f3
·
verified ·
1 Parent(s): 5669b22

Upload 17 files

Browse files
upgrade_codes/__init__.py ADDED
File without changes
upgrade_codes/compare_yaml.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ruamel.yaml import YAML
2
+
3
+ conf1 = "conf.yaml"
4
+ conf2 = "config_templates/conf.ZH.default.yaml"
5
+
6
+
7
+ def collect_all_key_paths(d, prefix=""):
8
+ keys = set()
9
+ for k, v in d.items():
10
+ full_key = f"{prefix}.{k}" if prefix else k
11
+ keys.add(full_key)
12
+ if isinstance(v, dict):
13
+ keys.update(collect_all_key_paths(v, full_key))
14
+ return keys
15
+
16
+
17
+ def collect_leaf_key_paths(d, prefix=""):
18
+ keys = set()
19
+ for k, v in d.items():
20
+ full_key = f"{prefix}.{k}" if prefix else k
21
+ if isinstance(v, dict):
22
+ keys.update(collect_leaf_key_paths(v, full_key))
23
+ else:
24
+ keys.add(full_key)
25
+ return keys
26
+
27
+
28
+ def get_value_by_path(d, path_str):
29
+ keys = path_str.split(".")
30
+ current = d
31
+ for key in keys:
32
+ if isinstance(current, dict) and key in current:
33
+ current = current[key]
34
+ else:
35
+ return None
36
+ return current
37
+
38
+
39
+ def compare_yaml_keys(dict1, dict2):
40
+ keys1 = collect_all_key_paths(dict1)
41
+ keys2 = collect_all_key_paths(dict2)
42
+ only_in_1 = keys1 - keys2
43
+ only_in_2 = keys2 - keys1
44
+ return only_in_1, only_in_2
45
+
46
+
47
+ def compare_yaml_values(dict1, dict2):
48
+ leaf_keys1 = collect_leaf_key_paths(dict1)
49
+ leaf_keys2 = collect_leaf_key_paths(dict2)
50
+ common_leaf_keys = leaf_keys1 & leaf_keys2
51
+
52
+ differences = []
53
+
54
+ for key in sorted(common_leaf_keys):
55
+ value1 = get_value_by_path(dict1, key)
56
+ value2 = get_value_by_path(dict2, key)
57
+
58
+ if value1 != value2:
59
+ differences.append({"key_path": key, "value1": value1, "value2": value2})
60
+
61
+ if not differences:
62
+ print("✅ 所有共同叶子节点的值完全一致\n")
63
+ else:
64
+ print(f"❌ 发现 {len(differences)} 个值不同的字段:\n\n")
65
+ for diff in differences:
66
+ print(f"键路径: {diff['key_path']}\n")
67
+ print(f" {conf1}中的值: {diff['value1']}\n")
68
+ print(f" {conf2}中的值: {diff['value2']}\n")
69
+ print("-" * 50 + "\n")
70
+
71
+ return differences
72
+
73
+
74
+ if __name__ == "__main__":
75
+ yaml = YAML(typ="safe")
76
+ with (
77
+ open(conf1, "r", encoding="utf-8") as f1,
78
+ open(conf2, "r", encoding="utf-8") as f2,
79
+ ):
80
+ config1 = yaml.load(f1)
81
+ config2 = yaml.load(f2)
82
+
83
+ # Compare differences in keys.
84
+ only_in_1, only_in_2 = compare_yaml_keys(config1, config2)
85
+
86
+ if not only_in_1 and not only_in_2:
87
+ print("✅ 两个 YAML 文件的 key 完全一致")
88
+ else:
89
+ print("❌ 不一致:")
90
+ if only_in_1:
91
+ print(f"仅在 {conf1} 中存在的 key ({len(only_in_1)} 个):")
92
+ for key in sorted(only_in_1):
93
+ print(f" - {key}")
94
+ if only_in_2:
95
+ print(f"\n仅在 {conf2} 中存在的 key ({len(only_in_2)} 个):")
96
+ for key in sorted(only_in_2):
97
+ print(f" - {key}")
98
+
99
+ # Compare differences in values.
100
+ diff_count = compare_yaml_values(config1, config2)
upgrade_codes/config_sync.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from upgrade_codes.upgrade_core.constants import (
4
+ USER_CONF,
5
+ BACKUP_CONF,
6
+ TEXTS,
7
+ ZH_DEFAULT_CONF,
8
+ EN_DEFAULT_CONF,
9
+ TEXTS_COMPARE,
10
+ TEXTS_MERGE,
11
+ )
12
+ import logging
13
+ from ruamel.yaml import YAML
14
+ from src.open_llm_vtuber.config_manager.utils import load_text_file_with_guess_encoding
15
+ from upgrade_codes.upgrade_core.comment_sync import CommentSynchronizer
16
+ from upgrade_codes.version_manager import VersionUpgradeManager
17
+ from upgrade_codes.upgrade_core.upgrade_utils import UpgradeUtility
18
+ from upgrade_codes.upgrade_core.comment_diff_fn import comment_diff_fn
19
+ from packaging import version
20
+
21
+
22
+ class ConfigSynchronizer:
23
+ def __init__(self, lang="en", logger=logging.getLogger(__name__)):
24
+ self.lang = lang
25
+ self.texts = TEXTS[lang]
26
+ self.default_path = ZH_DEFAULT_CONF if lang == "zh" else EN_DEFAULT_CONF
27
+ self.yaml = YAML()
28
+ self.yaml.preserve_quotes = True
29
+ self.user_path = USER_CONF
30
+ self.backup_path = BACKUP_CONF
31
+ self.texts_merge = TEXTS_MERGE.get(lang, TEXTS_MERGE["en"])
32
+ self.texts_compare = TEXTS_COMPARE.get(lang, TEXTS_COMPARE["en"])
33
+ self.logger = logger
34
+ self.upgrade_utils = UpgradeUtility(self.logger, self.lang)
35
+
36
+ def sync_user_config(self) -> None:
37
+ """
38
+ Ensure the user configuration file exists and create a backup if necessary.
39
+ If the user config file does not exist, copy the default config.
40
+ """
41
+ # Check if the user config file exists
42
+ if not os.path.exists(self.user_path):
43
+ self.logger.warning(self.texts["no_config"])
44
+ self.logger.warning(self.texts["copy_default_config"])
45
+ # Copy default config to user path
46
+ shutil.copy2(self.default_path, self.user_path)
47
+ return
48
+
49
+ # Create a backup of the user config file
50
+ self.backup_user_config()
51
+
52
+ def update_user_config(self) -> None:
53
+ """
54
+ Perform the actual update operations on the user configuration file:
55
+ 1. Compare and update configuration fields
56
+ 2. Synchronize comments
57
+ 3. Upgrade version if needed
58
+ """
59
+
60
+ # Step 1: Update config fields
61
+ if not self.compare_field_keys():
62
+ self.merge_and_update_user_config()
63
+ else:
64
+ self.logger.info(self.texts["configs_up_to_date"])
65
+
66
+ # Step 2: Sync comments
67
+ if not self.compare_comments():
68
+ comment_sync = CommentSynchronizer(
69
+ self.default_path,
70
+ self.user_path,
71
+ self.logger,
72
+ self.yaml,
73
+ self.texts_compare,
74
+ )
75
+ comment_sync.sync()
76
+ else:
77
+ self.logger.info(self.texts_compare["comments_up_to_date"])
78
+
79
+ # Step 3: Determine whether upgrade is needed
80
+ new_version = self.get_latest_version()
81
+ old_version = self.get_old_version()
82
+ need_upgrade = old_version != new_version
83
+
84
+ # Step 4: Run upgrade if needed
85
+ if need_upgrade:
86
+ version_upgrade_manager = VersionUpgradeManager(self.lang, self.logger)
87
+ final_version = version_upgrade_manager.upgrade(old_version)
88
+ self.logger.info(
89
+ self.texts["version_upgrade_success"].format(
90
+ old=old_version, new=final_version
91
+ )
92
+ )
93
+ else:
94
+ self.logger.info(
95
+ self.texts["version_upgrade_none"].format(version=old_version)
96
+ )
97
+
98
+ def backup_user_config(self):
99
+ backup_path = os.path.abspath(self.backup_path)
100
+ self.logger.info(
101
+ self.texts["backup_user_config"].format(
102
+ user_conf=self.user_path, backup_conf=self.backup_path
103
+ )
104
+ )
105
+ self.logger.debug(self.texts["config_backup_path"].format(path=backup_path))
106
+ shutil.copy2(self.user_path, self.backup_path)
107
+
108
+ def merge_and_update_user_config(self):
109
+ try:
110
+ new_keys = self.merge_configs()
111
+ if new_keys:
112
+ self.logger.info(self.texts["merged_config_success"])
113
+ for key in new_keys:
114
+ self.logger.info(f" - {key}")
115
+ else:
116
+ self.logger.info(self.texts["merged_config_none"])
117
+ except Exception as e:
118
+ self.logger.error(self.texts["merge_failed"].format(error=e))
119
+
120
+ def merge_configs(self):
121
+ user_config = self.yaml.load(load_text_file_with_guess_encoding(self.user_path))
122
+ default_config = self.yaml.load(
123
+ load_text_file_with_guess_encoding(self.default_path)
124
+ )
125
+
126
+ new_keys = []
127
+
128
+ def merge(d_user, d_default, path=""):
129
+ for k, v in d_default.items():
130
+ current_path = f"{path}.{k}" if path else k
131
+ if k not in d_user:
132
+ d_user[k] = v
133
+ new_keys.append(current_path)
134
+ elif isinstance(v, dict) and isinstance(d_user.get(k), dict):
135
+ merge(d_user[k], v, current_path)
136
+ return d_user
137
+
138
+ merged = merge(user_config, default_config)
139
+
140
+ with open(self.user_path, "w", encoding="utf-8") as f:
141
+ self.yaml.dump(merged, f)
142
+
143
+ for key in new_keys:
144
+ self.logger.info(self.texts_merge["new_config_item"].format(key=key))
145
+ return new_keys
146
+
147
+ def collect_all_subkeys(self, d, base_path):
148
+ """Collect all keys in the dictionary d, recursively, with base_path as the prefix."""
149
+ keys = []
150
+ # Only process if d is a dictionary
151
+ if isinstance(d, dict):
152
+ for key, value in d.items():
153
+ current_path = f"{base_path}.{key}" if base_path else key
154
+ keys.append(current_path)
155
+ if isinstance(value, dict):
156
+ keys.extend(self.collect_all_subkeys(value, current_path))
157
+ return keys
158
+
159
+ def get_missing_keys(self, user, default, path=""):
160
+ """Recursively find keys in default that are missing in user."""
161
+ missing = []
162
+ for key, default_val in default.items():
163
+ current_path = f"{path}.{key}" if path else key
164
+ if key not in user:
165
+ missing.append(current_path)
166
+ else:
167
+ user_val = user[key]
168
+ if isinstance(default_val, dict):
169
+ if isinstance(user_val, dict):
170
+ missing.extend(
171
+ self.get_missing_keys(user_val, default_val, current_path)
172
+ )
173
+ else:
174
+ subtree_missing = self.collect_all_subkeys(
175
+ default_val, current_path
176
+ )
177
+ missing.extend(subtree_missing)
178
+ return missing
179
+
180
+ def get_extra_keys(self, user, default, path=""):
181
+ """Recursively find keys in user that are not present in default."""
182
+ extra = []
183
+ for key, user_val in user.items():
184
+ current_path = f"{path}.{key}" if path else key
185
+ if key not in default:
186
+ # Only collect subkeys if the value is a dictionary
187
+ if isinstance(user_val, dict):
188
+ subtree_extra = self.collect_all_subkeys(user_val, current_path)
189
+ extra.extend(subtree_extra)
190
+ extra.append(current_path)
191
+ else:
192
+ default_val = default[key]
193
+ if isinstance(user_val, dict) and isinstance(default_val, dict):
194
+ extra.extend(
195
+ self.get_extra_keys(user_val, default_val, current_path)
196
+ )
197
+ elif isinstance(user_val, dict):
198
+ subtree_extra = self.collect_all_subkeys(user_val, current_path)
199
+ extra.extend(subtree_extra)
200
+ return extra
201
+
202
+ def delete_extra_keys(self):
203
+ """Delete extra keys in user config that are not present in default config."""
204
+
205
+ user_config = self.yaml.load(load_text_file_with_guess_encoding(self.user_path))
206
+ default_config = self.yaml.load(
207
+ load_text_file_with_guess_encoding(self.default_path)
208
+ )
209
+ extra_keys = self.get_extra_keys(user_config, default_config)
210
+
211
+ def delete_key_by_path(config_dict, key_path):
212
+ keys = key_path.split(".")
213
+ sub_dict = config_dict
214
+ for k in keys[:-1]:
215
+ if k in sub_dict and isinstance(sub_dict[k], dict):
216
+ sub_dict = sub_dict[k]
217
+ else:
218
+ return False
219
+ return sub_dict.pop(keys[-1], None) is not None
220
+
221
+ deleted_keys = []
222
+ for key_path in extra_keys:
223
+ if delete_key_by_path(user_config, key_path):
224
+ deleted_keys.append(key_path)
225
+
226
+ with open(self.user_path, "w", encoding="utf-8") as f:
227
+ self.yaml.dump(user_config, f)
228
+
229
+ self.logger.info(
230
+ self.texts_compare["extra_keys_deleted_count"].format(
231
+ count=len(deleted_keys)
232
+ )
233
+ )
234
+ for key in deleted_keys:
235
+ self.logger.info(
236
+ self.texts_compare["extra_keys_deleted_item"].format(key=key)
237
+ )
238
+
239
+ def compare_field_keys(self) -> bool:
240
+ """Compare field structure differences (missing/extra keys)"""
241
+
242
+ def field_compare_fn(user, default):
243
+ missing = self.get_missing_keys(user, default)
244
+ extra = self.get_extra_keys(user, default)
245
+
246
+ if missing:
247
+ self.logger.warning(
248
+ self.texts_compare["missing_keys"].format(keys=", ".join(missing))
249
+ )
250
+ if extra:
251
+ self.logger.warning(
252
+ self.texts_compare["extra_keys"].format(keys=", ".join(extra))
253
+ )
254
+ self.delete_extra_keys()
255
+ return (not missing, missing + extra)
256
+
257
+ return self.upgrade_utils.compare_dicts(
258
+ name="keys",
259
+ get_a=lambda: self.yaml.load(
260
+ load_text_file_with_guess_encoding(self.user_path)
261
+ ),
262
+ get_b=lambda: self.yaml.load(
263
+ load_text_file_with_guess_encoding(self.default_path)
264
+ ),
265
+ compare_fn=field_compare_fn,
266
+ )
267
+
268
+ def compare_comments(self) -> bool:
269
+ return self.upgrade_utils.compare_dicts(
270
+ name="comments",
271
+ get_a=lambda: load_text_file_with_guess_encoding(self.user_path),
272
+ get_b=lambda: load_text_file_with_guess_encoding(self.default_path),
273
+ compare_fn=comment_diff_fn,
274
+ )
275
+
276
+ def get_latest_version(self):
277
+ with open(self.default_path, "r", encoding="utf-8") as f:
278
+ default_config = self.yaml.load(f)
279
+ return default_config.get("system_config", {}).get("conf_version", "")
280
+
281
+ def get_old_version(self) -> str:
282
+ """
283
+ Extract the old version from backup config.
284
+ If missing or too old (< v1.1.1), fallback to v1.1.1.
285
+ """
286
+ fallback_version = "v1.1.1"
287
+ try:
288
+ yaml = YAML()
289
+ with open(BACKUP_CONF, "r", encoding="utf-8") as f:
290
+ backup_conf = yaml.load(f)
291
+ raw_version = backup_conf.get("system_config", {}).get(
292
+ "conf_version", fallback_version
293
+ )
294
+
295
+ if version.parse(raw_version) < version.parse(fallback_version):
296
+ self.logger.warning(
297
+ self.texts["version_too_old"].format(
298
+ found=raw_version, adjusted=fallback_version
299
+ )
300
+ )
301
+ return fallback_version
302
+
303
+ self.logger.info(
304
+ self.texts["backup_used_version"].format(backup_version=raw_version)
305
+ )
306
+ return raw_version
307
+ except Exception as e:
308
+ self.logger.warning(
309
+ self.texts["backup_read_error"].format(
310
+ version=fallback_version, error=e
311
+ )
312
+ )
313
+ return fallback_version
upgrade_codes/from_version/__init__.py ADDED
File without changes
upgrade_codes/from_version/v_1_1_1.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import yaml
3
+
4
+
5
+ class to_v_1_2_1:
6
+ def __init__(self, old_model_list, conf_yaml_path, language):
7
+ """
8
+ :param old_model_list: list of dicts (each representing a Live2D model config)
9
+ :param conf_yaml_path: path to conf.yaml that should be upgraded
10
+ :param language: language of the configuration ("zh" or "en")
11
+ """
12
+ self.old_models = old_model_list
13
+ self.conf_yaml_path = conf_yaml_path
14
+ self.language = language
15
+
16
+ # Configuration migration mapping table (language-specific)
17
+ self.migration_map = {
18
+ "zh": {
19
+ "shizuku.png": "mao.png",
20
+ "Shizuku": "Mao",
21
+ "shizuku-local": "mao_pro",
22
+ "shizuku-local-001": "mao_pro_001",
23
+ "distil-medium.en": "large-v3-turbo",
24
+ "en": "zh",
25
+ "v1.1.1": "v1.2.1",
26
+ "v1.2.0": "v1.2.1",
27
+ },
28
+ "en": {
29
+ "shizuku.png": "mao.png",
30
+ "Shizuku": "Mao",
31
+ "shizuku-local": "mao_pro",
32
+ "shizuku-local-001": "mao_pro_001",
33
+ "distil-medium.en": "large-v3-turbo",
34
+ "v1.1.1": "v1.2.1",
35
+ "v1.2.0": "v1.2.1",
36
+ },
37
+ }
38
+
39
+ def upgrade(self):
40
+ """
41
+ Return upgraded model_dict structure including 'models' list and new version
42
+ And perform in-place upgrade of conf.yaml
43
+ """
44
+ upgraded_models = self._upgrade_live2d_models(self.old_models)
45
+ self._upgrade_conf_yaml()
46
+ return upgraded_models
47
+
48
+ def _upgrade_live2d_models(self, old_model_list: list) -> list:
49
+ deprecated = {
50
+ "other_unit_90001",
51
+ "player_unit_00003",
52
+ "mashiro",
53
+ "shizuku-local",
54
+ "shizuku",
55
+ }
56
+ upgrades = {"mao_pro"}
57
+ new_models = []
58
+
59
+ for model in old_model_list:
60
+ name = model.get("name")
61
+ if name in deprecated:
62
+ continue
63
+
64
+ upgraded = copy.deepcopy(model)
65
+
66
+ if name in upgrades:
67
+ if name == "mao_pro":
68
+ upgraded["url"] = (
69
+ "/live2d-models/mao_pro/runtime/mao_pro.model3.json"
70
+ )
71
+ upgraded["kScale"] = 0.5
72
+
73
+ new_models.append(upgraded)
74
+
75
+ return new_models
76
+
77
+ def _upgrade_conf_yaml(self):
78
+ try:
79
+ with open(self.conf_yaml_path, "r", encoding="utf-8") as f:
80
+ data = yaml.safe_load(f)
81
+
82
+ # Update system version number
83
+ if "system_config" in data and isinstance(data["system_config"], dict):
84
+ data["system_config"]["conf_version"] = "v1.2.1"
85
+
86
+ # Update VAD config
87
+ vad_config = data.get("character_config", {}).get("vad_config", {})
88
+ if vad_config.get("vad_model") == "silero_vad":
89
+ vad_config["vad_model"] = None
90
+
91
+ # Update role-related configurations
92
+ char_config = data.get("character_config", {})
93
+ self._migrate_field(char_config, "avatar")
94
+ self._migrate_field(char_config, "character_name")
95
+ self._migrate_field(char_config, "conf_name")
96
+ self._migrate_field(char_config, "conf_uid")
97
+ self._migrate_field(char_config, "live2d_model_name")
98
+
99
+ # Update ASR config
100
+ asr_config = char_config.get("asr_config", {}).get("faster_whisper", {})
101
+ self._migrate_field(asr_config, "model_path")
102
+
103
+ if self.language == "zh":
104
+ self._migrate_field(asr_config, "language")
105
+
106
+ with open(self.conf_yaml_path, "w", encoding="utf-8") as f:
107
+ yaml.safe_dump(
108
+ data, f, allow_unicode=True, sort_keys=False, default_style="'"
109
+ ) # Auto formatting with '
110
+
111
+ except Exception as e:
112
+ raise RuntimeError(f"Failed to upgrade conf.yaml: {e}")
113
+
114
+ def _migrate_field(self, config_section: dict, field_name: str):
115
+ if field_name in config_section:
116
+ current_value = config_section[field_name]
117
+ lang_map = self.migration_map.get(self.language, {})
118
+ new_value = lang_map.get(current_value, current_value)
119
+ config_section[field_name] = new_value
upgrade_codes/upgrade_core/__init__.py ADDED
File without changes
upgrade_codes/upgrade_core/comment_diff_fn.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+ from ruamel.yaml import YAML
3
+ from ruamel.yaml.comments import CommentedMap
4
+
5
+
6
+ def get_comment_text(comment_list):
7
+ if not comment_list:
8
+ return ""
9
+ flattened = []
10
+ for c in comment_list:
11
+ if isinstance(c, list):
12
+ for sub in c:
13
+ if hasattr(sub, "value"):
14
+ flattened.append(str(sub.value).strip())
15
+ elif hasattr(c, "value"):
16
+ flattened.append(str(c.value).strip())
17
+ return "\n".join(flattened).strip()
18
+
19
+
20
+ def extract_comments(yaml_text: str) -> dict:
21
+ yaml = YAML()
22
+ yaml.preserve_quotes = True
23
+ data = yaml.load(StringIO(yaml_text))
24
+
25
+ comment_map = {}
26
+
27
+ def recurse(node, path=""):
28
+ if not isinstance(node, CommentedMap):
29
+ return
30
+ if hasattr(node, "ca") and isinstance(node.ca.items, dict):
31
+ for key in node:
32
+ full_path = f"{path}.{key}" if path else str(key)
33
+ if key in node.ca.items:
34
+ comment_map[full_path] = get_comment_text(node.ca.items[key])
35
+ recurse(node[key], full_path)
36
+
37
+ recurse(data)
38
+ return comment_map
39
+
40
+
41
+ def comment_diff_fn(default_text: str, user_text: str):
42
+ default_comments = extract_comments(default_text)
43
+ user_comments = extract_comments(user_text)
44
+
45
+ diff_keys = []
46
+
47
+ all_keys = set(default_comments.keys()) | set(user_comments.keys())
48
+ for key in all_keys:
49
+ d = default_comments.get(key, "")
50
+ u = user_comments.get(key, "")
51
+ if d != u:
52
+ diff_keys.append(key)
53
+
54
+ return (len(diff_keys) == 0), diff_keys
upgrade_codes/upgrade_core/comment_sync.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # upgrade/comment_sync.py
2
+ from typing import Dict
3
+ from logging import Logger
4
+ from ruamel.yaml import YAML
5
+ from ruamel.yaml.comments import CommentedMap
6
+
7
+
8
+ class CommentSynchronizer:
9
+ def __init__(
10
+ self,
11
+ default_path: str,
12
+ user_path: str,
13
+ logger: Logger,
14
+ yaml: YAML,
15
+ texts_compare: Dict[str, str],
16
+ ):
17
+ self.default_path = default_path
18
+ self.user_path = user_path
19
+ self.logger = logger
20
+ self.yaml = yaml
21
+ self.texts_compare = texts_compare
22
+
23
+ def sync(self) -> None:
24
+ try:
25
+ with open(self.default_path, "r", encoding="utf-8") as f:
26
+ default_tree: CommentedMap = self.yaml.load(f)
27
+ with open(self.user_path, "r", encoding="utf-8") as f:
28
+ user_tree: CommentedMap = self.yaml.load(f)
29
+
30
+ def sync_comments(
31
+ default_node: CommentedMap, user_node: CommentedMap, path: str = ""
32
+ ) -> None:
33
+ if not isinstance(default_node, CommentedMap) or not isinstance(
34
+ user_node, CommentedMap
35
+ ):
36
+ return
37
+
38
+ for key in default_node:
39
+ if key in user_node:
40
+ current_path = f"{path}.{key}" if path else key
41
+ if hasattr(default_node, "ca") and hasattr(user_node, "ca"):
42
+ if key in default_node.ca.items:
43
+ user_node.ca.items[key] = default_node.ca.items[key]
44
+ sync_comments(default_node[key], user_node[key], current_path)
45
+
46
+ sync_comments(default_tree, user_tree)
47
+
48
+ if hasattr(default_tree, "ca") and hasattr(user_tree, "ca"):
49
+ user_tree.ca.end = default_tree.ca.end
50
+
51
+ with open(self.user_path, "w", encoding="utf-8") as f:
52
+ self.yaml.dump(user_tree, f)
53
+
54
+ self.logger.info(self.texts_compare["comment_sync_success"])
55
+ except Exception as e:
56
+ self.logger.error(self.texts_compare["comment_sync_error"].format(error=e))
upgrade_codes/upgrade_core/constants.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # upgrade/constants.py
2
+ # CURRENT_SCRIPT_VERSION = "0.2.0"
3
+ from ruamel.yaml import YAML
4
+ from src.open_llm_vtuber.config_manager.utils import load_text_file_with_guess_encoding
5
+ import os
6
+
7
+ USER_CONF = "conf.yaml"
8
+ BACKUP_CONF = "conf.yaml.backup"
9
+
10
+ ZH_DEFAULT_CONF = "config_templates/conf.ZH.default.yaml"
11
+ EN_DEFAULT_CONF = "config_templates/conf.default.yaml"
12
+
13
+ yaml = YAML()
14
+ # user_config = yaml.load(load_text_file_with_guess_encoding(USER_CONF))
15
+ # CURRENT_SCRIPT_VERSION = user_config.get("system_config", {}).get("conf_version")
16
+
17
+
18
+ def load_user_config():
19
+ if not os.path.exists(USER_CONF):
20
+ return None
21
+ text = load_text_file_with_guess_encoding(USER_CONF)
22
+ if text is None:
23
+ return None
24
+ return yaml.load(text)
25
+
26
+
27
+ def get_current_script_version():
28
+ config = load_user_config()
29
+ if config:
30
+ return config.get("system_config", {}).get("conf_version", "UNKNOWN")
31
+ return "UNKNOWN"
32
+
33
+
34
+ CURRENT_SCRIPT_VERSION = get_current_script_version()
35
+
36
+ TEXTS = {
37
+ "zh": {
38
+ # "welcome_message": f"Auto-Upgrade Script {CURRENT_SCRIPT_VERSION}\nOpen-LLM-VTuber 升级脚本 - 此脚本仍在实验阶段,可能无法按预期工作。",
39
+ "welcome_message": f"正在从 {CURRENT_SCRIPT_VERSION} 自动升级...",
40
+ # "lang_select": "请选择语言/Please select language (zh/en):",
41
+ # "invalid_lang": "无效的语言选择,使用英文作为默认语言",
42
+ "not_git_repo": "错误:当前目录不是git仓库。请进入 Open-LLM-VTuber 目录后再运行此脚本。\n当然,更有可能的是你下载的Open-LLM-VTuber不包含.git文件夹 (如果你是透过下载压缩包而非使用 git clone 命令下载的话可能会造成这种情况),这种情况下目前无法用脚本升级。",
43
+ "backup_user_config": "正在备份 {user_conf} 到 {backup_conf}",
44
+ "configs_up_to_date": "[DEBUG] 用户配置已是最新。",
45
+ "no_config": "警告:未找到conf.yaml文件",
46
+ "copy_default_config": "正在从模板复制默认配置",
47
+ "uncommitted": "发现未提交的更改,正在暂存...",
48
+ "stash_error": "错误:无法暂存更改",
49
+ "changes_stashed": "更改已暂存",
50
+ "pulling": "正在从远程仓库拉取更新...",
51
+ "pull_error": "错误:无法拉取更新",
52
+ "restoring": "正在恢复暂存的更改...",
53
+ "conflict_warning": "警告:恢复暂存的更改时发生冲突",
54
+ "manual_resolve": "请手动解决冲突",
55
+ "stash_list": "你可以使用 'git stash list' 查看暂存的更改",
56
+ "stash_pop": "使用 'git stash pop' 恢复更改",
57
+ "upgrade_complete": "升级完成!",
58
+ "check_config": "1. 请检查conf.yaml是否需要更新",
59
+ "resolve_conflicts": "2. 如果有配置文件冲突,请手动解决",
60
+ "check_backup": "3. 检查备份的配置文件以确保没有丢失重要设置",
61
+ "git_not_found": "错误:未检测到 Git。请先安装 Git:\nWindows: https://git-scm.com/download/win\nmacOS: brew install git\nLinux: sudo apt install git",
62
+ "operation_preview": """
63
+ 此脚本将执行以下操作:
64
+ 1. 备份当前的 conf.yaml 配置文件
65
+ 2. 暂存所有未提交的更改 (git stash)
66
+ 3. 从远程仓库拉取最新代码 (git pull)
67
+ 4. 尝试恢复之前暂存的更改 (git stash pop)
68
+
69
+ 是否继续?(y/N): """,
70
+ "merged_config_success": "新增配置项已合并:",
71
+ "merged_config_none": "未发现新增配置项。",
72
+ "merge_failed": "配置合并失败: {error}",
73
+ "updating_submodules": "正在更新子模块...",
74
+ "submodules_updated": "子模块更新完成",
75
+ "submodule_error": "更新子模块时出错",
76
+ "no_submodules": "未检测到子模块,跳过更新",
77
+ "env_info": "系统环境: {os_name} {os_version}, Python {python_version}",
78
+ "git_version": "Git 版本: {git_version}",
79
+ "current_branch": "当前分支: {branch}",
80
+ "operation_time": "操作 '{operation}' 完成, 耗时: {time:.2f} 秒",
81
+ "checking_stash": "检查是否有未提交的更改...",
82
+ "detected_changes": "检测到 {count} 个文件有更改",
83
+ "submodule_updating": "正在更新子模块: {submodule}",
84
+ "submodule_updated": "子模块已更新: {submodule}",
85
+ "submodule_update_error": "❌ 子模块更新失败。",
86
+ "checking_remote": "正在检查远程仓库状态...",
87
+ "remote_ahead": "本地版本已是最新",
88
+ "remote_behind": "发现 {count} 个新提交可供更新",
89
+ "config_backup_path": "配置备份路径: {path}",
90
+ "start_upgrade": "开始升级流程...",
91
+ "version_upgrade_success": "配置版本已成功升级: {old} → {new}",
92
+ "version_upgrade_none": "无需升级配置,当前版本为 {version}",
93
+ "version_upgrade_failed": "升级配置时出错: {error}",
94
+ "finish_upgrade": "升级流程结束, 总耗时: {time:.2f} 秒",
95
+ "backup_used_version": "✅ 从备份文件读取配置版本: {backup_version}",
96
+ "backup_read_error": "⚠️ 读取备份文件失败,使用默认版本 {version}。错误信息: {error}",
97
+ "version_too_old": "🔁 检测到旧版本号 {found} 低于最低支持版本,已强制使用 {adjusted}",
98
+ "checking_ahead_status": "🔍 正在检查是否存在未推送的本地提交...",
99
+ "local_ahead": "🚨 你在 'main' 分支上有 {count} 个尚未推送到远程的本地 commit。",
100
+ "push_blocked": (
101
+ "⛔ 你没有权限推送到 'main' 分支。\n"
102
+ "这些 commit 只保存在本地,无法同步到 GitHub。\n"
103
+ "如果继续升级,可能会导致这些提交丢失或与远程版本发生冲突。"
104
+ ),
105
+ "backup_suggestion": (
106
+ "🛟 为了安全保存你的本地提交,你可以选择以下任意方式:\n"
107
+ "🔄 1. 撤销最近的提交(推荐):\n"
108
+ " • GitHub Desktop:点击右下角的 “Undo” 按钮\n"
109
+ " • 终端命令:git reset --soft HEAD~1\n"
110
+ "📦 2. 导出 patch 文件(保留提交记录):\n"
111
+ " → 终端执行:git format-patch origin/main --stdout > backup.patch\n"
112
+ "🌿 3. 创建一个备份分支(保存当前状态):\n"
113
+ " → 终端执行:git checkout -b my-backup-before-upgrade\n"
114
+ "💡 提示:撤销 commit 后,你可以新建分支或导出补丁以继续操作。"
115
+ ),
116
+ "abort_upgrade": "🛑 为保护本地提交,升级流程已中止。",
117
+ "no_config_fatal": (
118
+ "❌ 未找到配置文件 conf.yaml。\n"
119
+ "请执行以下任一操作:\n"
120
+ "👉 将旧版配置文件复制到当前目录\n"
121
+ "👉 或运行 run_server.py 自动生成默认模板"
122
+ ),
123
+ },
124
+ "en": {
125
+ # "welcome_message": f"Auto-Upgrade Script {CURRENT_SCRIPT_VERSION}\nOpen-LLM-VTuber upgrade script - This script is highly experimental and may not work as expected.",
126
+ "welcome_message": f"Starting auto upgrade from {CURRENT_SCRIPT_VERSION}...",
127
+ # "lang_select": "请选择语言/Please select language (zh/en):",
128
+ # "invalid_lang": "Invalid language selection, using English as default",
129
+ "not_git_repo": "Error: Current directory is not a git repository. Please run this script inside the Open-LLM-VTuber directory.\nAlternatively, it is likely that the Open-LLM-VTuber you downloaded does not contain the .git folder (this can happen if you downloaded a zip archive instead of using git clone), in which case you cannot upgrade using this script.",
130
+ "backup_user_config": "Backing up {user_conf} to {backup_conf}",
131
+ "configs_up_to_date": "[DEBUG] User configuration is up-to-date.",
132
+ "no_config": "Warning: conf.yaml not found",
133
+ "copy_default_config": "Copying default configuration from template",
134
+ "uncommitted": "Found uncommitted changes, stashing...",
135
+ "stash_error": "Error: Unable to stash changes",
136
+ "changes_stashed": "Changes stashed",
137
+ "pulling": "Pulling updates from remote repository...",
138
+ "pull_error": "Error: Unable to pull updates",
139
+ "restoring": "Restoring stashed changes...",
140
+ "conflict_warning": "Warning: Conflicts occurred while restoring stashed changes",
141
+ "manual_resolve": "Please resolve conflicts manually",
142
+ "stash_list": "Use 'git stash list' to view stashed changes",
143
+ "stash_pop": "Use 'git stash pop' to restore changes",
144
+ "upgrade_complete": "Upgrade complete!",
145
+ "check_config": "1. Please check if conf.yaml needs updating",
146
+ "resolve_conflicts": "2. Resolve any config file conflicts manually",
147
+ "check_backup": "3. Check backup config to ensure no important settings are lost",
148
+ "git_not_found": "Error: Git not found. Please install Git first:\nWindows: https://git-scm.com/download/win\nmacOS: brew install git\nLinux: sudo apt install git",
149
+ "operation_preview": """
150
+ This script will perform the following operations:
151
+ 1. Backup current conf.yaml configuration file
152
+ 2. Stash all uncommitted changes (git stash)
153
+ 3. Pull latest code from remote repository (git pull)
154
+ 4. Attempt to restore previously stashed changes (git stash pop)
155
+
156
+ Continue? (y/N): """,
157
+ "merged_config_success": "Merged new configuration items:",
158
+ "merged_config_none": "No new configuration items found.",
159
+ "merge_failed": "Configuration merge failed: {error}",
160
+ "updating_submodules": "Updating submodules...",
161
+ "submodules_updated": "Submodules updated successfully",
162
+ "submodule_error": "Error updating submodules",
163
+ "no_submodules": "No submodules detected, skipping update",
164
+ "env_info": "Environment: {os_name} {os_version}, Python {python_version}",
165
+ "git_version": "Git version: {git_version}",
166
+ "current_branch": "Current branch: {branch}",
167
+ "operation_time": "Operation '{operation}' completed in {time:.2f} seconds",
168
+ "checking_stash": "Checking for uncommitted changes...",
169
+ "detected_changes": "Detected changes in {count} files",
170
+ "submodule_updating": "Updating submodule: {submodule}",
171
+ "submodule_updated": "Submodule updated: {submodule}",
172
+ "submodule_update_error": "❌ Submodule update failed.",
173
+ "checking_remote": "Checking remote repository status...",
174
+ "remote_ahead": "Local version is up to date",
175
+ "remote_behind": "Found {count} new commits to pull",
176
+ "config_backup_path": "Config backup path: {path}",
177
+ "start_upgrade": "Starting upgrade process...",
178
+ "version_upgrade_success": "Config version upgraded: {old} → {new}",
179
+ "version_upgrade_none": "No upgrade needed. Current version is {version}",
180
+ "version_upgrade_failed": "Failed to upgrade config version: {error}",
181
+ "finish_upgrade": "Upgrade process completed, total time: {time:.2f} seconds",
182
+ "backup_used_version": "✅ Loaded config version from backup: {backup_version}",
183
+ "backup_read_error": "⚠️ Failed to read backup file. Falling back to default version {version}. Error: {error}",
184
+ "version_too_old": "🔁 Detected old version {found} which is lower than the minimum supported version, forced to use {adjusted}",
185
+ "checking_ahead_status": "🔍 Checking for unpushed local commits...",
186
+ "local_ahead": "🚨 You have {count} local commit(s) on 'main' that are NOT pushed to remote.",
187
+ "push_blocked": (
188
+ "⛔ You do NOT have permission to push to the 'main' branch.\n"
189
+ "Your commits are local only and will NOT be synced to GitHub.\n"
190
+ "Continuing the upgrade may cause those commits to be lost or conflict with remote changes."
191
+ ),
192
+ "backup_suggestion": (
193
+ "🛟 To keep your work safe, you can choose one of the following options:\n"
194
+ "🔄 1. Undo the last commit:\n"
195
+ " • GitHub Desktop: Click the 'Undo' button at the bottom right.\n"
196
+ " • Terminal: Run: git reset --soft HEAD~1\n"
197
+ "📦 2. Export your commit(s) as a patch file:\n"
198
+ " → Run: git format-patch origin/main --stdout > backup.patch\n"
199
+ "🌿 3. Create a backup branch:\n"
200
+ " → Run: git checkout -b my-backup-before-upgrade\n"
201
+ "💡 Recommendation: After undoing the commit, you can switch to a new branch or export changes as needed."
202
+ ),
203
+ "abort_upgrade": "🛑 Upgrade aborted to protect your local commits.",
204
+ "no_config_fatal": (
205
+ "❌ Config file conf.yaml not found.\n"
206
+ "Please either:\n"
207
+ "👉 Copy your old config file to the current directory\n"
208
+ "👉 Or run run_server.py to generate a default template"
209
+ ),
210
+ },
211
+ }
212
+
213
+ # Multilingual texts for merge_configs log messages
214
+ TEXTS_MERGE = {
215
+ "zh": {
216
+ "new_config_item": "[信息] 新增配置项: {key}",
217
+ },
218
+ "en": {
219
+ "new_config_item": "[INFO] New config item: {key}",
220
+ },
221
+ }
222
+
223
+ # Multilingual texts for compare_configs log messages
224
+ TEXTS_COMPARE = {
225
+ "zh": {
226
+ "missing_keys": "用户配置缺少以下键,可能与默认配置不一致: {keys}",
227
+ "extra_keys": "用户配置包含以下默认配置中不存在的键: {keys}",
228
+ "up_to_date": "用户配置与默认配置一致。",
229
+ "compare_passed": "{name} 对比通过。",
230
+ "compare_failed": "{name} 配置不一致。",
231
+ "compare_diff_item": "- {item}",
232
+ "compare_error": "{name} 对比失败: {error}",
233
+ "comments_up_to_date": "注释一致,跳过注释同步。",
234
+ "extra_keys_deleted_count": "已删除 {count} 个额外键:",
235
+ "extra_keys_deleted_item": " - {key}",
236
+ "comment_sync_success": "注释同步成功。",
237
+ "comment_sync_error": "注释同步失败: {error}",
238
+ },
239
+ "en": {
240
+ "missing_keys": "User config is missing the following keys, which may be out-of-date: {keys}",
241
+ "extra_keys": "User config contains the following keys not present in default config: {keys}",
242
+ "up_to_date": "User config is up-to-date with default config.",
243
+ "compare_passed": "{name} comparison passed.",
244
+ "compare_failed": "{name} comparison failed: configs differ.",
245
+ "compare_diff_item": "- {item}",
246
+ "compare_error": "{name} comparison error: {error}",
247
+ "comments_up_to_date": "Comments are up to date, skipping comment sync.",
248
+ "extra_keys_deleted_count": "Deleted {count} extra keys:",
249
+ "extra_keys_deleted_item": " - {key}",
250
+ "comment_sync_success": "All comments synchronized successfully.",
251
+ "comment_sync_error": "Failed to synchronize comments: {error}",
252
+ },
253
+ }
254
+
255
+ UPGRADE_TEXTS = {
256
+ "zh": {
257
+ "model_dict_not_found": "⚠️ 未找到 model_dict.json,跳过升级。",
258
+ "model_dict_read_error": "❌ 读取 model_dict.json 失败: {error}",
259
+ "upgrade_success": "✅ model_dict.json 已成功升级至 v1.2.1 格式 ({language} 语言)",
260
+ "already_latest": "model_dict.json 已是最新格式。",
261
+ "upgrade_error": "❌ 升级 model_dict.json 失败: {error}",
262
+ "no_upgrade_routine": "没有适用于版本 {version} 的升级程序",
263
+ "upgrading_path": "⬆️ 正在升级配置: {from_version} → {to_version}",
264
+ },
265
+ "en": {
266
+ "model_dict_not_found": "⚠️ model_dict.json not found. Skipping upgrade.",
267
+ "model_dict_read_error": "❌ Failed to read model_dict.json: {error}",
268
+ "upgrade_success": "✅ model_dict.json upgraded to v1.2.1 format ({language} language)",
269
+ "already_latest": "model_dict.json already in latest format.",
270
+ "upgrade_error": "❌ Failed to upgrade model_dict.json: {error}",
271
+ "no_upgrade_routine": "No upgrade routine for version {version}",
272
+ "upgrading_path": "⬆️ Upgrading config: {from_version} → {to_version}",
273
+ },
274
+ }
upgrade_codes/upgrade_core/language.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import ctypes
4
+ import locale
5
+ import platform
6
+ import subprocess
7
+
8
+
9
+ def get_system_language():
10
+ """Get system language using a combination of methods."""
11
+
12
+ os_name = platform.system()
13
+
14
+ if os_name == "Windows":
15
+ try:
16
+ # Use Windows API to get the UI language
17
+ windll = ctypes.windll.kernel32 # type: ignore
18
+ ui_lang = windll.GetUserDefaultUILanguage()
19
+ lang_code = locale.windows_locale.get(ui_lang)
20
+ if lang_code:
21
+ lang = lang_code.split("_")[0]
22
+ if lang.startswith("zh"):
23
+ return "zh"
24
+ except Exception:
25
+ pass
26
+
27
+ elif os_name == "Darwin": # macOS
28
+ try:
29
+ # Use defaults command to get the AppleLocale
30
+ result = subprocess.run(
31
+ ["defaults", "read", "-g", "AppleLocale"],
32
+ capture_output=True,
33
+ text=True,
34
+ )
35
+ lang = result.stdout.strip().split("_")[0]
36
+ if lang.startswith("zh"):
37
+ return "zh"
38
+ except Exception:
39
+ pass
40
+
41
+ elif os_name == "Linux":
42
+ # Check the LANG environment variable
43
+ lang = os.environ.get("LANG")
44
+ if lang:
45
+ lang = lang.split("_")[0]
46
+ if lang.startswith("zh"):
47
+ return "zh"
48
+
49
+ # Fallback to using locale.getpreferredencoding()
50
+ encoding = locale.getpreferredencoding()
51
+ if encoding.lower() in ("cp936", "gbk", "big5"):
52
+ return "zh"
53
+
54
+ return "en"
55
+
56
+
57
+ def select_language():
58
+ """Select language based on command-line argument or system language"""
59
+ if len(sys.argv) > 1 and sys.argv[1].lower() in ["zh", "en"]:
60
+ return sys.argv[1].lower()
61
+ return get_system_language()
upgrade_codes/upgrade_core/upgrade_utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import subprocess
4
+ import sys
5
+ import time
6
+ from upgrade_codes.upgrade_core.constants import TEXTS, TEXTS_COMPARE
7
+ from typing import Callable, Any
8
+
9
+
10
+ class UpgradeUtility:
11
+ def __init__(self, logger, lang):
12
+ self.logger = logger
13
+ self.lang = lang
14
+ self.texts = TEXTS[lang]
15
+ self.texts_compare = TEXTS_COMPARE[lang]
16
+
17
+ def run_command(self, command):
18
+ """Run shell command and return result"""
19
+ try:
20
+ result = subprocess.run(
21
+ command,
22
+ shell=True,
23
+ check=True,
24
+ capture_output=True,
25
+ text=True,
26
+ encoding="utf-8",
27
+ errors="replace",
28
+ )
29
+ return True, result.stdout
30
+ except subprocess.CalledProcessError as e:
31
+ return (
32
+ False,
33
+ f"Command failed with error code {e.returncode}\nError: {e.stderr}",
34
+ )
35
+ except Exception as e:
36
+ return False, f"Unexpected error: {str(e)}"
37
+
38
+ def check_git_installed(self):
39
+ """Check if Git is installed"""
40
+ command = "where git" if sys.platform == "win32" else "which git"
41
+ try:
42
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
43
+ return result.returncode == 0
44
+ except subprocess.SubprocessError:
45
+ return False
46
+
47
+ def log_system_info(self):
48
+ """Log detailed system information"""
49
+ texts = self.texts
50
+
51
+ # Log OS info
52
+ os_name = platform.system()
53
+ os_version = platform.version()
54
+ python_version = platform.python_version()
55
+ self.logger.info(
56
+ texts["env_info"].format(
57
+ os_name=os_name, os_version=os_version, python_version=python_version
58
+ )
59
+ )
60
+
61
+ # Log Git version
62
+ success, git_version = self.run_command("git --version")
63
+ if success:
64
+ self.logger.info(
65
+ texts["git_version"].format(git_version=git_version.strip())
66
+ )
67
+
68
+ # Log current branch
69
+ success, branch = self.run_command("git branch --show-current")
70
+ if success:
71
+ self.logger.info(texts["current_branch"].format(branch=branch.strip()))
72
+
73
+ def time_operation(self, func, *args, **kwargs):
74
+ """Time an operation and return result with timing information"""
75
+ start_time = time.time()
76
+ result = func(*args, **kwargs)
77
+ end_time = time.time()
78
+ elapsed = end_time - start_time
79
+ return result, elapsed
80
+
81
+ def get_submodule_list(self):
82
+ """Get a list of submodules"""
83
+ if not os.path.exists(".gitmodules"):
84
+ return []
85
+
86
+ success, output = self.run_command(
87
+ "git config --file .gitmodules --get-regexp path"
88
+ )
89
+ if not success:
90
+ return []
91
+
92
+ submodules = []
93
+ for line in output.strip().split("\n"):
94
+ if line:
95
+ parts = line.strip().split(" ", 1)
96
+ if len(parts) >= 2:
97
+ submodules.append(parts[1])
98
+
99
+ return submodules
100
+
101
+ def has_submodules(self):
102
+ """Check if the repository has submodules by looking for .gitmodules file"""
103
+ return os.path.exists(".gitmodules")
104
+
105
+ def compare_dicts(
106
+ self,
107
+ name: str,
108
+ get_a: Callable[[], Any],
109
+ get_b: Callable[[], Any],
110
+ compare_fn: Callable[[Any, Any], Any],
111
+ ) -> bool:
112
+ try:
113
+ a = get_a()
114
+ b = get_b()
115
+ comparison_result = compare_fn(a, b)
116
+
117
+ if isinstance(comparison_result, tuple):
118
+ is_equal, differences = comparison_result
119
+ else:
120
+ is_equal, differences = comparison_result, []
121
+
122
+ if is_equal:
123
+ msg = self.texts_compare.get(
124
+ "compare_passed", "[DEBUG] {name} comparison passed."
125
+ )
126
+ self.logger.debug(msg.format(name=name))
127
+ return True
128
+ else:
129
+ msg = self.texts_compare.get(
130
+ "compare_failed",
131
+ "[WARNING] {name} comparison failed: configs differ.",
132
+ )
133
+ self.logger.warning(msg.format(name=name))
134
+ if differences:
135
+ for item in differences:
136
+ self.logger.warning(
137
+ self.texts_compare.get(
138
+ "compare_diff_item", "- {item}"
139
+ ).format(item=item)
140
+ )
141
+ return False
142
+ except Exception as e:
143
+ msg = self.texts_compare.get(
144
+ "compare_error", "[ERROR] {name} comparison error: {error}"
145
+ )
146
+ self.logger.error(msg.format(name=name, error=e))
147
+ return False
upgrade_codes/upgrade_manager.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ from upgrade_codes.upgrade_core.language import select_language
3
+ from upgrade_codes.config_sync import ConfigSynchronizer
4
+ from upgrade_codes.upgrade_core.upgrade_utils import UpgradeUtility
5
+ import os
6
+ from datetime import datetime
7
+ import sys
8
+ from upgrade_codes.upgrade_core.constants import USER_CONF, TEXTS
9
+
10
+
11
+ class UpgradeManager:
12
+ def __init__(self):
13
+ self.lang = select_language()
14
+ self._configure_logger()
15
+ self.logger = logger
16
+ self.upgrade_utils = UpgradeUtility(self.logger, self.lang)
17
+ self.config_sync = ConfigSynchronizer(self.lang, self.logger)
18
+ self.texts = TEXTS
19
+
20
+ def check_user_config_exists(self):
21
+ if not os.path.exists(USER_CONF):
22
+ print(self.texts[self.lang]["no_config_fatal"])
23
+ exit(1)
24
+
25
+ def _configure_logger(self):
26
+ logger.remove()
27
+ log_dir = "logs"
28
+ os.makedirs(log_dir, exist_ok=True)
29
+ log_file = os.path.join(
30
+ log_dir, f"upgrade_{datetime.now().strftime('%Y-%m-%d-%H-%M')}.log"
31
+ )
32
+
33
+ logger.add(
34
+ sys.stdout,
35
+ level="DEBUG",
36
+ colorize=True,
37
+ format="<green>[{level}]</green> <level>{message}</level>",
38
+ )
39
+ logger.add(
40
+ log_file,
41
+ level="DEBUG",
42
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
43
+ )
44
+
45
+ def sync_user_config(self):
46
+ self.config_sync.sync_user_config()
47
+
48
+ def update_user_config(self):
49
+ self.config_sync.update_user_config()
50
+
51
+ def log_system_info(self):
52
+ return self.upgrade_utils.log_system_info()
53
+
54
+ def check_git_installed(self):
55
+ return self.upgrade_utils.check_git_installed()
56
+
57
+ def run_command(self, command):
58
+ return self.upgrade_utils.run_command(command)
59
+
60
+ def time_operation(self, func, *args, **kwargs):
61
+ return self.upgrade_utils.time_operation(func, *args, **kwargs)
62
+
63
+ def get_submodule_list(self):
64
+ return self.upgrade_utils.get_submodule_list()
65
+
66
+ def has_submodules(self):
67
+ return self.upgrade_utils.has_submodules()
upgrade_codes/version_manager.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from packaging.version import parse as parse_version
4
+ from upgrade_codes.upgrade_core.constants import USER_CONF, UPGRADE_TEXTS
5
+ from upgrade_codes.from_version.v_1_1_1 import to_v_1_2_1
6
+
7
+
8
+ class VersionUpgradeManager:
9
+ def __init__(self, language, logger):
10
+ self.logger = logger
11
+ self.language = language
12
+ self.log_texts = UPGRADE_TEXTS.get(language, UPGRADE_TEXTS["en"])
13
+ self.indent_spaces = 4
14
+ self.user_config = USER_CONF
15
+
16
+ def get_upgrade_mapping(self):
17
+ """
18
+ Define version upgrade tasks using version ranges.
19
+ Each task maps a range [from_version, to_version) to a specific upgrade module.
20
+ """
21
+ return [
22
+ {
23
+ "from_range": (
24
+ "v1.1.1",
25
+ "v1.2.1",
26
+ ), # Inclusive lower bound, exclusive upper bound
27
+ "from_version": "v1.1.1",
28
+ "to_version": "v1.2.1",
29
+ "module": to_v_1_2_1,
30
+ },
31
+ # Future upgrade example:
32
+ # {
33
+ # "from_range": ("v1.2.1", "v1.3.0"),
34
+ # "from_version": "v1.2.1",
35
+ # "to_version": "v1.3.0",
36
+ # "module": to_v_1_3_0,
37
+ # },
38
+ ]
39
+
40
+ def resolve_upgrade_task(self, current_version: str):
41
+ """
42
+ Determine which upgrade task applies to the given current_version.
43
+ Returns a tuple of (from_version, to_version, module) if matched, else None.
44
+ """
45
+ parsed_current = parse_version(current_version.strip("v"))
46
+ for task in self.get_upgrade_mapping():
47
+ low = parse_version(task["from_range"][0].strip("v"))
48
+ high = parse_version(task["from_range"][1].strip("v"))
49
+ if low <= parsed_current < high:
50
+ return task["from_version"], task["to_version"], task["module"]
51
+ return None
52
+
53
+ def upgrade(self, current_version: str) -> str:
54
+ """
55
+ Perform the upgrade process starting from current_version.
56
+ If a matching version range is found, run the corresponding upgrade module.
57
+ """
58
+ task = self.resolve_upgrade_task(current_version)
59
+ if not task:
60
+ self.logger.info(
61
+ self.log_texts["no_upgrade_routine"].format(version=current_version)
62
+ )
63
+ return current_version
64
+
65
+ from_version, to_version, module = task
66
+ self.logger.info(
67
+ self.log_texts["upgrading_path"].format(
68
+ from_version=current_version, to_version=to_version
69
+ )
70
+ )
71
+ upgraded_version = current_version
72
+
73
+ try:
74
+ model_path = Path("model_dict.json")
75
+ with open(model_path, "r", encoding="utf-8") as f:
76
+ model_dict = json.load(f)
77
+
78
+ if isinstance(model_dict, list):
79
+ new_data = module(model_dict, self.user_config, self.language).upgrade()
80
+ with open(model_path, "w", encoding="utf-8") as f:
81
+ json.dump(
82
+ new_data, f, indent=self.indent_spaces, ensure_ascii=False
83
+ )
84
+
85
+ upgraded_version = to_version
86
+ self.logger.info(
87
+ self.log_texts["upgrade_success"].format(language=self.language)
88
+ )
89
+ else:
90
+ self.logger.info(self.log_texts["already_latest"])
91
+ except Exception as e:
92
+ self.logger.error(self.log_texts["upgrade_error"].format(error=e))
93
+
94
+ return upgraded_version
web_tool/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web Tool
2
+
3
+
4
+ ## Why?
5
+ The Open-LLM-VTuber project leverages TTS and ASR (speech recognition) models to deliver an immersive, voice-to-voice AI companion experience.
6
+
7
+ While ASR and TTS technologies are powerful on their own, setting them up can be challenging. Previously, although our users installed the TTS and ASR models into our project, they were exclusively accessible as a part of the Open-LLM-VTuber's AI companion feature, preventing their use for other purposes like transcription or speech generation.
8
+
9
+ ## What is Web Tool?
10
+
11
+ This is a dedicated web page within the Open-LLM-VTuber backend that provides direct access to the ASR and TTS models initialized by the Open-LLM-VTuber server.
12
+
13
+ Access the web page at: http://localhost:12393/web-tool. Note that the ASR and TTS models are the same ones you've set in the `conf.yaml` file, and switching models at runtime is not possible at this point.
web_tool/index.html ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>ASR & TTS Tool</title>
8
+ <style>
9
+ body {
10
+ font-family: Arial, sans-serif;
11
+ max-width: 800px;
12
+ margin: 0 auto;
13
+ padding: 20px;
14
+ line-height: 1.6;
15
+ }
16
+
17
+ .section {
18
+ margin-bottom: 30px;
19
+ padding: 20px;
20
+ border: 1px solid #ccc;
21
+ border-radius: 8px;
22
+ background-color: #f9f9f9;
23
+ }
24
+
25
+ .controls {
26
+ display: flex;
27
+ gap: 20px;
28
+ margin-bottom: 15px;
29
+ }
30
+
31
+ .record-controls,
32
+ .upload-controls {
33
+ display: flex;
34
+ gap: 10px;
35
+ align-items: center;
36
+ }
37
+
38
+ .file-input {
39
+ max-width: 200px;
40
+ }
41
+
42
+ .button {
43
+ padding: 10px 20px;
44
+ margin: 5px;
45
+ border: none;
46
+ border-radius: 5px;
47
+ background-color: #007bff;
48
+ color: white;
49
+ cursor: pointer;
50
+ transition: background-color 0.3s;
51
+ }
52
+
53
+ .button:hover {
54
+ background-color: #0056b3;
55
+ }
56
+
57
+ .button:disabled {
58
+ background-color: #cccccc;
59
+ cursor: not-allowed;
60
+ }
61
+
62
+ textarea {
63
+ width: 100%;
64
+ height: 100px;
65
+ margin: 10px 0;
66
+ padding: 10px;
67
+ border: 1px solid #ddd;
68
+ border-radius: 4px;
69
+ resize: vertical;
70
+ }
71
+
72
+ .status {
73
+ margin: 10px 0;
74
+ padding: 10px;
75
+ border-radius: 5px;
76
+ font-size: 14px;
77
+ }
78
+
79
+ .error {
80
+ background-color: #ffe6e6;
81
+ border: 1px solid #ffcccc;
82
+ color: #cc0000;
83
+ }
84
+
85
+ .success {
86
+ background-color: #e6ffe6;
87
+ border: 1px solid #ccffcc;
88
+ color: #006600;
89
+ }
90
+
91
+ #audioPlayer {
92
+ width: 100%;
93
+ margin: 10px 0;
94
+ }
95
+
96
+ h1,
97
+ h2 {
98
+ color: #333;
99
+ }
100
+
101
+ h1 {
102
+ border-bottom: 2px solid #007bff;
103
+ padding-bottom: 10px;
104
+ margin-bottom: 30px;
105
+ }
106
+ </style>
107
+ </head>
108
+
109
+ <body>
110
+ <h1>ASR & TTS Tool</h1>
111
+
112
+ <div class="section">
113
+ <h2>Speech Recognition</h2>
114
+ <div class="controls">
115
+ <div class="record-controls">
116
+ <button id="startRecording" class="button">Start Recording</button>
117
+ <button id="stopRecording" class="button" disabled>Stop Recording</button>
118
+ </div>
119
+ <div class="upload-controls">
120
+ <input type="file" id="audioFileInput" accept="audio/*"
121
+ title="Supported formats: WAV, MP3, M4A, OGG, etc. Files will be converted to 16kHz mono WAV."
122
+ class="file-input" />
123
+ <button id="uploadAudio" class="button">Upload Audio</button>
124
+ </div>
125
+ </div>
126
+ <div id="asrStatus" class="status"></div>
127
+ <textarea id="transcription" placeholder="Transcription will appear here..." readonly></textarea>
128
+ </div>
129
+
130
+ <div class="section">
131
+ <h2>Text to Speech</h2>
132
+ <textarea id="ttsInput" placeholder="Enter text for TTS..."></textarea>
133
+ <button id="generateSpeech" class="button">Generate Speech</button>
134
+ <div id="ttsStatus" class="status"></div>
135
+ <audio id="audioPlayer" controls></audio>
136
+ <button id="downloadAudio" class="button" disabled>Download Audio</button>
137
+ </div>
138
+
139
+ <script src="recorder.js"></script>
140
+ <script src="main.js"></script>
141
+ </body>
142
+
143
+ </html>
web_tool/main.js ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const API_BASE_URL = window.location.origin;
2
+ const recorder = new AudioRecorder();
3
+
4
+ // Audio context and buffers
5
+ let audioContext = null;
6
+ let audioBuffers = [];
7
+ let pendingAudioPaths = new Set();
8
+ let currentAudioPath = null;
9
+ let ws = null;
10
+
11
+ // DOM Elements
12
+ const startRecordingBtn = document.getElementById('startRecording');
13
+ const stopRecordingBtn = document.getElementById('stopRecording');
14
+ const transcriptionArea = document.getElementById('transcription');
15
+ const asrStatus = document.getElementById('asrStatus');
16
+ const ttsInput = document.getElementById('ttsInput');
17
+ const generateSpeechBtn = document.getElementById('generateSpeech');
18
+ const ttsStatus = document.getElementById('ttsStatus');
19
+ const audioPlayer = document.getElementById('audioPlayer');
20
+ const downloadAudioBtn = document.getElementById('downloadAudio');
21
+ const audioFileInput = document.getElementById('audioFileInput');
22
+ const uploadAudioBtn = document.getElementById('uploadAudio');
23
+
24
+ // File upload handler with format conversion
25
+ uploadAudioBtn.addEventListener('click', async () => {
26
+ const file = audioFileInput.files[0];
27
+ if (!file) {
28
+ asrStatus.textContent = 'Please select an audio file';
29
+ asrStatus.className = 'status error';
30
+ return;
31
+ }
32
+
33
+ try {
34
+ asrStatus.textContent = 'Processing audio file...';
35
+ asrStatus.className = 'status';
36
+
37
+ // Convert audio to WAV format
38
+ const audioContext = new (window.AudioContext || window.webkitAudioContext)();
39
+ const arrayBuffer = await file.arrayBuffer();
40
+ const audioBuffer = await audioContext.decodeAudioData(arrayBuffer);
41
+
42
+ // Create WAV file
43
+ const wavBuffer = await audioBufferToWav(audioBuffer);
44
+ const wavBlob = new Blob([wavBuffer], { type: 'audio/wav' });
45
+
46
+ const formData = new FormData();
47
+ formData.append('file', wavBlob, 'recording.wav');
48
+
49
+ const response = await fetch(`${API_BASE_URL}/asr`, {
50
+ method: 'POST',
51
+ body: formData
52
+ });
53
+
54
+ if (!response.ok) throw new Error('ASR request failed');
55
+
56
+ const data = await response.json();
57
+ transcriptionArea.value = data.text;
58
+ asrStatus.textContent = 'Transcription complete!';
59
+ asrStatus.className = 'status success';
60
+
61
+ // Clean up
62
+ audioContext.close();
63
+ } catch (error) {
64
+ asrStatus.textContent = 'Error: ' + error.message;
65
+ asrStatus.className = 'status error';
66
+ }
67
+ });
68
+
69
+ // Recording handlers
70
+ startRecordingBtn.addEventListener('click', async () => {
71
+ try {
72
+ asrStatus.textContent = 'Starting recording...';
73
+ asrStatus.className = 'status';
74
+ await recorder.start();
75
+ startRecordingBtn.disabled = true;
76
+ stopRecordingBtn.disabled = false;
77
+ asrStatus.textContent = 'Recording...';
78
+ } catch (error) {
79
+ asrStatus.textContent = 'Error starting recording: ' + error.message;
80
+ asrStatus.className = 'status error';
81
+ }
82
+ });
83
+
84
+ stopRecordingBtn.addEventListener('click', async () => {
85
+ try {
86
+ const audioBlob = await recorder.stop();
87
+ startRecordingBtn.disabled = false;
88
+ stopRecordingBtn.disabled = true;
89
+ asrStatus.textContent = 'Processing audio...';
90
+
91
+ // Send to ASR endpoint
92
+ const formData = new FormData();
93
+ formData.append('file', audioBlob);
94
+
95
+ const response = await fetch(`${API_BASE_URL}/asr`, {
96
+ method: 'POST',
97
+ body: formData
98
+ });
99
+
100
+ if (!response.ok) throw new Error('ASR request failed');
101
+
102
+ const data = await response.json();
103
+ transcriptionArea.value = data.text;
104
+ asrStatus.textContent = 'Transcription complete!';
105
+ asrStatus.className = 'status success';
106
+ } catch (error) {
107
+ asrStatus.textContent = 'Error: ' + error.message;
108
+ asrStatus.className = 'status error';
109
+ startRecordingBtn.disabled = false;
110
+ stopRecordingBtn.disabled = true;
111
+ }
112
+ });
113
+
114
+ // TTS handlers
115
+ function connectWebSocket() {
116
+ const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
117
+ ws = new WebSocket(`${wsProtocol}://${window.location.host}/tts-ws`);
118
+
119
+ ws.onopen = () => {
120
+ console.log('WebSocket connected');
121
+ generateSpeechBtn.disabled = false;
122
+ ttsStatus.textContent = 'Connected to TTS service';
123
+ ttsStatus.className = 'status success';
124
+
125
+ // Initialize AudioContext if needed
126
+ if (!audioContext) {
127
+ audioContext = new (window.AudioContext || window.webkitAudioContext)();
128
+ } else if (audioContext.state === 'suspended') {
129
+ audioContext.resume();
130
+ }
131
+ };
132
+
133
+ ws.onmessage = async (event) => {
134
+ const response = JSON.parse(event.data);
135
+
136
+ if (response.status === 'partial') {
137
+ ttsStatus.textContent = 'Generating audio...';
138
+ ttsStatus.className = 'status';
139
+
140
+ try {
141
+ const audioPath = response.audioPath.split('/').pop();
142
+ pendingAudioPaths.add(audioPath);
143
+
144
+ if (audioContext.state === 'suspended') {
145
+ await audioContext.resume();
146
+ }
147
+
148
+ // Use retry mechanism for fetching audio
149
+ const audioResponse = await fetchWithRetry(`${API_BASE_URL}/cache/${audioPath}`);
150
+ const arrayBuffer = await audioResponse.arrayBuffer();
151
+
152
+ if (arrayBuffer.byteLength === 0) {
153
+ throw new Error('Empty audio data received');
154
+ }
155
+
156
+ const audioBuffer = await audioContext.decodeAudioData(arrayBuffer);
157
+ audioBuffers.push(audioBuffer);
158
+ pendingAudioPaths.delete(audioPath);
159
+ } catch (error) {
160
+ console.error('Error loading audio:', error);
161
+ ttsStatus.textContent = 'Error loading audio: ' + error.message;
162
+ ttsStatus.className = 'status error';
163
+ pendingAudioPaths.clear();
164
+ }
165
+ } else if (response.status === 'complete') {
166
+ // Wait for any pending audio loads to complete
167
+ if (pendingAudioPaths.size > 0) {
168
+ ttsStatus.textContent = 'Finalizing audio...';
169
+ await new Promise(resolve => setTimeout(resolve, 500));
170
+ }
171
+
172
+ try {
173
+ // Combine all audio buffers
174
+ const targetSampleRate = 16000;
175
+ const totalLength = audioBuffers.reduce((acc, buffer) => {
176
+ // Calculate resampled length if needed
177
+ const ratio = targetSampleRate / buffer.sampleRate;
178
+ return acc + Math.ceil(buffer.length * ratio);
179
+ }, 0);
180
+
181
+ const combinedBuffer = audioContext.createBuffer(
182
+ 1, // mono
183
+ totalLength,
184
+ targetSampleRate
185
+ );
186
+
187
+ let offset = 0;
188
+ for (const buffer of audioBuffers) {
189
+ // Resample if needed
190
+ let channelData = buffer.getChannelData(0);
191
+ if (buffer.sampleRate !== targetSampleRate) {
192
+ channelData = await resampleAudio(channelData, buffer.sampleRate, targetSampleRate);
193
+ }
194
+ combinedBuffer.copyToChannel(channelData, 0, offset);
195
+ offset += channelData.length;
196
+ }
197
+
198
+ // Convert to WAV for download
199
+ const wavBlob = new Blob([await audioBufferToWav(combinedBuffer)], { type: 'audio/wav' });
200
+ const audioUrl = URL.createObjectURL(wavBlob);
201
+
202
+ // Update audio player
203
+ audioPlayer.src = audioUrl;
204
+ audioPlayer.load();
205
+ downloadAudioBtn.disabled = false;
206
+
207
+ // Store for download
208
+ currentAudioPath = audioUrl;
209
+
210
+ ttsStatus.textContent = 'Audio generated successfully!';
211
+ ttsStatus.className = 'status success';
212
+ } catch (error) {
213
+ console.error('Error combining audio:', error);
214
+ ttsStatus.textContent = 'Error combining audio: ' + error.message;
215
+ ttsStatus.className = 'status error';
216
+ } finally {
217
+ // Clear buffers
218
+ audioBuffers = [];
219
+ pendingAudioPaths.clear();
220
+ }
221
+ } else if (response.status === 'error') {
222
+ ttsStatus.textContent = 'Error: ' + response.message;
223
+ ttsStatus.className = 'status error';
224
+ audioBuffers = [];
225
+ pendingAudioPaths.clear();
226
+ }
227
+ };
228
+
229
+ ws.onclose = () => {
230
+ console.log('WebSocket disconnected');
231
+ generateSpeechBtn.disabled = true;
232
+ ttsStatus.textContent = 'Disconnected. Trying to reconnect...';
233
+ ttsStatus.className = 'status error';
234
+
235
+ // Clean up any pending audio resources
236
+ audioBuffers = [];
237
+ pendingAudioPaths.clear();
238
+ if (currentAudioPath) {
239
+ URL.revokeObjectURL(currentAudioPath);
240
+ currentAudioPath = null;
241
+ }
242
+
243
+ setTimeout(connectWebSocket, 5000);
244
+ };
245
+
246
+ ws.onerror = (error) => {
247
+ console.error('WebSocket error:', error);
248
+ ttsStatus.textContent = 'Connection error. Retrying...';
249
+ ttsStatus.className = 'status error';
250
+
251
+ // Clean up audio resources on error
252
+ audioBuffers = [];
253
+ pendingAudioPaths.clear();
254
+ if (currentAudioPath) {
255
+ URL.revokeObjectURL(currentAudioPath);
256
+ currentAudioPath = null;
257
+ }
258
+ };
259
+ }
260
+
261
+ // Convert AudioBuffer to WAV with specific format requirements
262
+ async function audioBufferToWav(buffer) {
263
+ // Resample to 16kHz if needed
264
+ let audioData = buffer.getChannelData(0);
265
+ if (buffer.sampleRate !== 16000) {
266
+ audioData = await resampleAudio(audioData, buffer.sampleRate, 16000);
267
+ }
268
+
269
+ const numChannels = 1; // Mono
270
+ const sampleRate = 16000;
271
+ const format = 1; // PCM
272
+ const bitDepth = 16;
273
+
274
+ const dataLength = audioData.length * (bitDepth / 8);
275
+ const headerLength = 44;
276
+ const totalLength = headerLength + dataLength;
277
+
278
+ const arrayBuffer = new ArrayBuffer(totalLength);
279
+ const view = new DataView(arrayBuffer);
280
+
281
+ // Write WAV header
282
+ writeString(view, 0, 'RIFF');
283
+ view.setUint32(4, totalLength - 8, true);
284
+ writeString(view, 8, 'WAVE');
285
+ writeString(view, 12, 'fmt ');
286
+ view.setUint32(16, 16, true);
287
+ view.setUint16(20, format, true);
288
+ view.setUint16(22, numChannels, true);
289
+ view.setUint32(24, sampleRate, true);
290
+ view.setUint32(28, sampleRate * numChannels * (bitDepth / 8), true);
291
+ view.setUint16(32, numChannels * (bitDepth / 8), true);
292
+ view.setUint16(34, bitDepth, true);
293
+ writeString(view, 36, 'data');
294
+ view.setUint32(40, dataLength, true);
295
+
296
+ // Write audio data
297
+ floatTo16BitPCM(view, 44, audioData);
298
+
299
+ return arrayBuffer;
300
+ }
301
+
302
+ function resampleAudio(audioData, originalSampleRate, targetSampleRate) {
303
+ const ratio = targetSampleRate / originalSampleRate;
304
+ const newLength = Math.round(audioData.length * ratio);
305
+ const result = new Float32Array(newLength);
306
+
307
+ for (let i = 0; i < newLength; i++) {
308
+ const position = i / ratio;
309
+ const index = Math.floor(position);
310
+ const fraction = position - index;
311
+
312
+ if (index + 1 < audioData.length) {
313
+ result[i] = audioData[index] * (1 - fraction) + audioData[index + 1] * fraction;
314
+ } else {
315
+ result[i] = audioData[index];
316
+ }
317
+ }
318
+
319
+ return result;
320
+ }
321
+
322
+ function writeString(view, offset, string) {
323
+ for (let i = 0; i < string.length; i++) {
324
+ view.setUint8(offset + i, string.charCodeAt(i));
325
+ }
326
+ }
327
+
328
+ function floatTo16BitPCM(view, offset, input) {
329
+ for (let i = 0; i < input.length; i++, offset += 2) {
330
+ const s = Math.max(-1, Math.min(1, input[i]));
331
+ view.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
332
+ }
333
+ }
334
+
335
+ generateSpeechBtn.addEventListener('click', () => {
336
+ const text = ttsInput.value.trim();
337
+ if (!text) {
338
+ ttsStatus.textContent = 'Please enter some text';
339
+ ttsStatus.className = 'status error';
340
+ return;
341
+ }
342
+
343
+ if (ws && ws.readyState === WebSocket.OPEN) {
344
+ ws.send(JSON.stringify({ text }));
345
+ ttsStatus.textContent = 'Generating audio...';
346
+ ttsStatus.className = 'status';
347
+ } else {
348
+ ttsStatus.textContent = 'Connection lost. Reconnecting...';
349
+ ttsStatus.className = 'status error';
350
+ connectWebSocket();
351
+ }
352
+ });
353
+
354
+ downloadAudioBtn.addEventListener('click', () => {
355
+ if (currentAudioPath) {
356
+ const link = document.createElement('a');
357
+ link.href = currentAudioPath;
358
+ link.download = `combined_audio_${Date.now()}.wav`;
359
+ document.body.appendChild(link);
360
+ link.click();
361
+ document.body.removeChild(link);
362
+ }
363
+ });
364
+
365
+ // Clean up resources when leaving the page
366
+ window.addEventListener('beforeunload', () => {
367
+ if (audioContext) {
368
+ audioContext.close();
369
+ }
370
+ if (ws) {
371
+ ws.close();
372
+ }
373
+ // Clean up any blob URLs
374
+ if (currentAudioPath) {
375
+ URL.revokeObjectURL(currentAudioPath);
376
+ }
377
+ // Clear any pending audio buffers
378
+ audioBuffers = [];
379
+ pendingAudioPaths.clear();
380
+ });
381
+
382
+ // Initialize WebSocket connection
383
+ connectWebSocket();
384
+
385
+ async function fetchWithRetry(url, maxRetries = 3, retryDelay = 1000) {
386
+ for (let i = 0; i < maxRetries; i++) {
387
+ try {
388
+ const response = await fetch(url);
389
+ if (!response.ok) {
390
+ throw new Error(`HTTP error! status: ${response.status}`);
391
+ }
392
+ return response;
393
+ } catch (error) {
394
+ if (i === maxRetries - 1) throw error;
395
+ await new Promise(resolve => setTimeout(resolve, retryDelay));
396
+ }
397
+ }
398
+ }
web_tool/recorder.js ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class AudioRecorder {
2
+ constructor() {
3
+ this.mediaRecorder = null;
4
+ this.audioChunks = [];
5
+ this.isRecording = false;
6
+ this.audioContext = new (window.AudioContext || window.webkitAudioContext)();
7
+ }
8
+
9
+ async start() {
10
+ try {
11
+ const stream = await navigator.mediaDevices.getUserMedia({
12
+ audio: {
13
+ channelCount: 1,
14
+ sampleRate: 16000
15
+ }
16
+ });
17
+ this.mediaRecorder = new MediaRecorder(stream);
18
+ this.audioChunks = [];
19
+ this.isRecording = true;
20
+
21
+ this.mediaRecorder.addEventListener("dataavailable", (event) => {
22
+ this.audioChunks.push(event.data);
23
+ });
24
+
25
+ this.mediaRecorder.start();
26
+ return true;
27
+ } catch (error) {
28
+ console.error("Error starting recording:", error);
29
+ throw error;
30
+ }
31
+ }
32
+
33
+ async stop() {
34
+ return new Promise(async (resolve) => {
35
+ this.mediaRecorder.addEventListener("stop", async () => {
36
+ const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' });
37
+ this.isRecording = false;
38
+
39
+ // Convert to WAV with correct format
40
+ const arrayBuffer = await audioBlob.arrayBuffer();
41
+ const audioBuffer = await this.audioContext.decodeAudioData(arrayBuffer);
42
+
43
+ // Create WAV file
44
+ const wavBuffer = await this.createWAV(audioBuffer);
45
+ const wavBlob = new Blob([wavBuffer], { type: 'audio/wav' });
46
+
47
+ resolve(wavBlob);
48
+ });
49
+
50
+ this.mediaRecorder.stop();
51
+ this.mediaRecorder.stream.getTracks().forEach(track => track.stop());
52
+ });
53
+ }
54
+
55
+ async createWAV(audioBuffer) {
56
+ const numChannels = 1; // Mono
57
+ const sampleRate = 16000; // Target sample rate
58
+ const format = 1; // PCM
59
+ const bitDepth = 16;
60
+
61
+ // Resample if needed
62
+ let samples = audioBuffer.getChannelData(0);
63
+ if (audioBuffer.sampleRate !== sampleRate) {
64
+ samples = await this.resampleAudio(samples, audioBuffer.sampleRate, sampleRate);
65
+ }
66
+
67
+ const dataLength = samples.length * (bitDepth / 8);
68
+ const headerLength = 44;
69
+ const totalLength = headerLength + dataLength;
70
+
71
+ const buffer = new ArrayBuffer(totalLength);
72
+ const view = new DataView(buffer);
73
+
74
+ // Write WAV header
75
+ this.writeString(view, 0, 'RIFF');
76
+ view.setUint32(4, totalLength - 8, true);
77
+ this.writeString(view, 8, 'WAVE');
78
+ this.writeString(view, 12, 'fmt ');
79
+ view.setUint32(16, 16, true);
80
+ view.setUint16(20, format, true);
81
+ view.setUint16(22, numChannels, true);
82
+ view.setUint32(24, sampleRate, true);
83
+ view.setUint32(28, sampleRate * numChannels * (bitDepth / 8), true);
84
+ view.setUint16(32, numChannels * (bitDepth / 8), true);
85
+ view.setUint16(34, bitDepth, true);
86
+ this.writeString(view, 36, 'data');
87
+ view.setUint32(40, dataLength, true);
88
+
89
+ // Write audio data
90
+ this.floatTo16BitPCM(view, 44, samples);
91
+
92
+ return buffer;
93
+ }
94
+
95
+ writeString(view, offset, string) {
96
+ for (let i = 0; i < string.length; i++) {
97
+ view.setUint8(offset + i, string.charCodeAt(i));
98
+ }
99
+ }
100
+
101
+ floatTo16BitPCM(view, offset, input) {
102
+ for (let i = 0; i < input.length; i++, offset += 2) {
103
+ const s = Math.max(-1, Math.min(1, input[i]));
104
+ view.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
105
+ }
106
+ }
107
+
108
+ async resampleAudio(audioData, originalSampleRate, targetSampleRate) {
109
+ const originalLength = audioData.length;
110
+ const ratio = targetSampleRate / originalSampleRate;
111
+ const newLength = Math.round(originalLength * ratio);
112
+ const result = new Float32Array(newLength);
113
+
114
+ for (let i = 0; i < newLength; i++) {
115
+ const position = i / ratio;
116
+ const index = Math.floor(position);
117
+ const fraction = position - index;
118
+
119
+ if (index + 1 < originalLength) {
120
+ result[i] = audioData[index] * (1 - fraction) + audioData[index + 1] * fraction;
121
+ } else {
122
+ result[i] = audioData[index];
123
+ }
124
+ }
125
+
126
+ return result;
127
+ }
128
+
129
+ isActive() {
130
+ return this.isRecording;
131
+ }
132
+ }