test_ui / upgrade_codes /config_sync.py
britto224's picture
Upload 17 files
3fde5f3 verified
import os
import shutil
from upgrade_codes.upgrade_core.constants import (
USER_CONF,
BACKUP_CONF,
TEXTS,
ZH_DEFAULT_CONF,
EN_DEFAULT_CONF,
TEXTS_COMPARE,
TEXTS_MERGE,
)
import logging
from ruamel.yaml import YAML
from src.open_llm_vtuber.config_manager.utils import load_text_file_with_guess_encoding
from upgrade_codes.upgrade_core.comment_sync import CommentSynchronizer
from upgrade_codes.version_manager import VersionUpgradeManager
from upgrade_codes.upgrade_core.upgrade_utils import UpgradeUtility
from upgrade_codes.upgrade_core.comment_diff_fn import comment_diff_fn
from packaging import version
class ConfigSynchronizer:
def __init__(self, lang="en", logger=logging.getLogger(__name__)):
self.lang = lang
self.texts = TEXTS[lang]
self.default_path = ZH_DEFAULT_CONF if lang == "zh" else EN_DEFAULT_CONF
self.yaml = YAML()
self.yaml.preserve_quotes = True
self.user_path = USER_CONF
self.backup_path = BACKUP_CONF
self.texts_merge = TEXTS_MERGE.get(lang, TEXTS_MERGE["en"])
self.texts_compare = TEXTS_COMPARE.get(lang, TEXTS_COMPARE["en"])
self.logger = logger
self.upgrade_utils = UpgradeUtility(self.logger, self.lang)
def sync_user_config(self) -> None:
"""
Ensure the user configuration file exists and create a backup if necessary.
If the user config file does not exist, copy the default config.
"""
# Check if the user config file exists
if not os.path.exists(self.user_path):
self.logger.warning(self.texts["no_config"])
self.logger.warning(self.texts["copy_default_config"])
# Copy default config to user path
shutil.copy2(self.default_path, self.user_path)
return
# Create a backup of the user config file
self.backup_user_config()
def update_user_config(self) -> None:
"""
Perform the actual update operations on the user configuration file:
1. Compare and update configuration fields
2. Synchronize comments
3. Upgrade version if needed
"""
# Step 1: Update config fields
if not self.compare_field_keys():
self.merge_and_update_user_config()
else:
self.logger.info(self.texts["configs_up_to_date"])
# Step 2: Sync comments
if not self.compare_comments():
comment_sync = CommentSynchronizer(
self.default_path,
self.user_path,
self.logger,
self.yaml,
self.texts_compare,
)
comment_sync.sync()
else:
self.logger.info(self.texts_compare["comments_up_to_date"])
# Step 3: Determine whether upgrade is needed
new_version = self.get_latest_version()
old_version = self.get_old_version()
need_upgrade = old_version != new_version
# Step 4: Run upgrade if needed
if need_upgrade:
version_upgrade_manager = VersionUpgradeManager(self.lang, self.logger)
final_version = version_upgrade_manager.upgrade(old_version)
self.logger.info(
self.texts["version_upgrade_success"].format(
old=old_version, new=final_version
)
)
else:
self.logger.info(
self.texts["version_upgrade_none"].format(version=old_version)
)
def backup_user_config(self):
backup_path = os.path.abspath(self.backup_path)
self.logger.info(
self.texts["backup_user_config"].format(
user_conf=self.user_path, backup_conf=self.backup_path
)
)
self.logger.debug(self.texts["config_backup_path"].format(path=backup_path))
shutil.copy2(self.user_path, self.backup_path)
def merge_and_update_user_config(self):
try:
new_keys = self.merge_configs()
if new_keys:
self.logger.info(self.texts["merged_config_success"])
for key in new_keys:
self.logger.info(f" - {key}")
else:
self.logger.info(self.texts["merged_config_none"])
except Exception as e:
self.logger.error(self.texts["merge_failed"].format(error=e))
def merge_configs(self):
user_config = self.yaml.load(load_text_file_with_guess_encoding(self.user_path))
default_config = self.yaml.load(
load_text_file_with_guess_encoding(self.default_path)
)
new_keys = []
def merge(d_user, d_default, path=""):
for k, v in d_default.items():
current_path = f"{path}.{k}" if path else k
if k not in d_user:
d_user[k] = v
new_keys.append(current_path)
elif isinstance(v, dict) and isinstance(d_user.get(k), dict):
merge(d_user[k], v, current_path)
return d_user
merged = merge(user_config, default_config)
with open(self.user_path, "w", encoding="utf-8") as f:
self.yaml.dump(merged, f)
for key in new_keys:
self.logger.info(self.texts_merge["new_config_item"].format(key=key))
return new_keys
def collect_all_subkeys(self, d, base_path):
"""Collect all keys in the dictionary d, recursively, with base_path as the prefix."""
keys = []
# Only process if d is a dictionary
if isinstance(d, dict):
for key, value in d.items():
current_path = f"{base_path}.{key}" if base_path else key
keys.append(current_path)
if isinstance(value, dict):
keys.extend(self.collect_all_subkeys(value, current_path))
return keys
def get_missing_keys(self, user, default, path=""):
"""Recursively find keys in default that are missing in user."""
missing = []
for key, default_val in default.items():
current_path = f"{path}.{key}" if path else key
if key not in user:
missing.append(current_path)
else:
user_val = user[key]
if isinstance(default_val, dict):
if isinstance(user_val, dict):
missing.extend(
self.get_missing_keys(user_val, default_val, current_path)
)
else:
subtree_missing = self.collect_all_subkeys(
default_val, current_path
)
missing.extend(subtree_missing)
return missing
def get_extra_keys(self, user, default, path=""):
"""Recursively find keys in user that are not present in default."""
extra = []
for key, user_val in user.items():
current_path = f"{path}.{key}" if path else key
if key not in default:
# Only collect subkeys if the value is a dictionary
if isinstance(user_val, dict):
subtree_extra = self.collect_all_subkeys(user_val, current_path)
extra.extend(subtree_extra)
extra.append(current_path)
else:
default_val = default[key]
if isinstance(user_val, dict) and isinstance(default_val, dict):
extra.extend(
self.get_extra_keys(user_val, default_val, current_path)
)
elif isinstance(user_val, dict):
subtree_extra = self.collect_all_subkeys(user_val, current_path)
extra.extend(subtree_extra)
return extra
def delete_extra_keys(self):
"""Delete extra keys in user config that are not present in default config."""
user_config = self.yaml.load(load_text_file_with_guess_encoding(self.user_path))
default_config = self.yaml.load(
load_text_file_with_guess_encoding(self.default_path)
)
extra_keys = self.get_extra_keys(user_config, default_config)
def delete_key_by_path(config_dict, key_path):
keys = key_path.split(".")
sub_dict = config_dict
for k in keys[:-1]:
if k in sub_dict and isinstance(sub_dict[k], dict):
sub_dict = sub_dict[k]
else:
return False
return sub_dict.pop(keys[-1], None) is not None
deleted_keys = []
for key_path in extra_keys:
if delete_key_by_path(user_config, key_path):
deleted_keys.append(key_path)
with open(self.user_path, "w", encoding="utf-8") as f:
self.yaml.dump(user_config, f)
self.logger.info(
self.texts_compare["extra_keys_deleted_count"].format(
count=len(deleted_keys)
)
)
for key in deleted_keys:
self.logger.info(
self.texts_compare["extra_keys_deleted_item"].format(key=key)
)
def compare_field_keys(self) -> bool:
"""Compare field structure differences (missing/extra keys)"""
def field_compare_fn(user, default):
missing = self.get_missing_keys(user, default)
extra = self.get_extra_keys(user, default)
if missing:
self.logger.warning(
self.texts_compare["missing_keys"].format(keys=", ".join(missing))
)
if extra:
self.logger.warning(
self.texts_compare["extra_keys"].format(keys=", ".join(extra))
)
self.delete_extra_keys()
return (not missing, missing + extra)
return self.upgrade_utils.compare_dicts(
name="keys",
get_a=lambda: self.yaml.load(
load_text_file_with_guess_encoding(self.user_path)
),
get_b=lambda: self.yaml.load(
load_text_file_with_guess_encoding(self.default_path)
),
compare_fn=field_compare_fn,
)
def compare_comments(self) -> bool:
return self.upgrade_utils.compare_dicts(
name="comments",
get_a=lambda: load_text_file_with_guess_encoding(self.user_path),
get_b=lambda: load_text_file_with_guess_encoding(self.default_path),
compare_fn=comment_diff_fn,
)
def get_latest_version(self):
with open(self.default_path, "r", encoding="utf-8") as f:
default_config = self.yaml.load(f)
return default_config.get("system_config", {}).get("conf_version", "")
def get_old_version(self) -> str:
"""
Extract the old version from backup config.
If missing or too old (< v1.1.1), fallback to v1.1.1.
"""
fallback_version = "v1.1.1"
try:
yaml = YAML()
with open(BACKUP_CONF, "r", encoding="utf-8") as f:
backup_conf = yaml.load(f)
raw_version = backup_conf.get("system_config", {}).get(
"conf_version", fallback_version
)
if version.parse(raw_version) < version.parse(fallback_version):
self.logger.warning(
self.texts["version_too_old"].format(
found=raw_version, adjusted=fallback_version
)
)
return fallback_version
self.logger.info(
self.texts["backup_used_version"].format(backup_version=raw_version)
)
return raw_version
except Exception as e:
self.logger.warning(
self.texts["backup_read_error"].format(
version=fallback_version, error=e
)
)
return fallback_version