| import os |
| import time |
| import math |
| import copy |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch.nn.utils import prune |
| from transformers import ( |
| AutoTokenizer, |
| AutoConfig, |
| DataCollatorForLanguageModeling, |
| Trainer, |
| TrainingArguments, |
| AutoModelForCausalLM, |
| AutoModel, |
| EarlyStoppingCallback, |
| pipeline, |
| get_scheduler, |
| logging as hf_logging |
| ) |
| try: |
| from peft import PeftModel, LoraConfig, get_peft_model, TaskType, PeftConfig |
| _peft_installed = True |
| except ImportError: |
| _peft_installed = False |
| PeftModel = None |
| LoraConfig = None |
| get_peft_model = None |
| TaskType = None |
| PeftConfig = None |
|
|
| from datasets import load_dataset, interleave_datasets, concatenate_datasets, Dataset, Features, Value, IterableDataset, DatasetDict |
| from huggingface_hub import login, create_repo, HfApi, hf_hub_download |
| import wandb |
| import gradio as gr |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch |
| import re |
| import json |
| import gc |
| from accelerate import Accelerator |
| import logging |
| import traceback |
| from collections import Counter, OrderedDict |
| import requests |
| import gzip |
| import inspect |
| import shutil |
| from functools import partial |
| import types |
| import psutil |
|
|
| hf_logging.set_verbosity_error() |
| logging.getLogger("datasets").setLevel(logging.ERROR) |
| logging.getLogger("huggingface_hub").setLevel(logging.ERROR) |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| padding = True |
| truncation = True |
| TOKENIZERS_PARALLELISM = True |
| os.environ["TOKENIZERS_PARALLELISM"] = str(TOKENIZERS_PARALLELISM) |
|
|
| BATCH_SIZE = 8 |
| LEARNING_RATE = 1.5e-4 |
| EPOCHS = 1 |
| MAX_STEPS = 1 |
| USE_CPU = False |
| NUM_CPU_CORES = -1 |
| MERGE_ALPHA = 0.7 |
| CONTEXT_LENGTH = 256 |
| HEADS = 4 |
| DIMENSIONS = 256 |
| LAYERS = 1 |
| INTERMEDIATE_SIZE = 1024 |
| USE_WANDB = False |
| ACTIVATION_FUNCTIONS = { "relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU, "mish": nn.Mish, "leaky_relu": nn.LeakyReLU, "elu": nn.ELU, "relu6": nn.ReLU6, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "identity": nn.Identity } |
| DEFAULT_ACTIVATION_FUNCTION = "gelu" |
| OPTIMIZERS = { |
| "adamw_torch": torch.optim.AdamW, |
| "adam_torch": torch.optim.Adam, |
| "sgd": torch.optim.SGD, |
| "adamax": torch.optim.Adamax, |
| "adagrad": torch.optim.Adagrad, |
| "rmsprop": torch.optim.RMSprop |
| } |
| DEFAULT_OPTIMIZER = "adamw_torch" |
| PRUNING_AMOUNT = 0.2 |
| QUANTIZATION_MODES = ['float32', 'float16', 'bfloat16'] |
| DEFAULT_QUANTIZATION = 'float32' |
| SCHEDULER_TYPES = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] |
| DEFAULT_SCHEDULER = "cosine" |
| GRADIENT_ACCUMULATION_STEPS = 1 |
| EVAL_STEPS = 500 |
| SAVE_STEPS = 500 |
| LOGGING_STEPS = 100 |
| EARLY_STOPPING_PATIENCE = 5 |
| LOAD_BEST_MODEL_AT_END = True |
| METRIC_FOR_BEST_MODEL = "eval_loss" |
|
|
| AVAILABLE_MODALITIES = ['Image', 'Audio'] |
| MODALITY_ENCODERS = { |
| 'Image': 'google/vit-base-patch16-224-in21k', |
| 'Audio': 'openai/whisper-base' |
| } |
| DEFAULT_PEFT_CONFIG_DICT = { |
| "task_type": TaskType.CAUSAL_LM if _peft_installed else None, |
| "inference_mode": False, |
| "r": 8, |
| "lora_alpha": 32, |
| "lora_dropout": 0.1, |
| "target_modules": None |
| } if _peft_installed else {} |
|
|
| global_model = None |
| global_tokenizer = None |
| global_pipe = None |
| original_num_layers_global = LAYERS |
| config = None |
| target_layers = LAYERS |
| current_peft_config = copy.deepcopy(DEFAULT_PEFT_CONFIG_DICT) if _peft_installed else {} |
|
|
| RAM_LIMIT_PERCENT = 85.0 |
| DISK_LIMIT_GB = 5.0 |
| BYPASS_RESOURCE_LIMITS = False |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, device=None, dtype=None): |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| super().__init__() |
| self.dim = dim |
| self.eps = eps |
| self.elementwise_affine = elementwise_affine |
| if self.elementwise_affine: |
| self.weight = nn.Parameter(torch.empty(dim, **factory_kwargs)) |
| else: |
| self.register_parameter('weight', None) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| if self.elementwise_affine: |
| nn.init.ones_(self.weight) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| output = self._norm(x.float()).type_as(x) |
| if self.elementwise_affine: |
| output = output * self.weight |
| return output |
|
|
| def extra_repr(self): |
| return f'{self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' |
|
|
| def activation_quant(x): |
| scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) |
| y = (x * scale).round().clamp(-128, 127) / scale |
| return y |
|
|
| def weight_quant(w): |
| scale = 1.0 / w.abs().mean().clamp(min=1e-5) |
| u = (w * scale).round().clamp(-1, 1) / scale |
| return u |
|
|
| class BitLinear(nn.Linear): |
| def forward(self, x): |
| w = self.weight |
| device = w.device |
| if x.device != device: |
| x = x.to(device) |
|
|
| rms_norm_module = RMSNorm(x.shape[-1], eps=1e-6, elementwise_affine=False, device=device, dtype=x.dtype) |
| x_norm = rms_norm_module(x) |
|
|
| x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() |
| w_quant = w + (weight_quant(w) - w).detach() |
|
|
| bias = self.bias.to(w_quant.dtype) if self.bias is not None else None |
| output = F.linear(x_quant, w_quant, None) |
|
|
| if bias is not None: |
| output = output + bias.to(output.dtype) |
|
|
| return output |
|
|
| def to(self, *args, **kwargs): |
| super().to(*args, **kwargs) |
| if self.bias is not None: |
| self.bias = self.bias.to(*args, **kwargs) |
| return self |
|
|
|
|
| class BypassLayerNorm(nn.Module): |
| def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| super().__init__() |
| if isinstance(normalized_shape, int): |
| self.normalized_shape = (normalized_shape,) |
| else: |
| self.normalized_shape = tuple(normalized_shape) |
| self.eps = eps |
| self.elementwise_affine = elementwise_affine |
| if self.elementwise_affine: |
| self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) |
| self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) |
| else: |
| self.register_parameter('weight', None) |
| self.register_parameter('bias', None) |
| self.bypass = False |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| if self.elementwise_affine: |
| nn.init.ones_(self.weight) |
| nn.init.zeros_(self.bias) |
|
|
| def forward(self, x): |
| if self.bypass: |
| return x |
| original_dtype = x.dtype |
| if original_dtype not in [torch.float32, torch.float16, torch.bfloat16]: |
| x = x.float() |
|
|
| output = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| return output.to(original_dtype) |
|
|
| def extra_repr(self) -> str: |
| return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, bypass={self.bypass}' |
|
|
| class BypassDropout(nn.Module): |
| def __init__(self, p=0.5, inplace=False): |
| super().__init__() |
| self.p = p |
| self.inplace = inplace |
| self.bypass = False |
|
|
| def forward(self, x): |
| if self.bypass or not self.training or self.p == 0: |
| return x |
| return F.dropout(x, self.p, self.training, self.inplace) |
|
|
| def extra_repr(self) -> str: |
| return f'p={self.p}, inplace={self.inplace}, bypass={self.bypass}' |
|
|
| def get_device(): |
| if torch.cuda.is_available() and not USE_CPU: |
| return torch.device("cuda") |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and not USE_CPU: |
| logging.info("MPS backend detected on Mac. Note: MPS support is experimental and may have limitations.") |
| return torch.device("mps") |
| else: |
| if not USE_CPU: |
| logging.warning("CUDA/MPS not available or USE_CPU=True. Falling back to CPU.") |
| return torch.device("cpu") |
|
|
| def clean_memory(): |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| logging.debug("Cleaned memory.") |
|
|
| def check_resources(ram_limit_pct=RAM_LIMIT_PERCENT, disk_limit_gb=DISK_LIMIT_GB): |
| if BYPASS_RESOURCE_LIMITS: |
| logging.info("Resource limit checks bypassed.") |
| return True, "Resource checks bypassed." |
|
|
| try: |
| ram = psutil.virtual_memory() |
| ram_used_pct = ram.percent |
| ram_ok = ram_used_pct < ram_limit_pct |
|
|
| disk = psutil.disk_usage('/') |
| disk_free_gb = disk.free / (1024**3) |
| disk_ok = disk_free_gb > disk_limit_gb |
|
|
| status_msg = (f"Resource Check: RAM Used: {ram_used_pct:.1f}% (Limit: <{ram_limit_pct}%), " |
| f"Disk Free: {disk_free_gb:.1f}GB (Limit: >{disk_limit_gb}GB).") |
|
|
| if ram_ok and disk_ok: |
| logging.info(status_msg + " Status: OK") |
| return True, status_msg + " Status: OK" |
| else: |
| warning_msg = status_msg + " Status: LIMIT EXCEEDED!" |
| logging.warning(warning_msg) |
| return False, warning_msg |
|
|
| except Exception as e: |
| logging.error(f"Failed to check resources: {e}") |
| return True, f"Resource check failed: {e}" |
|
|
| def initialize_config_flags(existing_config=None): |
| if existing_config is None: |
| from transformers import PretrainedConfig |
| config_obj = PretrainedConfig() |
| elif isinstance(existing_config, dict): |
| from transformers import PretrainedConfig |
| try: |
| config_obj = PretrainedConfig(**existing_config) |
| except Exception as e: |
| logging.warning(f"Could not initialize PretrainedConfig from dict, using default. Error: {e}") |
| config_obj = PretrainedConfig() |
| else: |
| config_obj = existing_config |
|
|
| default_flags = { |
| "reduced_layers": False, "original_num_layers": None, "removed_bias": False, "untied_embeddings": False, |
| "limits_configured": False, "qa_restrictions_removed": False, "additional_mechanisms_applied": False, |
| "safety_settings_enabled": True, "perfect_precision_recovered": False, "token_gen_speed_maximized": False, |
| "coherence_improvement_enabled": False, "inconsistencies_biases_removed": False, |
| "quantization_applied": False, "quantization_mode": DEFAULT_QUANTIZATION, |
| "layer_norm_bypassed": False, "replaced_layer_norm": False, "dropout_bypassed": False, "replaced_dropout": False, |
| "activation_function_swapped": False, "current_activation_function": DEFAULT_ACTIVATION_FUNCTION, |
| "embedding_normalized": False, "gradient_clipping_disabled": False, "weight_decay_disabled": False, |
| "lr_scheduler_disabled": False, "bitnet_applied": False, "gradient_checkpointing_enabled": False, |
| "pruning_applied": False, "pruning_amount": None, "frozen_layers": None, |
| "enhanced_security_enabled": False, "debug_mode_enabled": False, "auto_optimization_enabled": False, |
| "internal_logging_enabled": False, "drift_detection_enabled": False, "ultra_fast_mode": False, |
| "optimizer": DEFAULT_OPTIMIZER, "rms_norm_applied": False, "layerdrop_enabled": False, "layerdrop_prob": 0.0, |
| "lora_merged": False, "lora_adapter_path": None, "knowledge_distillation_setup": False, "kd_num_labels": None, |
| "reward_modeling_setup": False, "rm_num_outputs": 1, |
| "swa_applied": False, "knowledge_edited": False, "head_pruning_applied": False, "qat_applied": False, |
| "architecture_merged": False, "weight_init_applied": False, "gradient_noise_applied": False, |
| "rope_scaling_type": None, "rope_scaling_factor": None, "sliding_window_size": None, "attention_variant": None, |
| "response_filters": True, "harassment_filter": True, "hate_filter": True, "sexually_explicit_filter": True, |
| "dangerous_content_filter": True, "civic_integrity_filter": True, "code_filter": True, |
| "medical_advice_filter": True, "legal_advice_filter": True, "financial_advice_filter": True, |
| "pii_filter": True, "political_filter": True, "religious_filter": True, "profanity_filter": True, |
| "stereotype_filter": True, "misinfo_filter": True, "self_harm_filter": True, "personal_attack_filter": True, |
| "toxicity_filter": True, "spam_filter": True, "off_topic_filter": True, "tone_filter": True, |
| "min_max_length_filter": True, "repetition_filter_enabled": False, "factuality_filter_enabled": False, |
| "baseline_distribution": None, "remove_censorship": False, "no_response_filters": False, |
| "no_advert_warning": False, "no_limits": False, "knowledge_date": None, "cutoff_date": None, |
| "max_input_tokens": None, "max_output_tokens": None, |
| "multimodal_applied": False, "supported_modalities": [], "modality_encoders": {}, "modality_projection_dim": None, "modality_special_tokens": {}, |
| "use_flash_attention_2": getattr(config_obj, 'attn_implementation', None) == 'flash_attention_2', |
| "attn_implementation": getattr(config_obj, 'attn_implementation', 'auto'), |
| "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS, |
| "peft_adapter_added": False, "peft_config": None |
| } |
|
|
| for flag, default_value in default_flags.items(): |
| if not hasattr(config_obj, flag): |
| setattr(config_obj, flag, default_value) |
|
|
| if getattr(config_obj, 'attn_implementation', 'auto') == 'flash_attention_2': |
| config_obj.use_flash_attention_2 = True |
| else: |
| config_obj.use_flash_attention_2 = False |
|
|
| if getattr(config_obj, 'quantization_mode', DEFAULT_QUANTIZATION) == 'float32': |
| config_obj.quantization_applied = False |
| config_obj.perfect_precision_recovered = True |
| else: |
| config_obj.quantization_applied = True |
| config_obj.perfect_precision_recovered = False |
|
|
| if _peft_installed and isinstance(existing_config, PeftConfig): |
| config_obj.peft_adapter_added = True |
| config_obj.peft_config = existing_config.to_dict() |
|
|
| return config_obj |
|
|
| def _recursive_setattr(obj, attr_str, value): |
| parts = attr_str.split('.') |
| obj_to_set = obj |
| try: |
| for part in parts[:-1]: |
| if not hasattr(obj_to_set, part): |
| logging.warning(f"Intermediate attribute {part} not found in {attr_str} for object {type(obj_to_set)}") |
| return False |
| obj_to_set = getattr(obj_to_set, part) |
| if obj_to_set is None: |
| logging.warning(f"Intermediate attribute {part} is None in {attr_str}") |
| return False |
| if hasattr(obj_to_set, parts[-1]): |
| setattr(obj_to_set, parts[-1], value) |
| return True |
| else: |
| logging.warning(f"Final attribute {parts[-1]} not found in {attr_str} on object {type(obj_to_set)}") |
| return False |
| except AttributeError as e: |
| logging.error(f"AttributeError setting {attr_str}: {e}") |
| return False |
| except Exception as e: |
| logging.error(f"Generic error setting attribute {attr_str}: {e}") |
| return False |
|
|
|
|
| def _get_encoder_hidden_size(encoder_model_id, trust_remote_code=True): |
| try: |
| encoder_config = AutoConfig.from_pretrained(encoder_model_id, trust_remote_code=trust_remote_code) |
|
|
| possible_keys = ['hidden_size', 'd_model', 'embed_dim'] |
| for key in possible_keys: |
| if hasattr(encoder_config, key): |
| size = getattr(encoder_config, key) |
| if isinstance(size, int) and size > 0: |
| return size |
|
|
| nested_configs = ['vision_config', 'audio_config', 'encoder'] |
| for nested_name in nested_configs: |
| if hasattr(encoder_config, nested_name): |
| nested_cfg = getattr(encoder_config, nested_name) |
| if nested_cfg and isinstance(nested_cfg, object): |
| for key in possible_keys: |
| if hasattr(nested_cfg, key): |
| size = getattr(nested_cfg, key) |
| if isinstance(size, int) and size > 0: |
| return size |
|
|
| raise ValueError(f"Could not automatically determine hidden/embedding size for encoder {encoder_model_id}. Checked attributes: {possible_keys} and nested configs: {nested_configs}.") |
| except Exception as e: |
| logging.error(f"Failed to get config or hidden size for encoder {encoder_model_id}: {e}") |
| raise ValueError(f"Failed to get config or hidden size for encoder {encoder_model_id}") from e |
|
|
|
|
| def convert_to_bitnet(model, config, copy_weights=True): |
| if not hasattr(nn, 'RMSNorm'): |
| logging.warning("BitNet conversion requires RMSNorm, which might not be standard. Using custom RMSNorm.") |
|
|
| device = get_device() |
| converted_count = 0 |
| modules_to_process = list(model.named_modules()) |
| processed_names = set() |
|
|
| with torch.no_grad(): |
| for name, module in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| is_target_linear = isinstance(module, nn.Linear) and ( |
| any(sub in name.lower() for sub in ["attn", "mlp", "fc", "dense", "query", "key", "value", "out", "wi", "wo"]) |
| and "norm" not in name.lower() |
| and "embedding" not in name.lower() |
| ) |
|
|
| if is_target_linear: |
| try: |
| current_dtype = module.weight.dtype if hasattr(module, 'weight') and module.weight is not None else torch.float32 |
| has_bias = module.bias is not None |
|
|
| bl = BitLinear(module.in_features, module.out_features, has_bias).to(device=device, dtype=current_dtype) |
|
|
| if copy_weights and hasattr(module, 'weight') and module.weight is not None: |
| if bl.weight.shape == module.weight.shape: |
| bl.weight.data.copy_(module.weight.data) |
| else: |
| logging.warning(f"Shape mismatch for weight {name}: Expected {bl.weight.shape}, got {module.weight.shape}. Skipping weight copy.") |
|
|
| if has_bias and bl.bias is not None: |
| if bl.bias.shape == module.bias.shape: |
| bl.bias.data.copy_(module.bias.data) |
| else: |
| logging.warning(f"Shape mismatch for bias {name}: Expected {bl.bias.shape}, got {module.bias.shape}. Skipping bias copy.") |
| elif not has_bias and bl.bias is not None: |
| nn.init.zeros_(bl.bias) |
| elif has_bias and bl.bias is None: |
| logging.warning(f"Module {name} had bias, but BitLinear does not. Bias info lost.") |
|
|
| elif not copy_weights: |
| nn.init.xavier_uniform_(bl.weight) |
| if bl.bias is not None: |
| nn.init.zeros_(bl.bias) |
|
|
| if _recursive_setattr(model, name, bl): |
| converted_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Converted layer {name} to BitLinear.") |
| else: |
| logging.warning(f"Failed to set BitLinear for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error replacing {name} with BitLinear: {e} \n{traceback.format_exc()}") |
| processed_names.add(name) |
|
|
| if converted_count > 0: |
| config.bitnet_applied = True |
| logging.info(f"Applied BitNet conversion to {converted_count} linear layers.") |
| return f"Applied BitNet conversion to {converted_count} linear layers." |
| else: |
| logging.info("No applicable linear layers found or converted for BitNet.") |
| config.bitnet_applied = False |
| return "No applicable layers found for BitNet conversion." |
|
|
| def revert_bitnet(model, config): |
| if not getattr(config, 'bitnet_applied', False): |
| return "BitNet not applied according to config, nothing to revert." |
|
|
| device = get_device() |
| model.to(device) |
| reverted_count = 0 |
| modules_to_process = list(model.named_modules()) |
| processed_names = set() |
|
|
| with torch.no_grad(): |
| for name, module in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if isinstance(module, BitLinear): |
| try: |
| dtype = module.weight.dtype if hasattr(module, 'weight') and module.weight is not None else torch.float32 |
| has_bias = module.bias is not None |
|
|
| lin = nn.Linear(module.in_features, module.out_features, bias=has_bias).to(device, dtype=dtype) |
|
|
| if hasattr(module, 'weight') and module.weight is not None: |
| if lin.weight.shape == module.weight.shape: |
| lin.weight.data.copy_(module.weight.data) |
| else: |
| logging.warning(f"Shape mismatch reverting weight {name}: Expected {lin.weight.shape}, got {module.weight.shape}. Re-initializing nn.Linear weight.") |
| nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5)) |
| else: |
| nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5)) |
|
|
| if has_bias and lin.bias is not None: |
| if lin.bias.shape == module.bias.shape: |
| lin.bias.data.copy_(module.bias.data) |
| else: |
| logging.warning(f"Shape mismatch reverting bias {name}: Expected {lin.bias.shape}, got {module.bias.shape}. Re-initializing nn.Linear bias.") |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(lin.weight) |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| nn.init.uniform_(lin.bias, -bound, bound) |
| elif has_bias and lin.bias is None: |
| logging.error(f"BitLinear layer {name} had bias, but reverted nn.Linear does not. Reversion failed for bias.") |
| elif not has_bias and lin.bias is not None: |
| logging.error(f"BitLinear layer {name} lacked bias, but reverted nn.Linear has one. Setting to zero.") |
| nn.init.zeros_(lin.bias) |
|
|
| if _recursive_setattr(model, name, lin): |
| reverted_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Reverted BitLinear layer {name} to nn.Linear.") |
| else: |
| logging.warning(f"Failed to revert BitLinear for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error reverting BitLinear {name}: {e} \n{traceback.format_exc()}") |
| processed_names.add(name) |
|
|
| if reverted_count > 0: |
| config.bitnet_applied = False |
| logging.info(f"Reverted {reverted_count} BitNet layers to standard nn.Linear.") |
| return f"Reverted {reverted_count} BitNet layers." |
| else: |
| config.bitnet_applied = False |
| logging.info("No BitNet layers found to revert, but flag was true. Resetting flag.") |
| return "No BitNet layers found to revert." |
|
|
|
|
| def _find_decoder_layers_module(model): |
| prefixes = [ |
| ('model.decoder', 'layers'), |
| ('model.layers', None), |
| ('transformer.h', None), |
| ('transformer.blocks', None), |
| ('encoder.layer', None), |
| ('model.encoder.layers', None), |
| ('', 'layers'), |
| ('', 'h'), |
| ('model', 'layers'), |
| ('decoder.block', None), |
| ('decoder.layers', None) |
| ] |
| if hasattr(model, 'model'): |
| base_obj = model.model |
| else: |
| base_obj = model |
|
|
| direct_attrs = ['layers', 'h', 'blocks', 'block'] |
| for attr in direct_attrs: |
| if hasattr(base_obj, attr): |
| layer_list = getattr(base_obj, attr) |
| if isinstance(layer_list, nn.ModuleList) and len(layer_list) > 0: |
| logging.info(f"Found layer list at 'model.{attr}' or '{attr}' with {len(layer_list)} layers.") |
| return base_obj, attr, layer_list |
| elif isinstance(layer_list, (list, tuple)) and len(layer_list) > 0 and isinstance(layer_list[0], nn.Module): |
| logging.warning(f"Found layers as list/tuple at 'model.{attr}' or '{attr}'. Converting to ModuleList.") |
| setattr(base_obj, attr, nn.ModuleList(layer_list)) |
| return base_obj, attr, getattr(base_obj, attr) |
|
|
| for p_base, attr_name_explicit in prefixes: |
| mod = model |
| valid_path = True |
| if p_base: |
| for comp in p_base.split('.'): |
| if not hasattr(mod, comp): |
| valid_path = False |
| break |
| mod = getattr(mod, comp) |
| if mod is None: |
| valid_path = False |
| break |
| if not valid_path: |
| continue |
|
|
| attrs_to_check = [attr_name_explicit] if attr_name_explicit else ['layers', 'h', 'blocks', 'block', 'layer'] |
|
|
| for attr in attrs_to_check: |
| if hasattr(mod, attr): |
| layer_list = getattr(mod, attr) |
| if isinstance(layer_list, nn.ModuleList) and len(layer_list) > 0: |
| logging.info(f"Found layer list at '{p_base}.{attr}' with {len(layer_list)} layers.") |
| return mod, attr, layer_list |
| elif isinstance(layer_list, (list, tuple)) and len(layer_list) > 0 and isinstance(layer_list[0], nn.Module): |
| logging.warning(f"Found layers as list/tuple at '{p_base}.{attr}'. Converting to ModuleList.") |
| setattr(mod, attr, nn.ModuleList(layer_list)) |
| return mod, attr, getattr(mod, attr) |
|
|
| logging.warning("Could not automatically find the standard decoder/transformer layer list module.") |
| return None, None, None |
|
|
|
|
| def _reduce_layers_to_one(base_model, config, target_layers=1): |
| if not isinstance(target_layers, int) or target_layers < 1: |
| logging.error(f"Invalid target_layers value: {target_layers}. Must be an integer >= 1.") |
| return f"Error: Target layers must be >= 1, got {target_layers}." |
|
|
| layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model) |
|
|
| if layer_module and layer_attr and current_layers is not None: |
| current_len = len(current_layers) |
| if current_len <= 0: |
| logging.warning("Found layer attribute but the ModuleList is empty. Cannot reduce.") |
| return "Warning: Layer list found but it's empty. Cannot reduce." |
|
|
| if current_len > target_layers: |
| logging.info(f"Reducing layers: {current_len} -> {target_layers}...") |
| original_layer_count = current_len |
|
|
| if not hasattr(config, 'original_num_layers') or config.original_num_layers is None or config.original_num_layers < current_len: |
| config.original_num_layers = original_layer_count |
| logging.info(f"Stored/Updated original layer count in config: {original_layer_count}") |
|
|
| new_layer_list = nn.ModuleList(current_layers[:target_layers]) |
| setattr(layer_module, layer_attr, new_layer_list) |
|
|
| config.num_hidden_layers = target_layers |
| config.reduced_layers = True |
| if hasattr(config, 'n_layer'): config.n_layer = target_layers |
| if hasattr(config, 'num_layers'): config.num_layers = target_layers |
| if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = target_layers |
| if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = target_layers |
|
|
|
|
| logging.info(f"Successfully reduced layers to {target_layers}.") |
| clean_memory() |
| return f"Layers reduced to {target_layers}. Original count was: {original_layer_count}." |
| elif current_len == target_layers: |
| logging.info(f"Model already has {current_len} layers, matching the target {target_layers}. No reduction needed.") |
| config.reduced_layers = False if current_len == getattr(config, 'original_num_layers', current_len) else True |
| config.num_hidden_layers = current_len |
| if hasattr(config, 'n_layer'): config.n_layer = current_len |
| if hasattr(config, 'num_layers'): config.num_layers = current_len |
| if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len |
| if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len |
| return f"Model already has {current_len} layers (target {target_layers}). No reduction performed." |
| else: |
| logging.info(f"Model has {current_len} layers, which is less than the target {target_layers}. No reduction needed.") |
| config.reduced_layers = True |
| config.num_hidden_layers = current_len |
| if hasattr(config, 'n_layer'): config.n_layer = current_len |
| if hasattr(config, 'num_layers'): config.num_layers = current_len |
| if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len |
| if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len |
| return f"Model already has {current_len} layers (< target {target_layers}). No reduction performed." |
| else: |
| logging.warning("Could not find standard layer structure for reduction.") |
| config.reduced_layers = False |
| return "Warning: Could not find standard layer structure for reduction." |
|
|
|
|
| def _enable_full_layers(base_model, config, original_num_layers=None): |
| if not getattr(config, 'reduced_layers', False): |
| layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model) |
| current_len = len(current_layers) if current_layers is not None else 0 |
| orig_len_config = getattr(config, 'original_num_layers', None) |
| if current_len > 0 and orig_len_config is not None and current_len == orig_len_config: |
| config.reduced_layers = False |
| return "Layers already seem to be at the original count. Flag corrected if necessary." |
| else: |
| return "Layers not previously reduced according to config flag, or cannot verify current/original counts." |
|
|
| orig_layers = original_num_layers if original_num_layers is not None else getattr(config, 'original_num_layers', None) |
|
|
| if orig_layers is None: |
| global original_num_layers_global |
| orig_layers = original_num_layers_global |
| if orig_layers is not None: |
| logging.warning(f"Using globally stored original layer count: {orig_layers} as it was missing in config.") |
| config.original_num_layers = orig_layers |
| else: |
| logging.error("Cannot restore layers: Original layer count is missing from config and global state.") |
| return "Error: Cannot revert - Original layer count unknown." |
|
|
| if not isinstance(orig_layers, int) or orig_layers <= 0: |
| logging.error(f"Cannot restore layers: Invalid original layer count found ({orig_layers}).") |
| return f"Error: Cannot revert - Invalid original layer count ({orig_layers})." |
|
|
| layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model) |
|
|
| if layer_module and layer_attr and current_layers is not None: |
| current_len = len(current_layers) |
| if current_len < orig_layers: |
| logging.info(f"Restoring layers: {current_len} -> {orig_layers}..."); T = time.time() |
| try: |
| if current_len == 0: |
| logging.error("Cannot restore layers: No existing layers found to copy structure from.") |
| return "Error: Cannot restore layers - no template layer available." |
|
|
| device = next(iter(current_layers[0].parameters()), torch.tensor([])).device |
| template_layer = current_layers[0].to('cpu') |
|
|
| layers_to_add = [] |
| num_layers_to_add = orig_layers - current_len |
| logging.info(f"Need to add {num_layers_to_add} layers.") |
|
|
| for i in range(num_layers_to_add): |
| new_layer = copy.deepcopy(template_layer) |
| for _, sub_module in new_layer.named_modules(): |
| if hasattr(sub_module, 'reset_parameters'): |
| try: |
| sub_module.reset_parameters() |
| except Exception as reset_e: |
| logging.warning(f"Could not reset parameters for submodule {sub_module} in new layer {i}: {reset_e}") |
| elif isinstance(sub_module, nn.Linear): |
| nn.init.kaiming_uniform_(sub_module.weight, a=math.sqrt(5)) |
| if sub_module.bias is not None: |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(sub_module.weight) |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| nn.init.uniform_(sub_module.bias, -bound, bound) |
| elif isinstance(sub_module, nn.Embedding): |
| nn.init.normal_(sub_module.weight) |
| if sub_module.padding_idx is not None: |
| with torch.no_grad(): sub_module.weight[sub_module.padding_idx].fill_(0) |
| elif isinstance(sub_module, (nn.LayerNorm, RMSNorm, BypassLayerNorm)): |
| if sub_module.elementwise_affine: |
| if hasattr(sub_module, 'weight') and sub_module.weight is not None: nn.init.ones_(sub_module.weight) |
| if hasattr(sub_module, 'bias') and sub_module.bias is not None: nn.init.zeros_(sub_module.bias) |
|
|
| new_layer = new_layer.to(device) |
| layers_to_add.append(new_layer) |
|
|
| full_layer_list = nn.ModuleList(list(current_layers) + layers_to_add) |
| setattr(layer_module, layer_attr, full_layer_list) |
|
|
| config.num_hidden_layers = orig_layers |
| config.reduced_layers = False |
| if hasattr(config, 'n_layer'): config.n_layer = orig_layers |
| if hasattr(config, 'num_layers'): config.num_layers = orig_layers |
| if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = orig_layers |
| if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = orig_layers |
|
|
|
|
| msg = f"Restored layer structure to {orig_layers} layers in {time.time()-T:.2f}s." |
| logging.info(msg) |
| clean_memory() |
| return msg |
|
|
| except Exception as e: |
| logging.error(f"Error restoring layers: {e}\n{traceback.format_exc()}") |
| setattr(layer_module, layer_attr, current_layers) |
| config.num_hidden_layers = current_len |
| config.reduced_layers = True |
| if hasattr(config, 'n_layer'): config.n_layer = current_len |
| if hasattr(config, 'num_layers'): config.num_layers = current_len |
| if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len |
| if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len |
|
|
| return f"Error restoring layers: {e}. State might be inconsistent." |
| else: |
| config.reduced_layers = False |
| config.num_hidden_layers = current_len |
| if hasattr(config, 'n_layer'): config.n_layer = current_len |
| if hasattr(config, 'num_layers'): config.num_layers = current_len |
| if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len |
| if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len |
|
|
|
|
| msg = f"Model already has {current_len} layers (>= original {orig_layers}). No restoration needed. Corrected flags if necessary." |
| logging.info(msg) |
| return msg |
| elif layer_module and layer_attr and current_layers is None: |
| logging.warning(f"Layer attribute '{layer_attr}' exists but is None or invalid. Cannot restore layers.") |
| return "Warning: Layer attribute found but invalid. Cannot restore layers." |
| else: |
| logging.warning("Could not find standard layer structure for restoration.") |
| return "Warning: Could not find standard layer structure for restoration." |
|
|
| def _replace_linear_without_bias(module, config): |
| device = get_device() |
| replaced_count = 0 |
| modules_to_process = list(module.named_modules()) |
| processed_names = set() |
|
|
| for name, child in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if isinstance(child, nn.Linear) and child.bias is not None: |
| try: |
| dtype = child.weight.dtype |
| current_device = child.weight.device |
|
|
| new_linear = nn.Linear(child.in_features, child.out_features, bias=False).to(device=current_device, dtype=dtype) |
|
|
| with torch.no_grad(): |
| if new_linear.weight.shape == child.weight.shape: |
| new_linear.weight.copy_(child.weight) |
| else: |
| logging.warning(f"Shape mismatch removing bias for weight {name}: Expected {new_linear.weight.shape}, got {child.weight.shape}. Re-initializing.") |
| nn.init.kaiming_uniform_(new_linear.weight, a=math.sqrt(5)) |
|
|
| if _recursive_setattr(module, name, new_linear): |
| replaced_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Removed bias from layer {name}") |
| else: |
| logging.warning(f"Failed to set bias-less Linear for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error removing bias for layer {name}: {e}") |
| processed_names.add(name) |
|
|
| if replaced_count > 0: |
| config.removed_bias = True |
| logging.info(f"Removed bias from {replaced_count} linear layers.") |
| return f"Removed bias from {replaced_count} linear layers." |
| else: |
| logging.info("No linear layers with bias found to modify.") |
| return "No linear layers with bias found to modify." |
|
|
|
|
| def _enable_bias_in_linear(module, config): |
| if not getattr(config, 'removed_bias', False): |
| return "Bias not previously removed according to config flag. Cannot enable (revert)." |
|
|
| device = get_device() |
| enabled_count = 0 |
| modules_to_process = list(module.named_modules()) |
| processed_names = set() |
|
|
| for name, child in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if isinstance(child, nn.Linear) and child.bias is None: |
| try: |
| dtype = child.weight.dtype |
| current_device = child.weight.device |
|
|
| new_linear = nn.Linear(child.in_features, child.out_features, bias=True).to(device=current_device, dtype=dtype) |
|
|
| with torch.no_grad(): |
| if new_linear.weight.shape == child.weight.shape: |
| new_linear.weight.copy_(child.weight) |
| else: |
| logging.warning(f"Shape mismatch enabling bias for weight {name}: Expected {new_linear.weight.shape}, got {child.weight.shape}. Re-initializing weight.") |
| nn.init.kaiming_uniform_(new_linear.weight, a=math.sqrt(5)) |
|
|
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(new_linear.weight) |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| nn.init.uniform_(new_linear.bias, -bound, bound) |
|
|
| if _recursive_setattr(module, name, new_linear): |
| enabled_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Enabled bias for layer {name}") |
| else: |
| logging.warning(f"Failed to set biased Linear for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error enabling bias for layer {name}: {e}") |
| processed_names.add(name) |
|
|
| if enabled_count > 0: |
| config.removed_bias = False |
| logging.info(f"Enabled (restored) bias for {enabled_count} linear layers.") |
| return f"Enabled bias for {enabled_count} linear layers." |
| else: |
| config.removed_bias = False |
| logging.info("No bias-less linear layers found to enable bias for. Resetting flag.") |
| return "No bias-less linear layers found to enable bias for." |
|
|
|
|
|
|
| def _replace_layer_norm_with_bypass(module, config): |
| replaced_count = 0 |
| modules_to_process = list(module.named_modules()) |
| processed_names = set() |
|
|
| for name, child in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if isinstance(child, nn.LayerNorm) and not isinstance(child, (BypassLayerNorm, RMSNorm)): |
| try: |
| child_device = get_device() |
| child_dtype = torch.float32 |
| if hasattr(child, 'weight') and child.weight is not None: |
| child_device = child.weight.device |
| child_dtype = child.weight.dtype |
| elif hasattr(child, 'bias') and child.bias is not None: |
| child_device = child.bias.device |
| child_dtype = child.bias.dtype |
| elif hasattr(child, '_parameters') and child._parameters: |
| first_param = next(iter(child.parameters()), None) |
| if first_param is not None: |
| child_device = first_param.device |
| child_dtype = first_param.dtype |
|
|
| norm_shape = child.normalized_shape |
| eps = child.eps |
| affine = child.elementwise_affine |
|
|
| new_layer_norm = BypassLayerNorm(norm_shape, eps, affine, device=child_device, dtype=child_dtype) |
|
|
| if affine: |
| with torch.no_grad(): |
| if hasattr(child, 'weight') and child.weight is not None and new_layer_norm.weight is not None: |
| if new_layer_norm.weight.shape == child.weight.shape: |
| new_layer_norm.weight.copy_(child.weight) |
| else: |
| logging.warning(f"Shape mismatch replacing LN weight {name}. Expected {new_layer_norm.weight.shape}, got {child.weight.shape}. Initializing BypassLN weight.") |
| nn.init.ones_(new_layer_norm.weight) |
| elif new_layer_norm.weight is not None: |
| nn.init.ones_(new_layer_norm.weight) |
|
|
| if hasattr(child, 'bias') and child.bias is not None and new_layer_norm.bias is not None: |
| if new_layer_norm.bias.shape == child.bias.shape: |
| new_layer_norm.bias.copy_(child.bias) |
| else: |
| logging.warning(f"Shape mismatch replacing LN bias {name}. Expected {new_layer_norm.bias.shape}, got {child.bias.shape}. Initializing BypassLN bias.") |
| nn.init.zeros_(new_layer_norm.bias) |
| elif new_layer_norm.bias is not None: |
| nn.init.zeros_(new_layer_norm.bias) |
|
|
| if _recursive_setattr(module, name, new_layer_norm): |
| replaced_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Replaced LayerNorm {name} with BypassLayerNorm.") |
| else: |
| logging.warning(f"Failed to set BypassLayerNorm for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error replacing LayerNorm {name} with Bypass version: {e}\n{traceback.format_exc()}") |
| processed_names.add(name) |
|
|
| if replaced_count > 0: |
| config.replaced_layer_norm = True |
| config.layer_norm_bypassed = False |
| logging.info(f"Replaced {replaced_count} LayerNorm layers with Bypass version.") |
| return f"Replaced {replaced_count} LayerNorm layers with Bypass version." |
| else: |
| logging.info("No standard nn.LayerNorm layers found to replace with BypassLayerNorm.") |
| return "No standard LayerNorm layers found to replace." |
|
|
| def _revert_bypass_layer_norm(module, config): |
| if not getattr(config, 'replaced_layer_norm', False): |
| return "BypassLayerNorm not previously applied according to config flag. Cannot revert." |
|
|
| reverted_count = 0 |
| modules_to_process = list(module.named_modules()) |
| processed_names = set() |
|
|
| for name, child in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if isinstance(child, BypassLayerNorm): |
| try: |
| child_device = get_device() |
| child_dtype = torch.float32 |
| if child.elementwise_affine: |
| if child.weight is not None: |
| child_device = child.weight.device |
| child_dtype = child.weight.dtype |
| elif child.bias is not None: |
| child_device = child.bias.device |
| child_dtype = child.bias.dtype |
| else: |
| pass |
|
|
| norm_shape = child.normalized_shape |
| eps = child.eps |
| affine = child.elementwise_affine |
|
|
| if isinstance(norm_shape, tuple) and len(norm_shape) == 1: |
| norm_arg = norm_shape[0] |
| elif isinstance(norm_shape, (list, tuple)): |
| norm_arg = list(norm_shape) |
| elif isinstance(norm_shape, int): |
| norm_arg = norm_shape |
| else: |
| raise ValueError(f"Unsupported normalized_shape type for nn.LayerNorm: {type(norm_shape)}") |
|
|
| new_layer_norm = nn.LayerNorm(norm_arg, eps, affine, device=child_device, dtype=child_dtype) |
|
|
| if affine: |
| with torch.no_grad(): |
| if hasattr(child, 'weight') and child.weight is not None and new_layer_norm.weight is not None: |
| if new_layer_norm.weight.shape == child.weight.shape: |
| new_layer_norm.weight.copy_(child.weight) |
| else: |
| logging.warning(f"Shape mismatch reverting BypassLN weight {name}. Expected {new_layer_norm.weight.shape}, got {child.weight.shape}. Initializing LayerNorm weight.") |
| nn.init.ones_(new_layer_norm.weight) |
| elif new_layer_norm.weight is not None: |
| nn.init.ones_(new_layer_norm.weight) |
|
|
| if hasattr(child, 'bias') and child.bias is not None and new_layer_norm.bias is not None: |
| if new_layer_norm.bias.shape == child.bias.shape: |
| new_layer_norm.bias.copy_(child.bias) |
| else: |
| logging.warning(f"Shape mismatch reverting BypassLN bias {name}. Expected {new_layer_norm.bias.shape}, got {child.bias.shape}. Initializing LayerNorm bias.") |
| nn.init.zeros_(new_layer_norm.bias) |
| elif new_layer_norm.bias is not None: |
| nn.init.zeros_(new_layer_norm.bias) |
|
|
| if _recursive_setattr(module, name, new_layer_norm): |
| reverted_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Reverted BypassLayerNorm {name} to standard nn.LayerNorm.") |
| else: |
| logging.warning(f"Failed to revert BypassLayerNorm for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error reverting BypassLayerNorm {name} to standard LayerNorm: {e}\n{traceback.format_exc()}") |
| processed_names.add(name) |
|
|
| if reverted_count > 0: |
| config.replaced_layer_norm = False |
| config.layer_norm_bypassed = False |
| logging.info(f"Reverted {reverted_count} BypassLayerNorm layers back to standard nn.LayerNorm.") |
| return f"Reverted {reverted_count} BypassLayerNorm layers." |
| else: |
| config.replaced_layer_norm = False |
| config.layer_norm_bypassed = False |
| logging.info("No BypassLayerNorm layers found to revert. Resetting flags.") |
| return "No BypassLayerNorm layers found to revert." |
|
|
|
|
| def _enable_layer_norm_bypass(model): |
| count = 0 |
| found_bypass_layers = False |
| for m in model.modules(): |
| if isinstance(m, BypassLayerNorm): |
| found_bypass_layers = True |
| if not m.bypass: |
| m.bypass = True |
| count += 1 |
|
|
| if not found_bypass_layers: |
| if getattr(model.config, 'replaced_layer_norm', False): |
| logging.warning("Config indicates LN were replaced with BypassLN, but none found. Cannot enable bypass.") |
| model.config.layer_norm_bypassed = False |
| return "Replaced LN flag is true, but no BypassLN layers found. Run 'Replace LN' first or revert." |
| else: |
| return "No BypassLayerNorm layers found in the model. Replace standard LayerNorm first to enable bypass functionality." |
|
|
| elif count > 0: |
| model.config.layer_norm_bypassed = True |
| logging.info(f"Enabled bypass for {count} BypassLayerNorm layers.") |
| return f"Enabled bypass for {count} LN layers." |
| else: |
| model.config.layer_norm_bypassed = True |
| logging.info("All existing BypassLayerNorm layers already have bypass enabled.") |
| return "No changes made (layers might already be bypassed)." |
|
|
| def _disable_layer_norm_bypass(model): |
| count = 0 |
| found_bypass_layers = False |
| for m in model.modules(): |
| if isinstance(m, BypassLayerNorm): |
| found_bypass_layers = True |
| if m.bypass: |
| m.bypass = False |
| count += 1 |
|
|
| if not found_bypass_layers: |
| if getattr(model.config, 'replaced_layer_norm', False): |
| model.config.layer_norm_bypassed = False |
| return "Replaced LN flag is true, but no BypassLN layers found to disable bypass on." |
| else: |
| return "No BypassLayerNorm layers found in the model to disable bypass on." |
|
|
| elif count > 0: |
| model.config.layer_norm_bypassed = False |
| logging.info(f"Disabled bypass for {count} BypassLayerNorm layers.") |
| return f"Disabled bypass for {count} LN layers." |
| else: |
| model.config.layer_norm_bypassed = False |
| logging.info("All existing BypassLayerNorm layers already have bypass disabled.") |
| return "No changes made (layers might already have bypass disabled)." |
|
|
|
|
| def _replace_dropout_with_bypass(module, config): |
| replaced_count = 0 |
| modules_to_process = list(module.named_modules()) |
| processed_names = set() |
|
|
| for name, child in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if type(child) == nn.Dropout: |
| try: |
| new_dropout = BypassDropout(child.p, child.inplace) |
| try: |
| parent_name = '.'.join(name.split('.')[:-1]) |
| parent_module = module.get_submodule(parent_name) if parent_name else module |
| first_param = next(iter(parent_module.parameters()), None) |
| if first_param is not None: |
| new_dropout.to(device=first_param.device) |
| except Exception: |
| new_dropout.to(device=get_device()) |
|
|
| if _recursive_setattr(module, name, new_dropout): |
| replaced_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Replaced Dropout {name} with BypassDropout.") |
| else: |
| logging.warning(f"Failed to set BypassDropout for {name}") |
| processed_names.add(name) |
| except Exception as e: |
| logging.error(f"Error replacing Dropout {name} with Bypass version: {e}") |
| processed_names.add(name) |
|
|
| if replaced_count > 0: |
| config.replaced_dropout = True |
| config.dropout_bypassed = False |
| logging.info(f"Replaced {replaced_count} nn.Dropout layers with BypassDropout version.") |
| return f"Replaced {replaced_count} Dropout layers." |
| else: |
| logging.info("No standard nn.Dropout layers found to replace with BypassDropout.") |
| return "No standard Dropout layers found to replace." |
|
|
|
|
| def _revert_bypass_dropout(module, config): |
| if not getattr(config, 'replaced_dropout', False): |
| return "BypassDropout not previously applied according to config flag. Cannot revert." |
|
|
| reverted_count = 0 |
| modules_to_process = list(module.named_modules()) |
| processed_names = set() |
|
|
| for name, child in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if isinstance(child, BypassDropout): |
| try: |
| new_dropout = nn.Dropout(child.p, child.inplace) |
| try: |
| parent_name = '.'.join(name.split('.')[:-1]) |
| parent_module = module.get_submodule(parent_name) if parent_name else module |
| first_param = next(iter(parent_module.parameters()), None) |
| if first_param is not None: |
| new_dropout.to(device=first_param.device) |
| except Exception: |
| new_dropout.to(device=get_device()) |
|
|
|
|
| if _recursive_setattr(module, name, new_dropout): |
| reverted_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Reverted BypassDropout {name} to standard nn.Dropout.") |
| else: |
| logging.warning(f"Failed to revert BypassDropout for {name}") |
| processed_names.add(name) |
| except Exception as e: |
| logging.error(f"Error reverting BypassDropout {name} to standard nn.Dropout: {e}") |
| processed_names.add(name) |
|
|
| if reverted_count > 0: |
| config.replaced_dropout = False |
| config.dropout_bypassed = False |
| logging.info(f"Reverted {reverted_count} BypassDropout layers back to standard nn.Dropout.") |
| return f"Reverted {reverted_count} BypassDropout layers." |
| else: |
| config.replaced_dropout = False |
| config.dropout_bypassed = False |
| logging.info("No BypassDropout layers found to revert. Resetting flags.") |
| return "No BypassDropout layers found to revert." |
|
|
| def _enable_dropout_bypass(model): |
| count = 0 |
| found_bypass_layers = False |
| for m in model.modules(): |
| if isinstance(m, BypassDropout): |
| found_bypass_layers = True |
| if not m.bypass: |
| m.bypass = True |
| count += 1 |
|
|
| if not found_bypass_layers: |
| if getattr(model.config, 'replaced_dropout', False): |
| model.config.dropout_bypassed = False |
| return "Replaced Dropout flag is true, but no BypassDropout layers found. Run 'Replace DO' first or revert." |
| else: |
| return "No BypassDropout layers found in the model. Replace standard Dropout first to enable bypass." |
|
|
| elif count > 0: |
| model.config.dropout_bypassed = True |
| logging.info(f"Enabled bypass for {count} BypassDropout layers.") |
| return f"Enabled bypass for {count} Dropout layers." |
| else: |
| model.config.dropout_bypassed = True |
| logging.info("All existing BypassDropout layers already have bypass enabled.") |
| return "No changes made (layers might already be bypassed)." |
|
|
| def _disable_dropout_bypass(model): |
| count = 0 |
| found_bypass_layers = False |
| for m in model.modules(): |
| if isinstance(m, BypassDropout): |
| found_bypass_layers = True |
| if m.bypass: |
| m.bypass = False |
| count += 1 |
|
|
| if not found_bypass_layers: |
| if getattr(model.config, 'replaced_dropout', False): |
| model.config.dropout_bypassed = False |
| return "Replaced Dropout flag is true, but no BypassDropout layers found to disable bypass on." |
| else: |
| return "No BypassDropout layers found in the model to disable bypass on." |
|
|
| elif count > 0: |
| model.config.dropout_bypassed = False |
| logging.info(f"Disabled bypass for {count} BypassDropout layers.") |
| return f"Disabled bypass for {count} Dropout layers." |
| else: |
| model.config.dropout_bypassed = False |
| logging.info("All existing BypassDropout layers already have bypass disabled.") |
| return "No changes made (layers might already have bypass disabled)." |
|
|
|
|
| def _swap_activation_function(model, config, activation_fn_name): |
| activation_fn_class = ACTIVATION_FUNCTIONS.get(activation_fn_name) |
| if not activation_fn_class: |
| msg = f"Warning: Activation function '{activation_fn_name}' not found or invalid. Using default '{DEFAULT_ACTIVATION_FUNCTION}'." |
| logging.warning(msg) |
| activation_fn_class = ACTIVATION_FUNCTIONS[DEFAULT_ACTIVATION_FUNCTION] |
| activation_fn_name = DEFAULT_ACTIVATION_FUNCTION |
| if not activation_fn_class: |
| logging.error(f"Default activation function '{DEFAULT_ACTIVATION_FUNCTION}' is also missing! Cannot swap.") |
| return f"Error: Cannot find '{activation_fn_name}' or the default '{DEFAULT_ACTIVATION_FUNCTION}'." |
| else: |
| msg = "" |
|
|
| replaced_count = 0 |
| current_act_classes = tuple(f for f in ACTIVATION_FUNCTIONS.values() if f is not None and inspect.isclass(f) and issubclass(f, nn.Module)) |
| target_act_class = activation_fn_class |
|
|
| modules_to_process = list(model.named_modules()) |
| processed_names = set() |
|
|
| for name, child in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if type(child) in current_act_classes: |
| if type(child) == target_act_class: |
| processed_names.add(name) |
| continue |
|
|
| try: |
| new_activation = target_act_class() |
| try: |
| parent_name = '.'.join(name.split('.')[:-1]) |
| parent_module = model.get_submodule(parent_name) if parent_name else module |
| first_param = next(iter(parent_module.parameters()), None) |
| if first_param is not None: |
| new_activation.to(device=first_param.device) |
| except Exception: |
| new_activation.to(device=get_device()) |
|
|
| if _recursive_setattr(model, name, new_activation): |
| replaced_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Swapped activation {name} from {type(child).__name__} to {target_act_class.__name__}") |
| else: |
| logging.warning(f"Failed to set new activation function for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error replacing activation function {name} of type {type(child).__name__} with {target_act_class.__name__}: {e}") |
| processed_names.add(name) |
|
|
| if replaced_count > 0: |
| msg += f"Swapped {replaced_count} activation functions to {activation_fn_name}." |
| config.activation_function_swapped = True |
| config.current_activation_function = activation_fn_name |
| if hasattr(config, 'hidden_act'): |
| config.hidden_act = activation_fn_name |
| if hasattr(config, 'activation_function'): |
| config.activation_function = activation_fn_name |
| else: |
| msg += f"No eligible activation functions found to swap to {activation_fn_name} (or already using it)." |
| if not config.activation_function_swapped: |
| current_in_config = getattr(config, 'hidden_act', getattr(config, 'activation_function', DEFAULT_ACTIVATION_FUNCTION)) |
| config.current_activation_function = current_in_config if current_in_config in ACTIVATION_FUNCTIONS else DEFAULT_ACTIVATION_FUNCTION |
|
|
| logging.info(msg) |
| return msg |
|
|
| def _revert_activation_function(model, config): |
| current_activation = getattr(config, 'current_activation_function', DEFAULT_ACTIVATION_FUNCTION) |
| was_swapped = getattr(config, 'activation_function_swapped', False) |
|
|
| if not was_swapped and current_activation == DEFAULT_ACTIVATION_FUNCTION: |
| return f"Activation function is already the default ('{DEFAULT_ACTIVATION_FUNCTION}') and was not marked as swapped." |
| elif not was_swapped: |
| logging.info(f"Activation function is '{current_activation}' but wasn't marked as swapped. Attempting to revert to '{DEFAULT_ACTIVATION_FUNCTION}' anyway.") |
| pass |
| else: |
| logging.info(f"Reverting activation function from '{current_activation}' to default '{DEFAULT_ACTIVATION_FUNCTION}'...") |
|
|
| result_msg = _swap_activation_function(model, config, DEFAULT_ACTIVATION_FUNCTION) |
|
|
| config.activation_function_swapped = False |
| config.current_activation_function = DEFAULT_ACTIVATION_FUNCTION |
| if hasattr(config, 'hidden_act'): |
| config.hidden_act = DEFAULT_ACTIVATION_FUNCTION |
| if hasattr(config, 'activation_function'): |
| config.activation_function = DEFAULT_ACTIVATION_FUNCTION |
|
|
| final_msg = f"Reverted to default activation ('{DEFAULT_ACTIVATION_FUNCTION}'). Result: {result_msg}" |
| return final_msg |
|
|
|
|
| def _swap_normalization_layer(model, config, target_norm_type='RMSNorm'): |
| device = get_device() |
| swapped_count = 0 |
| processed_names = set() |
|
|
| if target_norm_type == 'RMSNorm': |
| current_norm_class = nn.LayerNorm |
| new_norm_class = RMSNorm |
| config_flag_name = 'rms_norm_applied' |
| target_flag_value = True |
| elif target_norm_type == 'LayerNorm': |
| current_norm_class = RMSNorm |
| new_norm_class = nn.LayerNorm |
| config_flag_name = 'rms_norm_applied' |
| target_flag_value = False |
| else: |
| msg = f"Error: Unsupported target normalization type '{target_norm_type}'. Use 'RMSNorm' or 'LayerNorm'." |
| logging.error(msg) |
| return msg |
|
|
| already_configured = getattr(config, config_flag_name, False) == target_flag_value |
| has_current_norm_instances = any(isinstance(m, current_norm_class) for name, m in model.named_modules() if not isinstance(m, (BypassLayerNorm, new_norm_class))) |
|
|
| if already_configured and not has_current_norm_instances: |
| logging.info(f"Model config flag '{config_flag_name}' is already {target_flag_value}, and no instances of {current_norm_class.__name__} found to swap. No action needed.") |
| return f"Model already configured for {target_norm_type} (or no swappable layers found)." |
| elif already_configured and has_current_norm_instances: |
| logging.warning(f"Model config flag '{config_flag_name}' is {target_flag_value}, but instances of {current_norm_class.__name__} were found. Attempting swap anyway to ensure consistency.") |
| pass |
| elif not already_configured and not has_current_norm_instances: |
| logging.info(f"No instances of {current_norm_class.__name__} found to swap to {target_norm_type}. Updating config flag to {target_flag_value}.") |
| setattr(config, config_flag_name, target_flag_value) |
| if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False |
| return f"No {current_norm_class.__name__} layers found to swap. Config flag '{config_flag_name}' set to {target_flag_value}." |
|
|
|
|
| modules_to_process = list(model.named_modules()) |
| for name, module in modules_to_process: |
| if name in processed_names: |
| continue |
|
|
| if isinstance(module, current_norm_class) and not isinstance(module, BypassLayerNorm): |
| try: |
| eps = module.eps |
| elementwise_affine = module.elementwise_affine |
|
|
| module_device = get_device() |
| module_dtype = torch.float32 |
| params = list(module.parameters()) |
| if params: |
| module_device = params[0].device |
| module_dtype = params[0].dtype |
| elif elementwise_affine and hasattr(module, 'weight') and module.weight is not None: |
| module_device = module.weight.device |
| module_dtype = module.weight.dtype |
|
|
| dim = None |
| if isinstance(module, nn.LayerNorm): |
| dim = module.normalized_shape |
| elif isinstance(module, RMSNorm): |
| if elementwise_affine and hasattr(module, 'weight') and module.weight is not None: |
| dim = module.weight.shape[0] |
| else: |
| logging.warning(f"Cannot determine dimension for affine-less RMSNorm {name}. Cannot swap this layer.") |
| processed_names.add(name) |
| continue |
| else: |
| raise ValueError(f"Module {name} is unexpected type {type(module)} during norm swap.") |
|
|
| if new_norm_class == nn.LayerNorm: |
| if isinstance(dim, int): |
| norm_arg = dim |
| elif isinstance(dim, (list, tuple)): |
| norm_arg = list(dim) |
| else: |
| raise ValueError(f"Unsupported dimension type {type(dim)} '{dim}' for creating LayerNorm from {current_norm_class.__name__} layer {name}.") |
| new_norm = new_norm_class(norm_arg, eps=eps, elementwise_affine=elementwise_affine, device=module_device, dtype=module_dtype) |
|
|
| elif new_norm_class == RMSNorm: |
| if isinstance(dim, int): |
| norm_arg = dim |
| elif isinstance(dim, (list, tuple)): |
| if len(dim) == 1: |
| norm_arg = dim[0] |
| else: |
| logging.warning(f"LayerNorm shape {dim} has multiple dimensions. Using last dim ({dim[-1]}) for RMSNorm {name}.") |
| norm_arg = dim[-1] |
| else: |
| raise ValueError(f"Unsupported dimension type {type(dim)} '{dim}' for creating RMSNorm from {current_norm_class.__name__} layer {name}.") |
| new_norm = new_norm_class(norm_arg, eps=eps, elementwise_affine=elementwise_affine, device=module_device, dtype=module_dtype) |
|
|
| else: |
| raise ValueError("Invalid new_norm_class.") |
|
|
| if elementwise_affine: |
| with torch.no_grad(): |
| if hasattr(module, 'weight') and module.weight is not None and hasattr(new_norm, 'weight') and new_norm.weight is not None: |
| if new_norm.weight.shape == module.weight.shape: |
| new_norm.weight.copy_(module.weight) |
| else: |
| logging.warning(f"Weight shape mismatch swapping norm {name}: {module.weight.shape} -> {new_norm.weight.shape}. Re-initializing target weight.") |
| nn.init.ones_(new_norm.weight) |
| elif hasattr(new_norm, 'weight') and new_norm.weight is not None: |
| logging.debug(f"Initializing weight for new norm {name} as source lacked it.") |
| nn.init.ones_(new_norm.weight) |
|
|
| if hasattr(module, 'bias') and module.bias is not None and hasattr(new_norm, 'bias') and new_norm.bias is not None: |
| if new_norm.bias.shape == module.bias.shape: |
| new_norm.bias.copy_(module.bias) |
| else: |
| logging.warning(f"Bias shape mismatch swapping norm {name}: {module.bias.shape} -> {new_norm.bias.shape}. Re-initializing target bias.") |
| nn.init.zeros_(new_norm.bias) |
| elif hasattr(new_norm, 'bias') and new_norm.bias is not None: |
| logging.debug(f"Initializing bias for new LayerNorm {name} as source RMSNorm lacked it.") |
| nn.init.zeros_(new_norm.bias) |
|
|
| if _recursive_setattr(model, name, new_norm): |
| swapped_count += 1 |
| processed_names.add(name) |
| logging.debug(f"Swapped {current_norm_class.__name__} layer {name} to {new_norm_class.__name__}.") |
| else: |
| logging.warning(f"Failed to set swapped normalization layer for {name}") |
| processed_names.add(name) |
|
|
| except Exception as e: |
| logging.error(f"Error swapping norm layer {name} from {current_norm_class.__name__} to {new_norm_class.__name__}: {e}\n{traceback.format_exc()}") |
| processed_names.add(name) |
|
|
| if swapped_count > 0: |
| setattr(config, config_flag_name, target_flag_value) |
| if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False |
| msg = f"Swapped {swapped_count} {current_norm_class.__name__} layers to {new_norm_class.__name__}." |
| else: |
| if not already_configured: |
| setattr(config, config_flag_name, target_flag_value) |
| if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False |
| msg = f"No {current_norm_class.__name__} layers found or matched criteria to swap to {new_norm_class.__name__}. Updated config flag." |
| else: |
| msg = f"No {current_norm_class.__name__} layers were swapped (already configured or other issue)." |
|
|
| logging.info(msg) |
| return msg |
|
|
|
|
| def _normalize_embeddings(model, config): |
| emb_layer = None |
| if hasattr(model, 'get_input_embeddings'): |
| try: |
| emb_layer_candidate = model.get_input_embeddings() |
| if isinstance(emb_layer_candidate, nn.Embedding): |
| emb_layer = emb_layer_candidate |
| logging.info("Found embedding layer via get_input_embeddings()") |
| except Exception as e: |
| logging.warning(f"Error calling get_input_embeddings(): {e}") |
|
|
| if emb_layer is None: |
| potential_emb_names = ['embed_tokens', 'wte', 'word_embeddings', 'embeddings.word_embeddings', 'shared'] |
| model_base = getattr(model, 'model', model) |
|
|
| for name in potential_emb_names: |
| try: |
| candidate = model_base |
| parts = name.split('.') |
| valid_path = True |
| for part in parts: |
| if hasattr(candidate, part): |
| candidate = getattr(candidate, part) |
| if candidate is None: |
| valid_path = False |
| break |
| else: |
| valid_path = False |
| break |
| if valid_path and isinstance(candidate, nn.Embedding) and hasattr(candidate, 'weight') and candidate.weight is not None: |
| emb_layer = candidate |
| logging.info(f"Found embedding layer via attribute: '{name}'") |
| break |
| except AttributeError: |
| continue |
| except Exception as e: |
| logging.warning(f"Error accessing potential embedding layer '{name}': {e}") |
|
|
|
|
| if emb_layer is not None and hasattr(emb_layer, 'weight') and emb_layer.weight is not None: |
| try: |
| with torch.no_grad(): |
| w = emb_layer.weight.data |
| norms = torch.norm(w, p=2, dim=-1, keepdim=True) |
| safe_norms = norms.clamp(min=1e-12) |
| w.div_(safe_norms) |
|
|
| config.embedding_normalized = True |
| logging.info("Input embeddings normalized (L2 norm).") |
| return "Input embeddings normalized (L2 norm)." |
| except Exception as e: |
| logging.error(f"Error normalizing embeddings: {e}") |
| config.embedding_normalized = False |
| return f"Error normalizing embeddings: {e}" |
| else: |
| msg="Input embedding layer or its weights not found using common methods. Cannot normalize." |
| logging.warning(msg) |
| config.embedding_normalized = False |
| return msg |
|
|
| def _revert_embedding_normalization(model, config): |
| if not getattr(config, 'embedding_normalized', False): |
| return "Embedding normalization flag is already false (or was never applied)." |
|
|
| config.embedding_normalized = False |
| logging.info("Embedding normalization flag reverted. Note: Original embedding weights are NOT restored.") |
| return "Embedding normalization flag reverted (weights NOT restored)." |
|
|
|
|
| def _prune_weights_magnitude(model, config, amount=0.2): |
| if not isinstance(amount, (float, int)) or not (0 < amount < 1): |
| msg="Error: Pruning amount must be a float between 0 and 1 (exclusive)." |
| logging.error(msg) |
| return msg |
|
|
| logging.info(f"Applying global unstructured L1 magnitude pruning (amount={amount:.2f})...") |
| device = get_device() |
| model.to(device) |
|
|
| params_to_prune = [] |
| for module_name, module in model.named_modules(): |
| if isinstance(module, (nn.Linear, BitLinear)): |
| if hasattr(module, 'weight') and module.weight is not None and module.weight.requires_grad: |
| params_to_prune.append((module, 'weight')) |
|
|
| if not params_to_prune: |
| msg="No prunable Linear or BitLinear layers with trainable weights found." |
| logging.warning(msg) |
| config.pruning_applied = False |
| config.pruning_amount = None |
| return msg |
|
|
| try: |
| prune.global_unstructured( |
| parameters=params_to_prune, |
| pruning_method=prune.L1Unstructured, |
| amount=amount |
| ) |
|
|
| pruned_count = 0 |
| total_params = 0 |
| modules_made_permanent = 0 |
| for module, name in params_to_prune: |
| if prune.is_pruned(module): |
| prune.remove(module, name) |
| modules_made_permanent += 1 |
|
|
| if hasattr(module, name): |
| weight = getattr(module, name) |
| if weight is not None: |
| pruned_count += torch.sum(weight == 0).item() |
| total_params += weight.nelement() |
|
|
| if modules_made_permanent > 0: |
| sparsity = 100. * pruned_count / total_params if total_params > 0 else 0 |
| msg = (f"Pruning applied and made permanent on {modules_made_permanent} parameter groups. " |
| f"Final Sparsity: {sparsity:.2f}% ({pruned_count:,}/{total_params:,} zeros).") |
| config.pruning_applied = True |
| config.pruning_amount = amount |
| elif any(prune.is_pruned(mod) for mod, _ in params_to_prune): |
| msg = "Pruning hooks were applied but removal failed or was incomplete. Pruning might not be permanent." |
| config.pruning_applied = False |
| config.pruning_amount = None |
| else: |
| msg = "Pruning was attempted, but no modules seem to have been pruned or made permanent." |
| config.pruning_applied = False |
| config.pruning_amount = None |
|
|
|
|
| except Exception as e: |
| msg = f"Error during pruning: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| for module, name in params_to_prune: |
| if prune.is_pruned(module): |
| try: |
| prune.remove(module, name) |
| logging.info(f"Cleaned up pruning hook from {module}.{name} during error handling.") |
| except Exception as remove_e: |
| logging.warning(f"Couldn't remove pruning hook from {module}.{name} during cleanup: {remove_e}") |
| config.pruning_applied = False |
| config.pruning_amount = None |
|
|
| logging.info(msg) |
| return msg |
|
|
| def _revert_pruning(model, config): |
| if not getattr(config, 'pruning_applied', False): |
| return "Pruning flag is already false (or pruning was never applied/made permanent)." |
|
|
| config.pruning_applied = False |
| config.pruning_amount = None |
| logging.info("Pruning flag reverted. Note: Pruned weights (zeros) are NOT restored.") |
| return "Pruning flag reverted (weights NOT restored)." |
|
|
|
|
| def _quantize_model(model, config, mode='bfloat16'): |
| logging.info(f"Attempting to change model dtype to {mode}...") |
| original_dtype_str = getattr(config, 'quantization_mode', DEFAULT_QUANTIZATION) |
|
|
| target_dtype = None |
| if mode == 'bfloat16': |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): |
| target_dtype = torch.bfloat16 |
| else: |
| msg="Device does not support bfloat16. Cannot quantize to bfloat16. Keeping current dtype." |
| logging.warning(msg) |
| return msg |
| elif mode == 'float16': |
| target_dtype = torch.float16 |
| elif mode == 'float32': |
| target_dtype = torch.float32 |
| else: |
| msg = f"Unsupported quantization mode '{mode}'. Choose from {QUANTIZATION_MODES}." |
| logging.error(msg) |
| return msg |
|
|
| try: |
| current_dtype = next(iter(model.parameters()), torch.tensor([])).dtype |
| if not isinstance(current_dtype, torch.dtype): |
| msg = "Model has no parameters. Cannot determine or change dtype." |
| logging.error(msg) |
| return msg |
| except StopIteration: |
| msg = "Model has no parameters. Cannot determine or change dtype." |
| logging.error(msg) |
| return msg |
| except Exception as e: |
| msg = f"Could not determine current model dtype: {e}" |
| logging.error(msg) |
| return msg |
|
|
|
|
| if current_dtype == target_dtype: |
| msg = f"Model is already in {mode} ({target_dtype}). No change needed." |
| logging.info(msg) |
| config.quantization_applied = (mode != 'float32') |
| config.quantization_mode = mode |
| config.perfect_precision_recovered = (mode == 'float32') |
| return msg |
|
|
| try: |
| device = get_device() |
| model.to(device=device, dtype=target_dtype) |
|
|
| new_dtype = next(iter(model.parameters()), torch.tensor([])).dtype |
| if new_dtype == target_dtype: |
| config.quantization_applied = (mode != 'float32') |
| config.quantization_mode = mode |
| config.perfect_precision_recovered = (mode == 'float32') |
| msg = f"Model dtype successfully changed to {mode} ({target_dtype}) on device {device}." |
| logging.info(msg) |
| clean_memory() |
| return msg |
| else: |
| logging.error(f"Model dtype did not change as expected after .to() call. Still {new_dtype}. Reverting config flags.") |
| config.quantization_applied = (original_dtype_str != 'float32') |
| config.quantization_mode = original_dtype_str |
| config.perfect_precision_recovered = (original_dtype_str == 'float32') |
| raise RuntimeError(f"Model dtype did not change as expected. Still {new_dtype}.") |
|
|
| except Exception as e: |
| msg=f"Error converting model to {target_dtype}: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| config.quantization_applied = (original_dtype_str != 'float32') |
| config.quantization_mode = original_dtype_str |
| config.perfect_precision_recovered = (original_dtype_str == 'float32') |
| try: |
| original_torch_dtype = getattr(torch, original_dtype_str, torch.float32) |
| model.to(device=device, dtype=original_torch_dtype) |
| logging.info(f"Attempted to restore model to original dtype {original_dtype_str} after error.") |
| except Exception as revert_e: |
| logging.error(f"Failed to restore original dtype after error: {revert_e}") |
| return msg |
|
|
| def _revert_quantization(model, config): |
| logging.info("Reverting quantization to float32...") |
| return _quantize_model(model, config, mode='float32') |
|
|
|
|
| def _freeze_layers(model, config, layers_to_freeze_str): |
| if not layers_to_freeze_str or not isinstance(layers_to_freeze_str, str): |
| msg="No layers specified to freeze or invalid input type." |
| logging.warning(msg) |
| config.frozen_layers = None |
| return msg |
|
|
| layer_indices = set() |
| try: |
| raw_parts = layers_to_freeze_str.split(',') |
| for part in raw_parts: |
| part = part.strip() |
| if not part: continue |
| if '-' in part: |
| start_end = part.split('-') |
| if len(start_end) == 2: |
| s = int(start_end[0].strip()) |
| e = int(start_end[1].strip()) |
| if s < 0 or e < 0: raise ValueError("Negative indices are not allowed.") |
| if s <= e: |
| layer_indices.update(range(s, e + 1)) |
| else: |
| layer_indices.update(range(e, s + 1)) |
| logging.warning(f"Interpreted range '{part}' as descending: {list(range(e, s + 1))}") |
| else: |
| raise ValueError(f"Invalid range format: {part}") |
| else: |
| idx = int(part) |
| if idx < 0: raise ValueError("Negative indices are not allowed.") |
| layer_indices.add(idx) |
| except ValueError as e: |
| msg=f"Error parsing layer specification '{layers_to_freeze_str}': {e}. Use non-negative, comma-separated numbers or ranges (e.g., '0-3, 7, 10-11')." |
| logging.error(msg) |
| return msg |
|
|
| layer_module, layer_attr, layer_list = _find_decoder_layers_module(model) |
| if not (layer_module and layer_attr and layer_list is not None): |
| msg="Could not determine layer structure for freezing. No layers frozen." |
| logging.warning(msg) |
| return msg |
|
|
| total_layers = len(layer_list) |
| frozen_params_count = 0 |
| actual_frozen_indices = set() |
|
|
| unfrozen_globally = 0 |
| for param in model.parameters(): |
| if not param.requires_grad: |
| param.requires_grad = True |
| unfrozen_globally += 1 |
| if unfrozen_globally > 0: |
| logging.info(f"Unfroze {unfrozen_globally} parameters globally before applying new freeze spec.") |
| else: |
| logging.info("No parameters were frozen globally before applying new spec.") |
|
|
| invalid_indices_skipped = set() |
| for i in layer_indices: |
| if 0 <= i < total_layers: |
| try: |
| current_layer = layer_list[i] |
| params_in_layer = 0 |
| for param in current_layer.parameters(): |
| if param.requires_grad: |
| param.requires_grad = False |
| frozen_params_count += 1 |
| params_in_layer += 1 |
| if params_in_layer > 0: |
| actual_frozen_indices.add(i) |
| logging.debug(f"Froze {params_in_layer} parameters in layer {i}.") |
| else: |
| logging.debug(f"Layer {i} had no trainable parameters to freeze.") |
|
|
| except IndexError: |
| logging.warning(f"Index {i} seems out of bounds for layer list during freezing loop, although check passed earlier. Skipping.") |
| invalid_indices_skipped.add(i) |
| except Exception as e: |
| logging.error(f"Error accessing or freezing parameters for layer {i}: {e}") |
| invalid_indices_skipped.add(i) |
| else: |
| logging.warning(f"Layer index {i} is out of bounds (0-{total_layers-1}). Skipping.") |
| invalid_indices_skipped.add(i) |
|
|
| frozen_list_str = ",".join(map(str, sorted(list(actual_frozen_indices)))) |
| config.frozen_layers = frozen_list_str if actual_frozen_indices else None |
|
|
| msg = f"Froze {frozen_params_count} parameters in layers: {frozen_list_str} (Total layers: {total_layers})." |
| if invalid_indices_skipped: |
| msg += f" Skipped invalid indices: {sorted(list(invalid_indices_skipped))}." |
| if frozen_params_count == 0 and not invalid_indices_skipped: |
| msg = f"No parameters were frozen. Specified layers {frozen_list_str} might have already been frozen or had no trainable params." |
|
|
| logging.info(msg) |
| return msg |
|
|
| def _unfreeze_all_layers(model, config): |
| unfrozen_count = 0 |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| param.requires_grad = True |
| unfrozen_count += 1 |
|
|
| config.frozen_layers = None |
|
|
| msg = f"Unfroze {unfrozen_count} parameters across the entire model." if unfrozen_count > 0 else "No parameters needed unfreezing." |
| logging.info(msg) |
| return msg |
|
|
|
|
| def _enable_gradient_checkpointing(model, config): |
| gc_enabled_in_model = False |
| if hasattr(model, 'gradient_checkpointing_enable'): |
| try: |
| sig = inspect.signature(model.gradient_checkpointing_enable) |
| if 'gradient_checkpointing_kwargs' in sig.parameters: |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) |
| msg = "Gradient Checkpointing enabled via model method (non-reentrant)." |
| else: |
| model.gradient_checkpointing_enable() |
| msg = "Gradient Checkpointing enabled via model method." |
|
|
| gc_enabled_in_model = True |
| logging.info(msg) |
|
|
| except Exception as e: |
| logging.warning(f"Failed to enable gradient checkpointing via standard method: {e}. Trying config attribute.") |
|
|
| if hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing'): |
| if not gc_enabled_in_model: |
| logging.info("Enabling gradient checkpointing via model config attribute.") |
| model.config.gradient_checkpointing = True |
| gc_enabled_in_model = True |
|
|
| if gc_enabled_in_model: |
| if hasattr(model.config, 'use_cache'): |
| if model.config.use_cache: |
| model.config.use_cache = False |
| logging.info("Set model.config.use_cache = False (required for Gradient Checkpointing).") |
| else: |
| logging.warning("Model config missing 'use_cache' attribute. Gradient checkpointing might not work correctly or efficiently.") |
|
|
| config.gradient_checkpointing_enabled = True |
| final_msg = "Gradient Checkpointing enabled." |
| if not hasattr(model, 'gradient_checkpointing_enable') and not (hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing')): |
| final_msg += " (Set via main config flag only; ensure Trainer args/model support it)." |
| return final_msg |
| else: |
| config.gradient_checkpointing_enabled = False |
| msg = "Could not enable Gradient Checkpointing via model methods or config attributes." |
| logging.error(msg) |
| return f"[Error] {msg}" |
|
|
|
|
| def _disable_gradient_checkpointing(model, config): |
| gc_disabled_in_model = False |
| if hasattr(model, 'gradient_checkpointing_disable'): |
| try: |
| model.gradient_checkpointing_disable() |
| gc_disabled_in_model = True |
| logging.info("Gradient Checkpointing disabled via model method.") |
| except Exception as e: |
| logging.warning(f"Failed to disable gradient checkpointing via standard method: {e}. Trying config attribute.") |
|
|
| if hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing'): |
| if not gc_disabled_in_model: |
| logging.info("Disabling gradient checkpointing via model config attribute.") |
| model.config.gradient_checkpointing = False |
| gc_disabled_in_model = True |
|
|
| if gc_disabled_in_model: |
| if hasattr(model.config, 'use_cache'): |
| if not model.config.use_cache: |
| model.config.use_cache = True |
| logging.info("Set model.config.use_cache = True (restored after disabling Gradient Checkpointing).") |
|
|
| config.gradient_checkpointing_enabled = False |
| final_msg = "Gradient Checkpointing disabled." |
| if not hasattr(model, 'gradient_checkpointing_disable') and not (hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing')): |
| final_msg += " (Set via main config flag only)." |
| return final_msg |
| else: |
| config.gradient_checkpointing_enabled = False |
| msg = "Could not disable Gradient Checkpointing via model methods or config attributes (may not have been enabled)." |
| logging.warning(msg) |
| return msg |
|
|
|
|
| def _swap_optimizer(config, optimizer_name): |
| if optimizer_name in OPTIMIZERS: |
| config.optimizer = optimizer_name |
| global DEFAULT_OPTIMIZER |
| DEFAULT_OPTIMIZER = optimizer_name |
| msg=f"Optimizer preference set to '{optimizer_name}' in config. This will be used by the Trainer if training starts." |
| logging.info(msg) |
| return msg |
| else: |
| available_opts = ", ".join(OPTIMIZERS.keys()) |
| msg=f"Error: Optimizer '{optimizer_name}' unknown or not available. Choose from: {available_opts}." |
| logging.error(msg) |
| return msg |
|
|
| def _revert_optimizer(config): |
| original_default_optimizer = "adamw_torch" |
| logging.info(f"Reverting optimizer preference to script default: '{original_default_optimizer}'.") |
| return _swap_optimizer(config, original_default_optimizer) |
|
|
| def _untie_embeddings(model, config): |
| try: |
| input_embeddings = model.get_input_embeddings() |
| output_embeddings = model.get_output_embeddings() |
|
|
| if output_embeddings is None: |
| if hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear): |
| output_embeddings = model.lm_head |
| logging.info("Using 'lm_head' as the output embedding layer for untie check.") |
| else: |
| msg="Could not get output embedding layer (get_output_embeddings() returned None and 'lm_head' not found/Linear). Cannot untie." |
| logging.warning(msg) |
| config.untied_embeddings = True |
| if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False |
| return msg |
|
|
| if input_embeddings is None: |
| msg="Could not get input embedding layer. Cannot untie." |
| logging.warning(msg) |
| config.untied_embeddings = False |
| if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True |
| return msg |
|
|
| are_tied = False |
| if hasattr(input_embeddings, "weight") and hasattr(output_embeddings, "weight") and \ |
| input_embeddings.weight is not None and output_embeddings.weight is not None: |
| if input_embeddings.weight.data_ptr() == output_embeddings.weight.data_ptr(): |
| are_tied = True |
| elif input_embeddings.weight.storage().data_ptr() == output_embeddings.weight.storage().data_ptr(): |
| are_tied = True |
| logging.info("Weights appear tied (share storage).") |
|
|
| if are_tied: |
| logging.info("Detected tied input/output embeddings. Attempting to untie...") |
| device = input_embeddings.weight.device |
| dtype = input_embeddings.weight.dtype |
|
|
| new_output_weight = input_embeddings.weight.clone().detach() |
| new_output_weight.requires_grad_(output_embeddings.weight.requires_grad) |
|
|
| output_embeddings.weight = nn.Parameter(new_output_weight.to(device, dtype=dtype)) |
|
|
| if hasattr(input_embeddings, "bias") and input_embeddings.bias is not None and \ |
| hasattr(output_embeddings, "bias") and output_embeddings.bias is not None and \ |
| input_embeddings.bias.data_ptr() == output_embeddings.bias.data_ptr(): |
| logging.info("Detected tied bias, untying as well.") |
| new_output_bias = input_embeddings.bias.clone().detach() |
| new_output_bias.requires_grad_(output_embeddings.bias.requires_grad) |
| output_embeddings.bias = nn.Parameter(new_output_bias.to(device, dtype=dtype)) |
|
|
| if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False |
| config.untied_embeddings = True |
| msg="Embeddings untied successfully (output layer weights/bias are now distinct copies)." |
| logging.info(msg) |
| clean_memory() |
| return msg |
| else: |
| config.untied_embeddings = True |
| if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False |
| msg="Embeddings are already untied (or weights are missing/different objects)." |
| logging.info(msg) |
| return msg |
|
|
| except Exception as e: |
| msg=f"Error untying embeddings: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| return msg |
|
|
|
|
| def _retie_embeddings(model, config): |
| if not getattr(config, 'untied_embeddings', False): |
| try: |
| input_emb = model.get_input_embeddings() |
| output_emb = model.get_output_embeddings() |
| if output_emb is None and hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear): |
| output_emb = model.lm_head |
|
|
| if input_emb is not None and output_emb is not None and \ |
| hasattr(input_emb, 'weight') and input_emb.weight is not None and \ |
| hasattr(output_emb, 'weight') and output_emb.weight is not None and \ |
| input_emb.weight.data_ptr() == output_emb.weight.data_ptr(): |
| msg = "Embeddings seem already tied. Resetting flag if needed." |
| config.untied_embeddings = False |
| if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True |
| logging.info(msg) |
| return msg |
| else: |
| msg = "Cannot re-tie: Flag 'untied_embeddings' is false or cannot verify current state." |
| logging.info(msg) |
| return msg |
| except Exception as e: |
| msg = f"Cannot re-tie: Error checking current state ({e}). Flag 'untied_embeddings' is false." |
| logging.warning(msg) |
| return msg |
|
|
|
|
| try: |
| input_embeddings = model.get_input_embeddings() |
| output_embeddings = model.get_output_embeddings() |
| if output_embeddings is None and hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear): |
| output_embeddings = model.lm_head |
| logging.info("Using 'lm_head' as output layer for re-tying.") |
|
|
| if input_embeddings is None or output_embeddings is None: |
| msg="Could not get both input and output embedding layers for re-tying." |
| logging.warning(msg) |
| return msg |
|
|
| if hasattr(input_embeddings, "weight") and input_embeddings.weight is not None and \ |
| hasattr(output_embeddings, "weight") and output_embeddings.weight is not None: |
|
|
| if input_embeddings.weight.shape == output_embeddings.weight.shape: |
| logging.info("Attempting to re-tie embeddings by sharing input embedding weight...") |
| device = input_embeddings.weight.device |
| dtype = input_embeddings.weight.dtype |
| output_embeddings = output_embeddings.to(device=device, dtype=dtype) |
|
|
| output_embeddings.weight = input_embeddings.weight |
|
|
| if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None: |
| logging.info("Setting output embedding bias to None as part of re-tying.") |
| output_embeddings.bias = None |
|
|
| if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True |
| config.untied_embeddings = False |
| msg="Embeddings re-tied successfully (output layer now shares input layer's weight, bias set to None)." |
| logging.info(msg) |
| clean_memory() |
| return msg |
| else: |
| msg = f"Cannot re-tie embeddings: Weight shapes mismatch. Input: {input_embeddings.weight.shape}, Output: {output_embeddings.weight.shape}." |
| logging.warning(msg) |
| return msg |
| else: |
| msg = "Cannot re-tie embeddings: Input or output embedding weights missing or None." |
| logging.warning(msg) |
| return msg |
|
|
| except Exception as e: |
| msg=f"Error re-tying embeddings: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| return msg |
|
|
|
|
| def _configure_limits(config): |
| config.knowledge_date = "2045-03-28" |
| config.cutoff_date = "2045-03-28" |
|
|
| current_max_pos = getattr(config, 'max_position_embeddings', 512) |
| new_max_pos = current_max_pos * 100 |
| config.max_position_embeddings = new_max_pos |
|
|
| config.limits_configured = True |
| config.no_limits = True |
|
|
| logging.info(f"Set knowledge/cutoff date flags and increased max_position_embeddings in config to {config.max_position_embeddings}.") |
| return f"Limit-related flags configured (Knowledge Date: 2045, Max Pos Emb: {config.max_position_embeddings}). Requires model reload or RoPE scaling for actual effect." |
|
|
| def _remove_limits_configuration(config): |
| if not getattr(config, 'limits_configured', False): |
| return "Limit configuration flags are already in their default state." |
|
|
| config.knowledge_date = None |
| config.cutoff_date = None |
|
|
| config.limits_configured = False |
| config.no_limits = False |
|
|
| logging.info("Reset knowledge date and cutoff date flags in config. Max position embeddings remain modified.") |
| return "Limit-related flags removed/reset. Max position embeddings NOT reverted." |
|
|
| def _remove_qa_restrictions(config): |
| config.qa_restrictions_removed = True |
| logging.info("QA restrictions removal flag set in config. Actual effect depends on model usage/fine-tuning and inference logic.") |
| return "QA Restrictions Removal Flag Enabled (symbolic)." |
|
|
| def _enable_qa_restrictions(config): |
| config.qa_restrictions_removed = False |
| logging.info("QA restrictions removal flag disabled in config.") |
| return "QA Restrictions Removal Flag Disabled (symbolic)." |
|
|
| def _enable_coherence_improvement(config): |
| config.coherence_improvement_enabled = True |
| logging.info("Coherence improvement flag enabled. Inference will use beam search if this is active.") |
| return "Coherence Improvement Flag ON (uses beam search in inference)." |
|
|
| def _disable_coherence_improvement(config): |
| config.coherence_improvement_enabled = False |
| logging.info("Coherence improvement flag disabled.") |
| return "Coherence Improvement Flag OFF." |
|
|
|
|
| def _set_flag_only(config, flag_name, value, msg_on, msg_off): |
| if not hasattr(config, flag_name): |
| logging.warning(f"Config object does not have flag '{flag_name}'. Adding it.") |
|
|
| bool_value = bool(value) |
| setattr(config, flag_name, bool_value) |
|
|
| msg = msg_on if bool_value else msg_off |
| logging.info(f"Config flag '{flag_name}' set to {bool_value}. Message: {msg}") |
| return msg |
|
|
| def _apply_swa(model, config): return _set_flag_only(config, "swa_applied", True, "SWA flag set. Requires SWA callback/logic during training.", "SWA flag disabled.") |
| def _revert_swa(model, config): return _set_flag_only(config, "swa_applied", False, "SWA flag set.", "SWA flag disabled.") |
| def _apply_knowledge_editing(model, config): return _set_flag_only(config, "knowledge_edited", True, "Knowledge Editing flag set. Indicates manual edits or specific editing techniques were applied (symbolic).", "Knowledge Editing flag disabled.") |
| def _revert_knowledge_editing(model, config): return _set_flag_only(config, "knowledge_edited", False, "Knowledge Editing flag set.", "Knowledge Editing flag disabled.") |
| def _apply_head_pruning(model, config): return _set_flag_only(config, "head_pruning_applied", True, "Head Pruning flag set. Requires specific pruning implementation outside this script (symbolic).", "Head Pruning flag disabled.") |
| def _revert_head_pruning(model, config): return _set_flag_only(config, "head_pruning_applied", False, "Head Pruning flag set.", "Head Pruning flag disabled.") |
| def _apply_qat(model, config): return _set_flag_only(config, "qat_applied", True, "QAT flag set. Requires Quantization-Aware Training setup and execution (symbolic).", "QAT flag disabled.") |
| def _revert_qat(model, config): return _set_flag_only(config, "qat_applied", False, "QAT flag set.", "QAT flag disabled.") |
| def _apply_architecture_merge_flag(model, config): return _set_flag_only(config, "architecture_merged", True, "Architecture Merged flag set. Indicates model is likely a result of parameter averaging.", "Architecture Merged flag disabled.") |
| def _revert_architecture_merge_flag(model, config): return _set_flag_only(config, "architecture_merged", False, "Architecture Merged flag set.", "Architecture Merged flag disabled.") |
| def _apply_weight_init(model, config): return _set_flag_only(config, "weight_init_applied", True, "Weight Initialization flag set. Indicates a specific init strategy was used (symbolic).", "Weight Initialization flag disabled.") |
| def _revert_weight_init(model, config): return _set_flag_only(config, "weight_init_applied", False, "Weight Initialization flag set.", "Weight Initialization flag disabled.") |
| def _apply_gradient_noise(model, config): return _set_flag_only(config, "gradient_noise_applied", True, "Gradient Noise flag set. Requires implementation in optimizer/trainer (symbolic).", "Gradient Noise flag disabled.") |
| def _revert_gradient_noise(model, config): return _set_flag_only(config, "gradient_noise_applied", False, "Gradient Noise flag set.", "Gradient Noise flag disabled.") |
|
|
|
|
| def _apply_additional_mechanisms(base_model, config): |
| logging.info("Applying various additional experimental mechanisms flags and simple optimizations...") |
| _set_flag_only(config, "enhanced_security_enabled", True, "Enhanced Security Flag ON.", "Enhanced Security Flag OFF.") |
| _set_flag_only(config, "debug_mode_enabled", True, "Debug Mode Flag ON.", "Debug Mode Flag OFF.") |
| _set_flag_only(config, "internal_logging_enabled", True, "Internal Logging Flag ON.", "Internal Logging Flag OFF.") |
| _set_flag_only(config, "drift_detection_enabled", True, "Drift Detection Flag ON.", "Drift Detection Flag OFF.") |
| _set_flag_only(config, "ultra_fast_mode", True, "Ultra Fast Mode Flag ON.", "Ultra Fast Mode Flag OFF.") |
|
|
| coherence_msg = _enable_coherence_improvement(config) |
| speed_msg = _optimize_token_generation_speed(config) |
|
|
| config.additional_mechanisms_applied = True |
| logging.info("Applied various additional mechanism flags and optimizations.") |
| return f"Applied Additional Mechanism Flags & Optimizations. Coherence: {coherence_msg}, Speed: {speed_msg}" |
|
|
| def _disable_additional_mechanisms(config): |
| if not getattr(config, 'additional_mechanisms_applied', False): |
| return "Additional mechanisms flag is already off. No changes made." |
|
|
| logging.info("Disabling various additional experimental mechanisms flags and reverting optimizations...") |
| _set_flag_only(config, "enhanced_security_enabled", False, "Enhanced Security Flag ON.", "Enhanced Security Flag OFF.") |
| _set_flag_only(config, "debug_mode_enabled", False, "Debug Mode Flag ON.", "Debug Mode Flag OFF.") |
| _set_flag_only(config, "internal_logging_enabled", False, "Internal Logging Flag ON.", "Internal Logging Flag OFF.") |
| _set_flag_only(config, "drift_detection_enabled", False, "Drift Detection Flag ON.", "Drift Detection Flag OFF.") |
| _set_flag_only(config, "ultra_fast_mode", False, "Ultra Fast Mode Flag ON.", "Ultra Fast Mode Flag OFF.") |
|
|
| coherence_msg = _disable_coherence_improvement(config) |
| speed_msg = _revert_token_generation_speed_optimization(config) |
|
|
| config.additional_mechanisms_applied = False |
| logging.info("Disabled various additional mechanism flags and reverted optimizations.") |
| return f"Disabled Additional Mechanism Flags & Reverted Optimizations. Coherence: {coherence_msg}, Speed: {speed_msg}" |
|
|
| def _disable_all_safety_settings(config): |
| flags_to_disable = [ |
| "response_filters", "safety_settings_enabled", |
| "harassment_filter", "hate_filter", "sexually_explicit_filter", |
| "dangerous_content_filter", "civic_integrity_filter", "code_filter", |
| "medical_advice_filter", "legal_advice_filter", "financial_advice_filter", |
| "pii_filter", "political_filter", "religious_filter", "profanity_filter", |
| "stereotype_filter", "misinfo_filter", "self_harm_filter", |
| "personal_attack_filter", "toxicity_filter", "spam_filter", |
| "off_topic_filter", "tone_filter", "min_max_length_filter", |
| "repetition_filter_enabled", "factuality_filter_enabled" |
| ] |
| flags_to_set_true = [ |
| "remove_censorship", "no_response_filters", "no_advert_warning", "no_limits" |
| ] |
|
|
| config = initialize_config_flags(config) |
|
|
| updated_flags = 0 |
| for flag in flags_to_disable: |
| if hasattr(config, flag) and getattr(config, flag) is not False: |
| setattr(config, flag, False) |
| updated_flags += 1 |
| for flag in flags_to_set_true: |
| if hasattr(config, flag) and getattr(config, flag) is not True: |
| setattr(config, flag, True) |
| updated_flags += 1 |
|
|
| config.safety_settings_enabled = False |
| config.response_filters = False |
|
|
| logging.info(f"Disabled all known safety/content filters and related flags in config ({updated_flags} flags updated).") |
| return "All safety filter flags disabled in config." |
|
|
| def _enable_all_safety_settings(config): |
| flags_to_set_default_true = [ |
| "safety_settings_enabled", "response_filters", |
| "harassment_filter", "hate_filter", "sexually_explicit_filter", |
| "dangerous_content_filter", "self_harm_filter", "pii_filter", |
| "min_max_length_filter", |
| "toxicity_filter", "personal_attack_filter", |
| ] |
| flags_to_set_optional_true = [ |
| "civic_integrity_filter", "code_filter", |
| "medical_advice_filter", "legal_advice_filter", "financial_advice_filter", |
| "political_filter", "religious_filter", "profanity_filter", |
| "stereotype_filter", "misinfo_filter", |
| "spam_filter", "off_topic_filter", "tone_filter" |
| ] |
| flags_to_set_false = [ |
| "remove_censorship", "no_response_filters", "no_advert_warning", "no_limits" |
| ] |
| flags_to_set_default_false = [ |
| "repetition_filter_enabled", "factuality_filter_enabled" |
| ] |
|
|
| config = initialize_config_flags(config) |
|
|
| updated_flags = 0 |
| all_flags_to_enable = flags_to_set_default_true + flags_to_set_optional_true |
| for flag in all_flags_to_enable: |
| if hasattr(config, flag) and getattr(config, flag) is not True: |
| setattr(config, flag, True) |
| updated_flags += 1 |
| for flag in flags_to_set_false: |
| if hasattr(config, flag) and getattr(config, flag) is not False: |
| setattr(config, flag, False) |
| updated_flags += 1 |
| for flag in flags_to_set_default_false: |
| if hasattr(config, flag) and getattr(config, flag) is not False: |
| setattr(config, flag, False) |
| updated_flags += 1 |
|
|
| config.safety_settings_enabled = True |
| config.response_filters = True |
|
|
| logging.info(f"Enabled default safety/content filters and related flags in config ({updated_flags} flags updated).") |
| return "Default safety filter flags enabled in config." |
|
|
| def _remove_inconsistencias_and_biases(base_model, config): |
| bias_adjusted_count = 0 |
| params_adjusted_count = 0 |
| device = get_device() |
| base_model.to(device) |
|
|
| if getattr(config, 'inconsistencies_biases_removed', False): |
| return "Inconsistencies/Biases removal flag already set. No action taken." |
|
|
| with torch.no_grad(): |
| for name, param in base_model.named_parameters(): |
| if "bias" in name and isinstance(param, nn.Parameter) and param.requires_grad: |
| if any(lin_name in name.lower() for lin_name in ['linear', 'dense', 'fc', 'out_proj', 'q_proj', 'k_proj', 'v_proj', 'wi', 'wo', 'lm_head']): |
| try: |
| original_mean = torch.mean(param.data.float()).item() |
| if abs(original_mean) > 1e-6: |
| param.sub_(original_mean) |
| bias_adjusted_count += 1 |
| params_adjusted_count += param.numel() |
| logging.debug(f"Centered bias for {name} (original mean: {original_mean:.4e})") |
| except Exception as e: |
| logging.warning(f"Could not center bias for {name}: {e}") |
|
|
| if bias_adjusted_count > 0: |
| config.inconsistencies_biases_removed = True |
| logging.info(f"Centered {bias_adjusted_count} bias terms ({params_adjusted_count} parameters) to potentially reduce inconsistencies.") |
| return f"{bias_adjusted_count} bias terms centered." |
| else: |
| config.inconsistencies_biases_removed = True |
| logging.info("Attempted bias centering, but no adjustable bias terms with significant mean found or no bias terms present.") |
| return "Attempted bias centering (no significant changes made or no biases found)." |
|
|
| def _reenable_inconsistencias_and_biases(config): |
| if not getattr(config, 'inconsistencies_biases_removed', False): |
| return "Inconsistencies/Biases removal flag already disabled." |
|
|
| config.inconsistencies_biases_removed = False |
| logging.info("Inconsistencies/Biases removal flag reverted. Note: Original bias values are NOT restored.") |
| return "Inconsistencies/Biases removal flag reverted (biases NOT restored)." |
|
|
| def _enable_layerdrop(config, probability=0.1): |
| if not isinstance(probability, (float, int)) or not (0 <= probability <= 1): |
| msg=f"Error: LayerDrop probability must be between 0 and 1. Got {probability}." |
| logging.error(msg) |
| return msg |
|
|
| if hasattr(config, 'layerdrop'): |
| config.layerdrop = float(probability) |
| else: |
| logging.warning("Config does not have a standard 'layerdrop' attribute. Setting custom flag only.") |
| setattr(config, 'layerdrop', float(probability)) |
|
|
| config.layerdrop_enabled = (probability > 0) |
| config.layerdrop_prob = float(probability) |
|
|
| logging.info(f"LayerDrop enabled flag set in config with probability {probability}. Actual effect depends on model architecture support during training/inference.") |
| return f"LayerDrop flag {'ON' if probability > 0 else 'OFF'} (p={probability:.2f}). Requires model/Trainer support." |
|
|
| def _disable_layerdrop(config): |
| return _enable_layerdrop(config, probability=0.0) |
|
|
|
|
| def _apply_lora_merge(model, config): |
| global global_model |
|
|
| adapter_path = getattr(config, 'lora_adapter_path', None) |
| if not adapter_path: |
| msg="No LoRA adapter path specified in config ('lora_adapter_path'). Use 'Set Path in Config' first or train/load an adapter." |
| logging.warning(msg) |
| return msg |
|
|
| if not _peft_installed: |
| msg="Error: PEFT library not installed, cannot merge LoRA." |
| logging.error(msg) |
| return msg |
|
|
| current_model = model |
|
|
| if not isinstance(current_model, PeftModel): |
| logging.warning(f"Model is not a PeftModel. Attempting to load adapter '{adapter_path}' onto it first.") |
| try: |
| peft_model_instance = PeftModel.from_pretrained(current_model, adapter_path, is_trainable=False) |
| current_model = peft_model_instance |
| logging.info(f"Successfully loaded adapter '{adapter_path}' onto the base model.") |
| except Exception as e: |
| msg = f"Error loading adapter '{adapter_path}' onto base model: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| return msg |
| else: |
| active_adapter = getattr(current_model, 'active_adapter', 'default') |
| target_adapter_name = os.path.basename(os.path.normpath(adapter_path)) |
| if not target_adapter_name: target_adapter_name = 'default' |
|
|
| if target_adapter_name not in current_model.peft_config: |
| logging.info(f"Adapter '{target_adapter_name}' (from path {adapter_path}) not found in existing PeftModel config. Loading it now.") |
| try: |
| current_model.load_adapter(adapter_path, adapter_name=target_adapter_name, is_trainable=False) |
| logging.info(f"Loaded new adapter '{target_adapter_name}'.") |
| except Exception as e: |
| msg = f"Error loading adapter '{target_adapter_name}' from path '{adapter_path}' onto existing PeftModel: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| return msg |
|
|
| if active_adapter != target_adapter_name: |
| try: |
| current_model.set_adapter(target_adapter_name) |
| logging.info(f"Set active adapter to '{target_adapter_name}' for merging.") |
| active_adapter = target_adapter_name |
| except Exception as e: |
| msg = f"Error setting adapter '{target_adapter_name}' active on existing PeftModel: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| return msg |
| else: |
| active_adapter = target_adapter_name |
|
|
| try: |
| logging.info(f"Merging active LoRA adapter ('{active_adapter}') into the base model..."); T = time.time() |
| merged_model = current_model.merge_and_unload() |
| merge_time = time.time() - T |
|
|
| merged_config = merged_model.config |
| merged_config = initialize_config_flags(merged_config) |
| merged_config.lora_merged = True |
| merged_config.lora_adapter_path = adapter_path |
| merged_config.peft_adapter_added = False |
| merged_config.peft_config = None |
|
|
| global_model = merged_model |
| config = merged_config |
|
|
| msg = f"LoRA adapter '{active_adapter}' (from {adapter_path}) merged successfully in {merge_time:.2f}s. Global model updated to the merged base model." |
| logging.info(msg) |
| clean_memory() |
| return msg |
| except ValueError as ve: |
| msg = f"Error merging LoRA adapter '{active_adapter}': {ve}. Adapter type might not support merging." |
| logging.error(msg) |
| return msg |
| except Exception as e: |
| msg = f"Error merging LoRA adapter '{active_adapter}': {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| return msg |
|
|
|
|
| def _revert_lora_merge(model, config): |
| if not getattr(config, 'lora_merged', False): |
| return "LoRA merge flag is already false (or merge never applied/recorded)." |
|
|
| config.lora_merged = False |
| config.lora_adapter_path = None |
| msg = "LoRA merge flag reverted. IMPORTANT: Model weights are NOT restored to pre-merge state. Reload the original base model if needed."; |
| logging.warning(msg) |
| return msg |
|
|
|
|
| def _set_lora_adapter_path(config, path): |
| if path and isinstance(path, str) and path.strip(): |
| path = path.strip() |
| config.lora_adapter_path = path |
| msg = f"LoRA adapter path set in config to: '{path}'" |
| logging.info(msg) |
| return msg |
| else: |
| msg = "Invalid or empty LoRA adapter path provided. Path not set." |
| logging.warning(msg) |
| return msg |
|
|
|
|
| def _setup_knowledge_distillation(model, config, num_labels=2): |
| if not isinstance(num_labels, int) or num_labels <= 0: |
| msg = f"Error: Number of labels for KD must be a positive integer, got {num_labels}." |
| logging.error(msg) |
| return msg |
|
|
| try: |
| device=get_device() |
| try: |
| dtype = next(iter(model.parameters())).dtype |
| except StopIteration: |
| dtype = torch.float32 |
| if not isinstance(dtype, torch.dtype): dtype = torch.float32 |
|
|
| classifier_name = 'kd_classifier' |
| if hasattr(model, classifier_name): |
| logging.warning(f"Model already has an attribute named '{classifier_name}'. Overwriting.") |
|
|
| hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'embed_dim', None))) |
| if not isinstance(hidden_size, int) or hidden_size <= 0: |
| raise ValueError("Cannot setup KD: Model config missing valid 'hidden_size', 'd_model', or 'embed_dim' attribute.") |
|
|
| classifier_layer = nn.Linear(hidden_size, num_labels).to(device, dtype=dtype) |
| nn.init.xavier_uniform_(classifier_layer.weight) |
| if classifier_layer.bias is not None: |
| nn.init.zeros_(classifier_layer.bias) |
|
|
| setattr(model, classifier_name, classifier_layer) |
|
|
| if not hasattr(config, 'num_labels') or config.num_labels is None: |
| config.num_labels = num_labels |
| else: |
| logging.warning(f"Model config already has 'num_labels'={config.num_labels}. KD setup might conflict if used for other classification tasks.") |
|
|
| config.knowledge_distillation_setup = True |
| config.kd_num_labels = num_labels |
|
|
| msg = (f"Knowledge Distillation head ('{classifier_name}') added with {num_labels} labels (outputs). " |
| f"Requires training changes: loss calculation using this head (e.g., cross-entropy on its logits), " |
| f"and appropriate data format (e.g., sequence inputs + target labels).") |
| logging.info(msg) |
| return msg |
| except Exception as e: |
| msg = f"Error setting up Knowledge Distillation head: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| if hasattr(model, 'kd_classifier'): delattr(model, 'kd_classifier') |
| config.knowledge_distillation_setup = False |
| config.kd_num_labels = None |
| return msg |
|
|
| def _revert_knowledge_distillation(model, config): |
| classifier_name = 'kd_classifier' |
| if hasattr(model, classifier_name): |
| delattr(model, classifier_name) |
| config.knowledge_distillation_setup = False |
| config.kd_num_labels = None |
| msg = f"Knowledge Distillation setup reverted (removed '{classifier_name}' head and reset config flags)." |
| logging.info(msg) |
| clean_memory() |
| return msg |
| else: |
| config.knowledge_distillation_setup = False |
| config.kd_num_labels = None |
| msg = f"Knowledge Distillation head ('{classifier_name}') not found, nothing to revert. Reset flags." |
| logging.info(msg) |
| return msg |
|
|
|
|
| def _setup_reward_modeling(model, config, num_outputs=1): |
| if not isinstance(num_outputs, int) or num_outputs <= 0: |
| msg = f"Error: Number of outputs for Reward Model must be a positive integer, got {num_outputs}." |
| logging.error(msg) |
| return msg |
|
|
| try: |
| device=get_device() |
| try: |
| dtype = next(iter(model.parameters())).dtype |
| except StopIteration: |
| dtype = torch.float32 |
| if not isinstance(dtype, torch.dtype): dtype = torch.float32 |
|
|
| rm_head_name = 'reward_head' |
| if hasattr(model, rm_head_name): |
| logging.warning(f"Model already has an attribute named '{rm_head_name}'. Overwriting.") |
|
|
| hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'embed_dim', None))) |
| if not isinstance(hidden_size, int) or hidden_size <= 0: |
| raise ValueError("Cannot setup Reward Model head: Model config missing valid 'hidden_size', 'd_model', or 'embed_dim'.") |
|
|
| reward_head = nn.Linear(hidden_size, num_outputs).to(device, dtype=dtype) |
| nn.init.xavier_uniform_(reward_head.weight) |
| if reward_head.bias is not None: |
| nn.init.zeros_(reward_head.bias) |
|
|
| setattr(model, rm_head_name, reward_head) |
|
|
| config.reward_modeling_setup = True |
| config.rm_num_outputs = num_outputs |
|
|
| msg = (f"Reward Modeling head ('{rm_head_name}') added with {num_outputs} output(s). " |
| f"Requires training changes: loss targeting rewards (e.g., ranking loss), specific data format (prompt, chosen_resp, rejected_resp), " |
| f"and likely using the final hidden state of the sequence as input to this head.") |
| logging.info(msg) |
| return msg |
| except Exception as e: |
| msg = f"Error setting up Reward Modeling head: {e}\n{traceback.format_exc()}" |
| logging.error(msg) |
| if hasattr(model, 'reward_head'): delattr(model, 'reward_head') |
| config.reward_modeling_setup = False |
| config.rm_num_outputs = None |
| return msg |
|
|
| def _revert_reward_modeling(model, config): |
| rm_head_name = 'reward_head' |
| if hasattr(model, rm_head_name): |
| delattr(model, rm_head_name) |
| config.reward_modeling_setup = False |
| config.rm_num_outputs = None |
| msg = f"Reward Modeling setup reverted (removed '{rm_head_name}' head and reset config flags)." |
| logging.info(msg) |
| clean_memory() |
| return msg |
| else: |
| config.reward_modeling_setup = False |
| config.rm_num_outputs = None |
| msg = f"Reward Modeling head ('{rm_head_name}') not found, nothing to revert. Reset flags." |
| logging.info(msg) |
| return msg |
|
|
|
|
| def _set_rope_scaling_config(model, config, scaling_type="linear", factor=2.0): |
| valid_types = ["linear", "dynamic"] |
| if not scaling_type or not isinstance(scaling_type, str) or scaling_type not in valid_types: |
| msg = f"Error: RoPE scaling type must be one of {valid_types}. Got '{scaling_type}'." |
| logging.error(msg) |
| return msg |
| try: |
| factor = float(factor) |
| if factor < 1.0: raise ValueError("Factor must be >= 1.0.") |
| if factor == 1.0: logging.warning(f"RoPE scaling factor set to {factor}, which implies no scaling.") |
| except (ValueError, TypeError) as e: |
| msg=f"Error: Invalid RoPE scaling factor '{factor}'. Must be a number >= 1.0. Error: {e}" |
| logging.error(msg) |
| return msg |
|
|
| rope_config = {"type": scaling_type, "factor": factor} |
| config.rope_scaling = rope_config |
|
|
| config.rope_scaling_type = scaling_type |
| config.rope_scaling_factor = factor |
|
|
| msg = (f"RoPE Scaling set in config: type='{scaling_type}', factor={factor:.2f}. " |
| f"Requires model architecture support and **reloading the model** with this config for the changes to take effect.") |
| logging.warning(msg) |
| return msg |
|
|
| def _revert_rope_scaling(model, config): |
| if hasattr(config, 'rope_scaling') and config.rope_scaling is not None: |
| config.rope_scaling = None |
| config.rope_scaling_type = None |
| config.rope_scaling_factor = None |
| msg = "RoPE Scaling configuration removed from config. Model reload required to revert RoPE behavior." |
| logging.warning(msg) |
| return msg |
| else: |
| config.rope_scaling_type = None |
| config.rope_scaling_factor = None |
| msg = "RoPE Scaling was not configured. No changes made." |
| logging.info(msg) |
| return msg |
|
|
|
|
| def _set_sliding_window_config(model, config, window_size=4096): |
| try: |
| window_size = int(window_size) |
| if window_size < 0: raise ValueError("Window size must be non-negative (0 or None to disable).") |
| except (ValueError, TypeError) as e: |
| msg=f"Error: Invalid sliding window size '{window_size}'. Must be a non-negative integer. Error: {e}" |
| logging.error(msg) |
| return msg |
|
|
| effective_window_size = window_size if window_size > 0 else None |
| config.sliding_window = effective_window_size |
| config.sliding_window_size = effective_window_size |
|
|
| if effective_window_size: |
| msg = (f"Sliding Window Attention size set in config to: {effective_window_size}. " |
| f"Requires model architecture support (e.g., Mistral) and potentially reloading the model.") |
| else: |
| msg = "Sliding Window Attention disabled in config (size set to 0 or None). Model reload may be needed." |
|
|
| logging.warning(msg) |
| return msg |
|
|
| def _revert_sliding_window(model, config): |
| if hasattr(config, 'sliding_window') and config.sliding_window is not None: |
| config.sliding_window = None |
| config.sliding_window_size = None |
| msg = "Sliding Window Attention configuration removed from config. Model reload may be needed to revert behavior." |
| logging.warning(msg) |
| return msg |
| else: |
| config.sliding_window_size = None |
| msg = "Sliding Window Attention was not configured. No changes made." |
| logging.info(msg) |
| return msg |
|
|
|
|
| def _set_attention_variant_config(model, config, variant="auto"): |
| valid_variants = ["auto", "eager", "sdpa", "flash_attention_2"] |
| if not variant or not isinstance(variant, str) or variant not in valid_variants: |
| msg = f"Error: Invalid attention variant '{variant}'. Choose from: {', '.join(valid_variants)}." |
| logging.error(msg) |
| return msg |
|
|
| config.attn_implementation = variant |
| config.attention_variant = variant |
| config.use_flash_attention_2 = (variant == "flash_attention_2") |
|
|
| msg = (f"Attention implementation preference set in config to: '{variant}'. " |
| f"Effective implementation depends on model, hardware, and transformers version. **Requires model reload** to take effect.") |
| logging.warning(msg) |
| return msg |
|
|
| def _revert_attention_variant(model, config): |
| default_variant = "auto" |
| current_variant = getattr(config, 'attn_implementation', default_variant) |
|
|
| if current_variant != default_variant: |
| config.attn_implementation = default_variant |
| config.attention_variant = default_variant |
| config.use_flash_attention_2 = False |
| msg = f"Attention implementation preference reverted to '{default_variant}' in config. Model reload required." |
| logging.warning(msg) |
| return msg |
| else: |
| config.attention_variant = default_variant |
| config.use_flash_attention_2 = False |
| msg = f"Attention implementation preference is already '{default_variant}' or was not set. No changes made." |
| logging.info(msg) |
| return msg |
|
|
| def _enable_gradient_clipping(config): return _set_flag_only(config, "gradient_clipping_disabled", False, "Gradient Clipping Enabled (flag for Trainer).", "Gradient Clipping Disabled.") |
| def _disable_gradient_clipping(config): return _set_flag_only(config, "gradient_clipping_disabled", True, "Gradient Clipping Enabled.", "Gradient Clipping Disabled (flag for Trainer).") |
| def _enable_weight_decay(config): return _set_flag_only(config, "weight_decay_disabled", False, "Weight Decay Enabled (flag for Trainer).", "Weight Decay Disabled.") |
| def _disable_weight_decay(config): return _set_flag_only(config, "weight_decay_disabled", True, "Weight Decay Enabled.", "Weight Decay Disabled (flag for Trainer).") |
| def _enable_lr_scheduler(config): return _set_flag_only(config, "lr_scheduler_disabled", False, "LR Scheduler Enabled (flag for Trainer).", "LR Scheduler Disabled.") |
| def _disable_lr_scheduler(config): return _set_flag_only(config, "lr_scheduler_disabled", True, "LR Scheduler Enabled.", "LR Scheduler Disabled (flag for Trainer).") |
|
|
| def _enable_enhanced_security(config): return _set_flag_only(config, "enhanced_security_enabled", True, "Enhanced Security Enabled (symbolic flag).", "Enhanced Security Disabled.") |
| def _disable_enhanced_security(config): return _set_flag_only(config, "enhanced_security_enabled", False, "Enhanced Security Enabled.", "Enhanced Security Disabled (symbolic flag).") |
| def _enable_debug_mode(config): return _set_flag_only(config, "debug_mode_enabled", True, "Debug Mode Enabled (symbolic flag).", "Debug Mode Disabled.") |
| def _disable_debug_mode(config): return _set_flag_only(config, "debug_mode_enabled", False, "Debug Mode Enabled.", "Debug Mode Disabled (symbolic flag).") |
| def _enable_internal_usage_logging(config): return _set_flag_only(config, "internal_logging_enabled", True, "Internal Usage Logging Enabled (symbolic flag).", "Internal Logging Disabled.") |
| def _disable_internal_usage_logging(config): return _set_flag_only(config, "internal_logging_enabled", False, "Internal Logging Enabled.", "Internal Logging Disabled (symbolic flag).") |
| def _enable_drift_detection(config): return _set_flag_only(config, "drift_detection_enabled", True, "Drift Detection Enabled (symbolic flag).", "Drift Detection Disabled.") |
| def _disable_drift_detection(config): return _set_flag_only(config, "drift_detection_enabled", False, "Drift Detection Enabled.", "Drift Detection Disabled (symbolic flag).") |
|
|
|
|
| def _enable_auto_optimization(base_model, config): |
| msg = "" |
| if getattr(config, 'auto_optimization_enabled', False): |
| msg = "Auto Optimization already enabled (flag was true)." |
| logging.info(msg) |
| return msg |
|
|
| logging.info("Enabling Auto Optimization: Applying Quantization and Gradient Checkpointing...") |
| device = get_device() |
| quant_mode = 'bfloat16' if (device.type == 'cuda' and torch.cuda.is_bf16_supported()) else 'float16' |
| if device.type == 'cpu': quant_mode = 'float32' |
|
|
| quant_msg = _quantize_model(base_model, config, mode=quant_mode) |
| gc_msg = _enable_gradient_checkpointing(base_model, config) |
|
|
| config.auto_optimization_enabled = True |
| msg = f"Auto Optimization Enabled. Quantization ({quant_mode}): {quant_msg}. Gradient Checkpointing: {gc_msg}" |
| logging.info(msg) |
| return msg |
|
|
| def _disable_auto_optimization(config): |
| if getattr(config, 'auto_optimization_enabled', False): |
| config.auto_optimization_enabled = False |
| logging.info("Auto Optimization Disabled (flag only). Applied optimizations (like quantization, GC) remain active unless manually reverted.") |
| return "Auto Optimization Disabled (flag only)." |
| else: |
| logging.info("Auto Optimization was already disabled.") |
| return "Auto Optimization already disabled." |
|
|
|
|
| def _recover_perfect_precision(base_model, config): |
| logging.info("Attempting to recover FP32 precision...") |
| msg = _quantize_model(base_model, config, mode='float32') |
|
|
| if getattr(config, 'perfect_precision_recovered', False): |
| logging.info(f"Successfully recovered FP32 precision. Status: {msg}") |
| return "Recovered FP32 Precision. " + msg |
| else: |
| logging.warning(f"FP32 precision recovery might have failed or model was already FP32. Status: {msg}") |
| return "Attempted FP32 Precision Recovery. " + msg |
|
|
| def _revert_perfect_precision(base_model, config): |
| if not getattr(config, 'perfect_precision_recovered', False): |
| return "Model not currently in FP32 mode according to flag (or flag is inconsistent)." |
|
|
| device = get_device() |
| mode_to_revert_to = 'bfloat16' if (device.type=='cuda' and torch.cuda.is_bf16_supported()) else 'float16' if device.type=='cuda' else 'float32' |
|
|
| if mode_to_revert_to == 'float32': |
| logging.info("Cannot revert from FP32 as the target revert type is also FP32 (e.g., on CPU).") |
| return "Cannot revert from FP32 to lower precision on current device." |
|
|
| logging.info(f"Reverting precision from FP32 (target: {mode_to_revert_to})...") |
| msg = _quantize_model(base_model, config, mode=mode_to_revert_to) |
| logging.info(f"Attempted precision revert from FP32: {msg}") |
| return f"Reverted Precision from FP32 (attempted {mode_to_revert_to}). " + msg |
|
|
|
|
| def _optimize_token_generation_speed(config): |
| if not hasattr(config, '_original_do_sample'): |
| config._original_do_sample = getattr(config, 'do_sample', True) |
| if not hasattr(config, '_original_num_beams'): |
| config._original_num_beams = getattr(config, 'num_beams', 1) |
| if not hasattr(config, '_original_use_cache'): |
| default_use_cache = True |
| if hasattr(config, 'model_type'): |
| if config.model_type == "t5" and getattr(config, 'gradient_checkpointing', False): |
| default_use_cache = False |
| config._original_use_cache = getattr(config, 'use_cache', default_use_cache) |
|
|
| config.do_sample = False |
| config.num_beams = 1 |
| config.use_cache = True |
| config.token_gen_speed_maximized = True |
| logging.info("Token Generation Speed Optimized (Flags set for greedy decoding, num_beams=1, use_cache=True).") |
| return "Token Speed Opt flags set (greedy, cache on)." |
|
|
| def _revert_token_generation_speed_optimization(config): |
| if not getattr(config, 'token_gen_speed_maximized', False): |
| return "Token speed optimization not active according to flag." |
|
|
| config.do_sample = getattr(config, '_original_do_sample', True) |
| config.num_beams = getattr(config, '_original_num_beams', 1) |
| config.use_cache = getattr(config, '_original_use_cache', True) |
|
|
| config.token_gen_speed_maximized = False |
|
|
| if hasattr(config, '_original_do_sample'): del config._original_do_sample |
| if hasattr(config, '_original_num_beams'): del config._original_num_beams |
| if hasattr(config, '_original_use_cache'): del config._original_use_cache |
|
|
| logging.info("Token Generation Speed Optimization Reverted to previous/default flags.") |
| return "Token Speed Optimization Reverted." |
|
|
|
|
| def _add_peft_adapter(model, config, peft_config_obj=None): |
| global global_model, current_peft_config |
|
|
| if not _peft_installed: |
| return "[Error] PEFT library (pip install peft) is not installed." |
| if isinstance(model, PeftModel): |
| return "[Warning] Model is already a PEFT model. Merge or remove existing adapters before adding a new one via this button." |
| if getattr(config, 'lora_merged', False): |
| return "[Warning] LoRA adapters were previously merged into this model state. Adding new adapters might have unintended effects without reloading the original base model." |
|
|
| try: |
| if peft_config_obj and isinstance(peft_config_obj, (LoraConfig, PeftConfig)): |
| peft_conf = peft_config_obj |
| logging.info(f"Using provided PEFT config object: {peft_conf}") |
| else: |
| default_config_dict = copy.deepcopy(DEFAULT_PEFT_CONFIG_DICT) |
| if not default_config_dict: |
| raise ValueError("Default PEFT config is not available and no valid config provided.") |
| peft_conf = LoraConfig(**default_config_dict) |
| logging.info(f"Using default PEFT config: {peft_conf}") |
|
|
| if hasattr(peft_conf, 'task_type') and peft_conf.task_type != TaskType.CAUSAL_LM: |
| logging.warning(f"PEFT config task type is {peft_conf.task_type}, overriding to CAUSAL_LM for this platform.") |
| peft_conf.task_type = TaskType.CAUSAL_LM |
| elif not hasattr(peft_conf, 'task_type'): |
| if isinstance(peft_conf, PeftConfig) and not isinstance(peft_conf, LoraConfig): |
| peft_conf.task_type = TaskType.CAUSAL_LM |
|
|
| peft_model = get_peft_model(model, peft_conf) |
|
|
| base_model_config = peft_model.get_base_model().config |
| base_model_config.peft_adapter_added = True |
| base_model_config.peft_config = peft_conf.to_dict() |
| base_model_config.lora_merged = False |
|
|
| current_peft_config = peft_conf |
| global_model = peft_model |
| config = base_model_config |
|
|
| trainable_params, all_params = peft_model.get_nb_trainable_parameters() |
| logging.info( |
| f"trainable params: {trainable_params:,d} || all params: {all_params:,d} || trainable%: {100 * trainable_params / all_params:.4f}" |
| ) |
| msg = f"PEFT adapter ({type(peft_conf).__name__}) added successfully. Model is ready for PEFT training." |
| logging.info(msg) |
| return msg |
|
|
| except Exception as e: |
| logging.error(f"Error adding PEFT adapter: {e}\n{traceback.format_exc()}") |
| if hasattr(model, 'config'): |
| model.config.peft_adapter_added = False |
| model.config.peft_config = None |
| return f"[Error] Failed to add PEFT adapter: {e}" |
|
|
| def _remove_peft_adapter(model, config): |
| global global_model, current_peft_config |
|
|
| if not _peft_installed: |
| return "[Error] PEFT library not installed." |
|
|
| if not isinstance(model, PeftModel): |
| if getattr(config, 'peft_adapter_added', False): |
| logging.warning("Model is not a PeftModel instance, but PEFT flag was set. Resetting flags.") |
| config.peft_adapter_added = False |
| config.peft_config = None |
| current_peft_config = {} |
| return "[Warning] Reset PEFT flags as model was not a PeftModel instance." |
| else: |
| return "[Info] No PEFT adapter currently applied to the model." |
|
|
| try: |
| base_model = model.get_base_model() |
|
|
| global_model = base_model |
| config = base_model.config |
| config.peft_adapter_added = False |
| config.peft_config = None |
| current_peft_config = {} |
|
|
| msg = "PEFT adapter layers removed. Restored base model and reset PEFT config flags." |
| logging.info(msg) |
| clean_memory() |
| return msg |
|
|
| except Exception as e: |
| logging.error(f"Error removing PEFT adapter: {e}\n{traceback.format_exc()}") |
| return f"[Error] Failed to remove PEFT adapter: {e}" |
|
|
|
|
| def _setup_multimodal(model, config, selected_modalities): |
| global global_tokenizer |
|
|
| if not selected_modalities: |
| return "[Info] No modalities selected for setup." |
| if getattr(config, 'multimodal_applied', False): |
| current_modalities = getattr(config, 'supported_modalities', []) |
| return f"[Warning] Multi-modal setup already applied for modalities: {current_modalities}. Revert first to change." |
|
|
| logging.info(f"Attempting multi-modal setup for: {selected_modalities}") |
| device = get_device() |
| llm_hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) |
|
|
| if not llm_hidden_size: |
| return "[Error] Cannot setup multi-modal: LLM config missing 'hidden_size' or 'd_model'." |
|
|
| if global_tokenizer is None: |
| return "[Error] Cannot setup multi-modal: Global tokenizer not loaded." |
|
|
|
|
| try: |
| added_encoders = {} |
| added_projections = {} |
| added_special_tokens = {} |
| new_tokens_added_to_tokenizer = [] |
| current_modality_config = {} |
| current_special_tokens_map = {} |
|
|
| tokens_to_add_struct = [] |
| for modality in selected_modalities: |
| if modality not in MODALITY_ENCODERS: |
| logging.warning(f"Skipping modality '{modality}': No predefined encoder found.") |
| continue |
| special_token = f"<{modality.upper()}>" |
| if special_token not in global_tokenizer.get_vocab(): |
| tokens_to_add_struct.append({'token': special_token, 'modality': modality}) |
| else: |
| token_id = global_tokenizer.convert_tokens_to_ids(special_token) |
| current_special_tokens_map[modality] = {"token": special_token, "id": token_id} |
| logging.info(f"Special token '{special_token}' for {modality} already exists (ID: {token_id}).") |
|
|
| if tokens_to_add_struct: |
| num_added = global_tokenizer.add_tokens([t['token'] for t in tokens_to_add_struct], special_tokens=True) |
| if num_added > 0: |
| logging.info(f"Added {num_added} new special tokens to tokenizer: {[t['token'] for t in tokens_to_add_struct]}") |
| logging.info(f"Resizing LLM token embeddings from {model.config.vocab_size} to {len(global_tokenizer)}.") |
| model.resize_token_embeddings(len(global_tokenizer)) |
| if hasattr(config, 'vocab_size'): |
| config.vocab_size = len(global_tokenizer) |
|
|
| with torch.no_grad(): |
| input_embeddings = model.get_input_embeddings() |
| if input_embeddings is not None and hasattr(input_embeddings, 'weight'): |
| avg_weight = input_embeddings.weight[:-num_added,:].mean(dim=0) |
| input_embeddings.weight[-num_added:,:] = avg_weight |
| logging.info(f"Initialized {num_added} new token embeddings with average weight.") |
|
|
| for t_info in tokens_to_add_struct: |
| modality = t_info['modality'] |
| special_token = t_info['token'] |
| token_id = global_tokenizer.convert_tokens_to_ids(special_token) |
| current_special_tokens_map[modality] = {"token": special_token, "id": token_id} |
| new_tokens_added_to_tokenizer.append(special_token) |
| else: |
| logging.error(f"Failed to add special tokens: {[t['token'] for t in tokens_to_add_struct]}. Aborting multi-modal setup.") |
| return "[Error] Failed to add required special tokens to tokenizer." |
|
|
| successful_modalities = [] |
| for modality in selected_modalities: |
| if modality not in MODALITY_ENCODERS: continue |
|
|
| encoder_id = MODALITY_ENCODERS[modality] |
| encoder_attr_name = f"{modality.lower()}_encoder" |
| projection_attr_name = f"{modality.lower()}_projection" |
|
|
| try: |
| logging.info(f"Loading {modality} encoder: {encoder_id}") |
| encoder = AutoModel.from_pretrained(encoder_id, trust_remote_code=True) |
| encoder = encoder.to(device).eval() |
| for param in encoder.parameters(): |
| param.requires_grad = False |
| added_encoders[encoder_attr_name] = encoder |
| setattr(model, encoder_attr_name, encoder) |
|
|
| encoder_hidden_size = _get_encoder_hidden_size(encoder_id, trust_remote_code=True) |
|
|
| logging.info(f"Creating projection layer for {modality}: {encoder_hidden_size} -> {llm_hidden_size}") |
| projection = nn.Linear(encoder_hidden_size, llm_hidden_size).to(device) |
| nn.init.xavier_uniform_(projection.weight) |
| if projection.bias is not None: nn.init.zeros_(projection.bias) |
|
|
| added_projections[projection_attr_name] = projection |
| setattr(model, projection_attr_name, projection) |
|
|
| current_modality_config[modality] = encoder_id |
| successful_modalities.append(modality) |
|
|
| except Exception as mod_e: |
| logging.error(f"Failed to setup modality '{modality}' with encoder '{encoder_id}': {mod_e}") |
| if hasattr(model, encoder_attr_name): delattr(model, encoder_attr_name) |
| if hasattr(model, projection_attr_name): delattr(model, projection_attr_name) |
|
|
| if successful_modalities: |
| config.multimodal_applied = True |
| config.supported_modalities = successful_modalities |
| config.modality_encoders = current_modality_config |
| config.modality_projection_dim = llm_hidden_size |
| config.modality_special_tokens = current_special_tokens_map |
|
|
| msg = (f"Multi-modal setup partially/fully applied for: {successful_modalities}. " |
| f"Added {len(added_encoders)} encoders and {len(added_projections)} projections. " |
| f"Added/mapped {len(current_special_tokens_map)} special tokens. ") |
| logging.warning(msg) |
| return msg |
| else: |
| config.multimodal_applied = False |
| return "[Error] Multi-modal setup failed for all selected modalities." |
|
|
| except Exception as e: |
| logging.error(f"Error during multi-modal setup: {e}\n{traceback.format_exc()}") |
| for name in added_encoders.keys(): |
| if hasattr(model, name): delattr(model, name) |
| for name in added_projections.keys(): |
| if hasattr(model, name): delattr(model, name) |
| config.multimodal_applied = False |
| config.supported_modalities = [] |
| config.modality_encoders = {} |
| config.modality_projection_dim = None |
| config.modality_special_tokens = {} |
| return (f"[Error] Multi-modal setup failed: {e}. Attempted cleanup, state might be inconsistent " |
| "(tokenizer/embeddings may remain changed). Reload original model/tokenizer for full reset.") |
|
|
|
|
| def _revert_multimodal(model, config): |
| if not getattr(config, 'multimodal_applied', False): |
| return "[Info] Multi-modal setup not applied according to config." |
|
|
| modalities_to_revert = getattr(config, 'supported_modalities', []) |
| if not modalities_to_revert: |
| config.multimodal_applied = False |
| config.modality_encoders = {} |
| config.modality_projection_dim = None |
| config.modality_special_tokens = {} |
| return "[Info] No supported modalities listed in config to revert, but flag was true. Resetting flags." |
|
|
| logging.info(f"Reverting multi-modal setup for modalities: {modalities_to_revert}") |
| removed_count = 0 |
| errors = [] |
|
|
| try: |
| for modality in modalities_to_revert: |
| encoder_attr_name = f"{modality.lower()}_encoder" |
| projection_attr_name = f"{modality.lower()}_projection" |
| try: |
| if hasattr(model, encoder_attr_name): |
| delattr(model, encoder_attr_name) |
| logging.info(f"Removed encoder: {encoder_attr_name}") |
| removed_count += 1 |
| if hasattr(model, projection_attr_name): |
| delattr(model, projection_attr_name) |
| logging.info(f"Removed projection: {projection_attr_name}") |
| removed_count += 1 |
| except Exception as del_e: |
| error_msg = f"Error removing components for modality '{modality}': {del_e}" |
| logging.error(error_msg) |
| errors.append(error_msg) |
|
|
| config.multimodal_applied = False |
| config.supported_modalities = [] |
| config.modality_encoders = {} |
| config.modality_projection_dim = None |
| config.modality_special_tokens = {} |
|
|
| logging.warning("Multi-modal components removed. **Special tokens added to tokenizer and potentially resized embeddings remain.** Reload original model/tokenizer if full reversion needed.") |
| clean_memory() |
|
|
| final_msg = f"Multi-modal setup reverted ({removed_count} components removed, flags reset). Embeddings/tokenizer not shrunk." |
| if errors: |
| final_msg += f" Errors encountered: {'; '.join(errors)}" |
| return final_msg |
|
|
| except Exception as e: |
| logging.error(f"Error reverting multi-modal setup: {e}\n{traceback.format_exc()}") |
| config.multimodal_applied = False |
| config.supported_modalities = [] |
| config.modality_encoders = {} |
| config.modality_projection_dim = None |
| config.modality_special_tokens = {} |
| return f"[Error] Reverting multi-modal setup failed: {e}. Flags reset." |
|
|
|
|
| def auto_extract_text_universal(data_item): |
| if isinstance(data_item, str): |
| return data_item.strip().replace('\\n', '\n') |
| elif isinstance(data_item, bytes): |
| try: |
| return data_item.decode('utf-8', errors='replace').strip().replace('\\n', '\n') |
| except Exception: |
| return "" |
| elif isinstance(data_item, (list, tuple)): |
| texts = [auto_extract_text_universal(item) for item in data_item] |
| return " ".join(filter(None, texts)) |
| elif isinstance(data_item, dict): |
| texts = [] |
| potential_keys = [ |
| 'text', 'content', 'sentence', 'paragraph', 'article', 'abstract', |
| 'summary', 'body', 'passage', 'document', 'script', 'dialogue', |
| 'instruction', 'input', 'output', 'query', 'response', 'title', |
| 'question', 'answer', 'prompt', 'completion', 'target', 'label', |
| 'review', 'comment', 'post', 'code', 'markdown' |
| ] |
| processed_keys = set() |
|
|
| for key in potential_keys: |
| if key in data_item and key not in processed_keys: |
| value = data_item[key] |
| extracted = auto_extract_text_universal(value) |
| if extracted: |
| texts.append(extracted) |
| processed_keys.add(key) |
|
|
| if not texts: |
| for key, value in data_item.items(): |
| if key not in processed_keys: |
| extracted = auto_extract_text_universal(value) |
| if extracted: |
| texts.append(extracted) |
| processed_keys.add(key) |
|
|
| seen = set() |
| unique_texts = [] |
| for t in texts: |
| if t and t not in seen: |
| unique_texts.append(t) |
| seen.add(t) |
| return "\n".join(unique_texts) |
|
|
| elif isinstance(data_item, (int, float, bool)) or data_item is None: |
| return "" |
| else: |
| try: |
| return str(data_item).strip().replace('\\n', '\n') |
| except Exception: |
| return "" |
|
|
|
|
| def process_example_universal(example): |
| extracted_text = auto_extract_text_universal(example) |
| return {"text": extracted_text if extracted_text else "[EMPTY_OR_NON_TEXTUAL]"} |
|
|
|
|
| def parse_datasets(dataset_text): |
| datasets = [] |
| seen_ids = set() |
| for line_num, line in enumerate(dataset_text.strip().splitlines()): |
| line = line.strip() |
| if not line or line.startswith('#'): |
| continue |
|
|
| parts = [s.strip() for s in line.split(",") if s.strip()] |
| ds_name = None |
| ds_config = None |
| ds_split = 'train' |
| ds_weight = 1.0 |
|
|
| if len(parts) >= 1: |
| ds_name = parts[0] |
| if len(parts) >= 2 and parts[1]: |
| ds_config = parts[1] if parts[1].lower() != 'none' else None |
| if len(parts) >= 3 and parts[2]: |
| ds_split = parts[2] |
| if len(parts) >= 4: |
| try: |
| ds_weight = float(parts[3]) |
| if ds_weight <= 0: |
| raise ValueError("Weight must be positive") |
| except (ValueError, IndexError): |
| logging.warning(f"Invalid or missing weight '{parts[3] if len(parts) >= 4 else ''}' on line {line_num+1} ('{line}'). Using default 1.0.") |
| ds_weight = 1.0 |
|
|
| if ds_name: |
| dataset_id = f"{ds_name}_{ds_config or 'DEFAULT'}_{ds_split}" |
| if dataset_id in seen_ids: |
| logging.warning(f"Skipping duplicate dataset entry: {dataset_id} on line {line_num+1}") |
| continue |
|
|
| datasets.append({"id": ds_name, "config": ds_config, "split": ds_split, "weight": ds_weight}) |
| seen_ids.add(dataset_id) |
| else: |
| logging.warning(f"Skipping invalid dataset line (no name found): '{line}' on line {line_num+1}") |
|
|
| if not datasets: |
| raise ValueError("No valid dataset configurations were parsed from the input.") |
|
|
| return datasets |
|
|
|
|
| def load_datasets_from_config(datasets_config): |
| ds_list = [] |
| loaded_configs = [] |
| total_weight = 0.0 |
| logging.info(f"Attempting to load datasets based on config: {datasets_config}") |
|
|
| for config_entry in datasets_config: |
| ds_name = config_entry['id'] |
| ds_config = config_entry['config'] |
| ds_split = config_entry['split'] |
| ds_weight = config_entry['weight'] |
| dataset_identifier = f"{ds_name}{'['+ds_config+']' if ds_config else ''} (Split: {ds_split}, Weight: {ds_weight})" |
|
|
| try: |
| logging.info(f"Loading {dataset_identifier}...") |
| d = load_dataset( |
| ds_name, |
| ds_config, |
| streaming=True, |
| split=ds_split, |
| trust_remote_code=True, |
| ) |
|
|
| try: |
| peek = next(iter(d)) |
| original_columns = list(peek.keys()) |
| d = load_dataset(ds_name, ds_config, streaming=True, split=ds_split, trust_remote_code=True) |
| except StopIteration: |
| logging.warning(f"Dataset stream appears empty after loading: {dataset_identifier}. Skipping.") |
| continue |
| except Exception as peek_e: |
| logging.warning(f"Could not reliably peek into dataset {dataset_identifier} to get columns: {peek_e}. Will attempt processing without column removal.") |
| original_columns = None |
|
|
| logging.info(f"Processing {dataset_identifier} (Original cols: {original_columns or 'unknown'}) -> Map to 'text' field") |
| process_partial = partial(process_example_universal) |
| processed_d = d.map(process_partial, remove_columns=original_columns) |
|
|
| processed_d = processed_d.filter(lambda example: example.get("text") != "[EMPTY_OR_NON_TEXTUAL]") |
|
|
| shuffled_d = processed_d.shuffle(buffer_size=10000, seed=42) |
|
|
| ds_list.append(shuffled_d) |
| loaded_configs.append(config_entry) |
| total_weight += ds_weight |
| logging.info(f"Successfully prepared stream: {dataset_identifier}") |
|
|
| except (requests.exceptions.RequestException, gzip.BadGzipFile) as http_e: |
| logging.error(f"Network or File Error loading dataset {dataset_identifier}: {http_e}. Check connection and dataset validity. Skipping.") |
| except FileNotFoundError: |
| logging.error(f"Dataset or config not found for {dataset_identifier}. Check name/config/path. Skipping.") |
| except Exception as e: |
| logging.error(f"General Error loading/processing dataset {dataset_identifier}: {e} \n{traceback.format_exc()}. Skipping.") |
|
|
| if not ds_list: |
| raise ValueError("No valid datasets were loaded. Check dataset names, configurations, splits, availability, and network connection.") |
|
|
| logging.info(f"Successfully loaded {len(ds_list)} dataset streams.") |
|
|
| if total_weight <= 0 or len(loaded_configs) != len(ds_list): |
| probabilities = [1.0 / len(ds_list)] * len(ds_list) if ds_list else [] |
| logging.warning("Using equal probabilities for interleaving due to zero total weight, loading errors, or no datasets.") |
| else: |
| probabilities = [cfg['weight'] / total_weight for cfg in loaded_configs] |
| prob_sum = sum(probabilities) |
| if abs(prob_sum - 1.0) > 1e-6: |
| probabilities = [p / prob_sum for p in probabilities] |
|
|
| if not ds_list: |
| logging.warning("No datasets to interleave.") |
| return None |
|
|
|
|
| logging.info(f"Interleaving {len(ds_list)} datasets with probabilities: {[f'{p:.3f}' for p in probabilities]}") |
| interleaved_ds = interleave_datasets(ds_list, probabilities=probabilities, seed=42, stopping_strategy="all_exhausted") |
|
|
| return interleaved_ds |
|
|
|
|
| def tokenize_function(examples, tokenizer, context_length): |
| texts = [str(t) if t is not None else "" for t in examples["text"]] |
| tokenized_output = tokenizer(texts, truncation=False, padding=False) |
| return tokenized_output |
|
|
| def group_texts(examples, block_size): |
| concatenated_examples = {k: sum(examples[k], []) if isinstance(examples[k][0], list) else examples[k] for k in examples} |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
| if total_length >= block_size: |
| total_length = (total_length // block_size) * block_size |
| else: |
| return {k: [] for k in examples.keys()} |
|
|
| result = { |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
| for k, t in concatenated_examples.items() |
| } |
| result["labels"] = result["input_ids"].copy() |
| return result |
|
|
| def split_dataset(processed_lm_iterable_dataset): |
| eval_buffer_size = 1000 |
| shuffle_buffer_size = 10000 |
| logging.info(f"Preparing train/eval split. Eval buffer: {eval_buffer_size}, Shuffle buffer: {shuffle_buffer_size}..."); T = time.time() |
|
|
| if not isinstance(processed_lm_iterable_dataset, IterableDataset): |
| logging.error("Input dataset is not an IterableDataset. Cannot perform stream-based splitting.") |
| raise TypeError("Input to split_dataset must be an IterableDataset.") |
|
|
| shuffled_ds = processed_lm_iterable_dataset.shuffle(seed=42, buffer_size=shuffle_buffer_size) |
|
|
| logging.info(f"Taking up to {eval_buffer_size} samples for the evaluation buffer...") |
| eval_samples_iter = shuffled_ds.take(eval_buffer_size) |
|
|
| try: |
| eval_list = list(eval_samples_iter) |
| num_eval_samples = len(eval_list) |
| except Exception as e: |
| logging.error(f"Error collecting evaluation samples: {e}. Proceeding without evaluation set.") |
| num_eval_samples = 0 |
| eval_list = [] |
|
|
| train_ds = shuffled_ds |
|
|
| eval_ds_static = None |
| if num_eval_samples > 0: |
| logging.info(f"Collected {num_eval_samples} samples for evaluation buffer.") |
| train_ds = shuffled_ds.skip(num_eval_samples) |
| logging.info("Training stream prepared (skipped eval samples).") |
|
|
| logging.info("Creating static evaluation dataset from buffer...") |
| try: |
| if not eval_list: raise ValueError("Evaluation buffer list is empty after take().") |
| first_example = eval_list[0] |
| if not isinstance(first_example, dict): raise ValueError("Eval buffer items are not dictionaries.") |
|
|
| expected_keys = ['input_ids', 'attention_mask', 'labels'] |
| eval_features_dict = {} |
| for key in expected_keys: |
| if key not in first_example: |
| raise ValueError(f"Eval buffer items missing required key: '{key}'") |
| try: |
| from datasets import Sequence |
| inner_dtype = 'int64' |
| if isinstance(first_example[key], list) and first_example[key] and isinstance(first_example[key][0], int): |
| eval_features_dict[key] = Sequence(feature=Value(dtype=inner_dtype)) |
| else: |
| eval_features_dict[key] = Value(dtype='list') |
| except ImportError: |
| eval_features_dict[key] = Value(dtype='list') |
|
|
| if not eval_features_dict: |
| raise ValueError("Could not define features for evaluation dataset.") |
|
|
| eval_features = Features(eval_features_dict) |
|
|
| valid_eval_list = [] |
| required_keys_set = set(eval_features.keys()) |
| for i, ex in enumerate(eval_list): |
| if isinstance(ex, dict) and set(ex.keys()) >= required_keys_set: |
| is_valid = all(isinstance(ex.get(k), list) for k in required_keys_set) |
| if is_valid: |
| valid_eval_list.append({k: ex[k] for k in required_keys_set}) |
| else: |
| logging.warning(f"Eval buffer item {i} has invalid type for required keys. Skipping.") |
| else: |
| logging.warning(f"Eval buffer item {i} is invalid (not dict or missing keys). Skipping.") |
|
|
| if not valid_eval_list: |
| logging.warning("No valid examples remained in the evaluation buffer after validation. Eval dataset will be None.") |
| eval_ds_static = None |
| train_ds = shuffled_ds |
| else: |
| eval_ds_static = Dataset.from_list(valid_eval_list, features=eval_features) |
| logging.info(f"Created static evaluation dataset with {len(eval_ds_static)} examples.") |
|
|
| except Exception as e: |
| logging.error(f"Error creating static evaluation dataset from buffer: {e}\n{traceback.format_exc()}. Evaluation dataset will be None.") |
| eval_ds_static = None |
| train_ds = shuffled_ds |
|
|
| else: |
| logging.warning("Evaluation buffer is empty (requested size might be too large or dataset too small). Training will continue without evaluation.") |
|
|
| logging.info(f"Dataset splitting completed in {time.time()-T:.2f}s") |
| return train_ds, eval_ds_static |
|
|
|
|
| def compute_perplexity(loss): |
| if loss is None or not isinstance(loss, (int, float)) or not math.isfinite(loss): |
| return float("inf") |
| try: |
| clamped_loss = min(max(loss, -700.0), 700.0) |
| perplexity = math.exp(clamped_loss) |
| if not math.isfinite(perplexity): |
| logging.warning(f"Perplexity calculation resulted in infinity for loss {loss} (clamped: {clamped_loss}).") |
| return float("inf") |
| return perplexity |
| except OverflowError: |
| logging.warning(f"OverflowError computing perplexity for loss {loss}. Returning infinity.") |
| return float("inf") |
| except Exception as e: |
| logging.warning(f"Error computing perplexity for loss {loss}: {e}. Returning infinity.") |
| return float("inf") |
|
|
|
|
| def merge_model_parameters(original_model, trained_model, alpha=MERGE_ALPHA): |
| if not (0 <= alpha <= 1): |
| logging.error(f"Merge alpha must be between 0 and 1. Got {alpha}. Defaulting to 0.5") |
| alpha = 0.5 |
|
|
| logging.info(f"Merging model parameters with alpha={alpha:.2f} (alpha*original + (1-alpha)*trained using linear interpolation)..."); T = time.time(); |
| device = get_device() |
|
|
| original_model = original_model.to(device) |
| trained_model = trained_model.to(device) |
|
|
| merged_model = copy.deepcopy(original_model).to(device) |
|
|
| merged_params_count = 0 |
| skipped_params_count = 0 |
|
|
| orig_params = dict(original_model.named_parameters()) |
| trained_params = dict(trained_model.named_parameters()) |
| merged_params = dict(merged_model.named_parameters()) |
|
|
| with torch.no_grad(): |
| for name, trained_param in trained_params.items(): |
| if name in orig_params and name in merged_params: |
| orig_param = orig_params[name] |
| merged_param = merged_params[name] |
|
|
| if orig_param.data.shape == trained_param.data.shape: |
| merged_tensor = torch.lerp(trained_param.data.float(), orig_param.data.float(), alpha) |
| merged_param.copy_(merged_tensor.to(merged_param.dtype)) |
| merged_params_count += 1 |
| else: |
| logging.warning(f"Size mismatch for parameter '{name}'. Original: {orig_param.data.shape}, Trained: {trained_param.data.shape}. Skipping merge for this parameter.") |
| skipped_params_count += 1 |
| else: |
| if name not in orig_params: |
| logging.warning(f"Parameter '{name}' from trained model not found in original model structure. Skipping.") |
| if name not in merged_params: |
| logging.warning(f"Parameter '{name}' from trained model not found in merged model structure (shouldn't happen). Skipping.") |
| skipped_params_count += 1 |
|
|
| logging.info(f"Parameter merging finished in {time.time()-T:.2f}s. Merged {merged_params_count} parameters, skipped {skipped_params_count}.") |
| return merged_model |
|
|
|
|
| def preserve_model_quality(original_model, trained_model, eval_dataset, tokenizer): |
| if eval_dataset is None: |
| logging.warning("No evaluation data provided (eval_dataset is None). Cannot perform quality check. Returning trained model.") |
| return trained_model |
|
|
| is_iterable = isinstance(eval_dataset, IterableDataset) |
| if is_iterable: |
| logging.warning("Evaluation dataset is iterable. Loss comparison might not be on the exact same data. Proceeding with caution.") |
| try: |
| _ = next(iter(eval_dataset.take(1))) |
| except StopIteration: |
| logging.warning("Iterable evaluation dataset appears empty. Returning trained model.") |
| return trained_model |
| except Exception as e: |
| logging.warning(f"Could not peek into iterable eval dataset: {e}. Assuming not empty.") |
| elif isinstance(eval_dataset, Dataset): |
| if len(eval_dataset) == 0: |
| logging.warning("Evaluation dataset is empty (length 0). Returning trained model.") |
| return trained_model |
| else: |
| logging.warning(f"Unknown evaluation dataset type: {type(eval_dataset)}. Cannot perform quality check. Returning trained model.") |
| return trained_model |
|
|
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
| eval_batch_size = max(1, BATCH_SIZE // 2) |
| device = get_device() |
|
|
| original_model.to(device).eval() |
| trained_model.to(device).eval() |
| was_training_orig = original_model.training |
| trained_model.train() |
|
|
| temp_eval_dir = "./tmp_eval_quality_check" |
| eval_args = TrainingArguments( |
| output_dir=temp_eval_dir, |
| per_device_eval_batch_size=eval_batch_size, |
| report_to=[], |
| dataloader_num_workers=max(1, (NUM_CPU_CORES if NUM_CPU_CORES > 0 else os.cpu_count() // 2)), |
| fp16=torch.cuda.is_available() and not USE_CPU and original_model.dtype == torch.float16, |
| bf16=(torch.cuda.is_available() and torch.cuda.is_bf16_supported()) and original_model.dtype == torch.bfloat16, |
| use_cpu=USE_CPU, |
| log_level='error', |
| remove_unused_columns=False, |
| ) |
|
|
| results = {} |
| eval_error = False |
| for model_name, model_instance in [("Original", original_model), ("Trained", trained_model)]: |
| logging.info(f"Evaluating {model_name} model for quality check..."); T_eval = time.time() |
|
|
| current_eval_dataset = eval_dataset |
| if is_iterable: |
| current_eval_dataset = eval_dataset.take(1000) |
| try: |
| if len(list(iter(current_eval_dataset.take(1)))) == 0: |
| logging.warning(f"Iterable eval sample for {model_name} is empty. Skipping eval.") |
| results[model_name] = {"loss": float('inf'), "ppl": float('inf')} |
| continue |
| current_eval_dataset = eval_dataset.take(1000) |
| except Exception as e: |
| logging.error(f"Error handling iterable dataset sample for {model_name}: {e}") |
| results[model_name] = {"loss": float('inf'), "ppl": float('inf')} |
| eval_error = True; break |
|
|
| trainer = Trainer( |
| model=model_instance, |
| args=eval_args, |
| data_collator=data_collator, |
| eval_dataset=current_eval_dataset |
| ) |
| try: |
| model_instance.eval() |
| metrics = trainer.evaluate() |
| loss = metrics.get("eval_loss") |
| ppl = compute_perplexity(loss) |
| results[model_name] = {"loss": loss if loss is not None else float('inf'), "ppl": ppl} |
| logging.info(f"{model_name} Eval Loss: {loss if loss is not None else 'N/A':.4f}, PPL: {ppl:.4f} (Eval time: {time.time()-T_eval:.2f}s)") |
| except StopIteration: |
| logging.error(f"Evaluation dataset exhausted unexpectedly during evaluation of {model_name}. Comparison may be incomplete.") |
| results[model_name] = {"loss": float('inf'), "ppl": float('inf')} |
| eval_error = True; break |
| except Exception as e: |
| logging.error(f"Error evaluating {model_name} model: {e}\n{traceback.format_exc()}") |
| results[model_name] = {"loss": float('inf'), "ppl": float('inf')} |
| eval_error = True; break |
|
|
| if os.path.exists(temp_eval_dir): |
| try: |
| shutil.rmtree(temp_eval_dir) |
| except Exception as e: |
| logging.warning(f"Could not remove temporary eval directory {temp_eval_dir}: {e}") |
|
|
| original_model.train(mode=was_training_orig) |
| trained_model.train() |
|
|
| original_loss = results.get("Original", {}).get("loss", float('inf')) |
| trained_loss = results.get("Trained", {}).get("loss", float('inf')) |
|
|
| if eval_error: |
| logging.error("Evaluation encountered errors. Cannot reliably compare models. Returning trained model.") |
| return trained_model |
|
|
| valid_comparison = math.isfinite(original_loss) and math.isfinite(trained_loss) |
|
|
| if valid_comparison: |
| loss_threshold = original_loss * 1.05 |
| if trained_loss > loss_threshold: |
| logging.warning(f"Trained model loss ({trained_loss:.4f}) is significantly worse (>5%) than original ({original_loss:.4f}). Reverting to original model state based on quality check.") |
| return original_model.to(device) |
| elif trained_loss > original_loss: |
| logging.info(f"Trained model loss ({trained_loss:.4f}) is slightly worse than original ({original_loss:.4f}), but within threshold. Keeping trained model.") |
| return trained_model.to(device) |
| else: |
| logging.info(f"Trained model loss ({trained_loss:.4f}) is better than or equal to original ({original_loss:.4f}). Keeping trained model.") |
| return trained_model.to(device) |
| else: |
| logging.warning("Could not perform valid loss comparison (one or both evaluations failed or yielded non-finite loss). Returning trained model.") |
| return trained_model.to(device) |
|
|
|
|
| def _merge_architectures(model_ids_str, hf_token=None, bypass_limits_state=False): |
| global global_model, global_tokenizer, config, global_pipe, BYPASS_RESOURCE_LIMITS |
| BYPASS_RESOURCE_LIMITS = bypass_limits_state |
|
|
| if not isinstance(model_ids_str, str) or not model_ids_str.strip(): |
| return "[Error] Model IDs string cannot be empty.", "{}", *get_error_filter_updates() |
|
|
| resources_ok, res_msg = check_resources() |
| if not resources_ok: |
| error_msg = f"[Error] Resource limits exceeded, cannot proceed with merge. {res_msg}" |
| logging.error(error_msg) |
| return error_msg, "{}", *get_error_filter_updates() |
| else: |
| logging.info(res_msg) |
|
|
| model_ids = [m.strip() for m in model_ids_str.split(',') if m.strip()] |
| if len(model_ids) < 2: |
| return "[Error] Need at least two valid model IDs/paths separated by commas to merge.", "{}", *get_error_filter_updates() |
|
|
| logging.info(f"Starting architecture merge (parameter averaging) for models: {model_ids}") |
| device = get_device() |
| merged_model = None |
| t_merge_start = time.time() |
| base_model_id = model_ids[0] |
|
|
| try: |
| logging.info(f"Loading base config and tokenizer from: {base_model_id}") |
| base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) |
| base_config = AutoConfig.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) |
| if base_tokenizer.pad_token is None and base_tokenizer.eos_token is not None: |
| base_tokenizer.pad_token = base_tokenizer.eos_token |
| base_config.pad_token_id = base_config.eos_token_id |
| logging.info("Set base tokenizer pad_token to eos_token for consistency.") |
|
|
| except Exception as e: |
| logging.error(f"Failed to load base config/tokenizer for {base_model_id}: {e}") |
| return f"[Error] Failed to load base model config/tokenizer: {e}", "{}", *get_error_filter_updates() |
|
|
| try: |
| logging.info(f"Loading base model state dict (CPU, float32) for merging: {base_model_id}") |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| trust_remote_code=True, |
| token=hf_token, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
| base_state_dict = base_model.state_dict() |
| merged_state_dict = OrderedDict((k, v.clone()) for k, v in base_state_dict.items()) |
| param_counts = OrderedDict((k, 1) for k in base_state_dict) |
| num_models_processed = 1 |
| del base_model, base_state_dict |
| clean_memory() |
|
|
| except Exception as e: |
| logging.error(f"Failed to load base model state dict for {base_model_id}: {e}") |
| return f"[Error] Failed to load base model state dict: {e}", "{}", *get_error_filter_updates() |
|
|
| for i, model_id in enumerate(model_ids[1:]): |
| logging.info(f"Processing model {i+2}/{len(model_ids)}: {model_id}") |
| try: |
| model_i = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| token=hf_token, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
| state_dict_i = model_i.state_dict() |
|
|
| for name, param_i in state_dict_i.items(): |
| if name in merged_state_dict: |
| if merged_state_dict[name].shape == param_i.shape: |
| merged_state_dict[name].add_(param_i) |
| param_counts[name] += 1 |
| else: |
| logging.warning(f"Shape mismatch for parameter '{name}' between base and {model_id}. Base: {merged_state_dict[name].shape}, Current: {param_i.shape}. Parameter '{name}' will NOT include contribution from {model_id}.") |
| else: |
| logging.warning(f"Parameter '{name}' found in {model_id} but not in base model {base_model_id}. Skipping this parameter.") |
|
|
| num_models_processed += 1 |
| del model_i, state_dict_i |
| clean_memory() |
|
|
| except Exception as e: |
| logging.error(f"Failed to load or process model {model_id}: {e}. Skipping this model for merge.") |
| continue |
|
|
| if num_models_processed < 2: |
| msg = "Merge failed: Fewer than two models were successfully loaded and processed." |
| logging.error(msg) |
| return f"[Error] {msg}", "{}", *get_error_filter_updates() |
|
|
| averaged_count = 0 |
| for name in merged_state_dict: |
| count = param_counts.get(name, 0) |
| if count > 0: |
| merged_state_dict[name].div_(count) |
| averaged_count +=1 |
| else: |
| logging.error(f"Parameter '{name}' has count {count} <= 0 during averaging. This indicates a logic error.") |
|
|
| logging.info(f"Averaged {averaged_count} parameters across {num_models_processed} successfully processed models.") |
|
|
| try: |
| logging.info("Creating final merged model from base config and averaged weights...") |
| merged_model = AutoModelForCausalLM.from_config(base_config, trust_remote_code=True) |
| load_results = merged_model.load_state_dict(merged_state_dict, strict=False) |
|
|
| if load_results.missing_keys: |
| logging.warning(f"Load state dict results: Missing keys: {load_results.missing_keys}") |
| if load_results.unexpected_keys: |
| logging.warning(f"Load state dict results: Unexpected keys: {load_results.unexpected_keys}") |
|
|
| final_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32 |
| logging.info(f"Converting merged model to {final_dtype} for use on device {device}.") |
|
|
| global_model = merged_model.to(device=device, dtype=final_dtype) |
| global_tokenizer = base_tokenizer |
| config = initialize_config_flags(global_model.config) |
| config.architecture_merged = True |
| config.merged_from_models = model_ids |
| config.merged_models_processed = num_models_processed |
|
|
| update_pipeline() |
| clean_memory() |
|
|
| final_status_json, *filter_updates = get_detailed_status_and_filter_states() |
| merge_time = time.time() - t_merge_start |
| msg = f"Successfully merged architectures (averaged parameters) from {num_models_processed} models in {merge_time:.2f}s. Base config/tokenizer from: {base_model_id}." |
| logging.info(msg) |
| return msg, final_status_json, *filter_updates |
|
|
| except Exception as e: |
| logging.error(f"Architecture merging failed during final model creation or state update: {e}\n{traceback.format_exc()}") |
| global_model = None |
| global_tokenizer = None |
| config = None |
| global_pipe = None |
| clean_memory() |
| return f"[Error] Architecture merging failed: {e}", "{}", *get_error_filter_updates() |
|
|
|
|
| def get_user_id(token): |
| if not token: |
| logging.warning("No Hugging Face token provided for user ID check.") |
| return "unknown_user" |
| try: |
| api = HfApi() |
| user_info = api.whoami(token=token) |
| return user_info.get("name", "unknown_user") |
| except requests.exceptions.HTTPError as http_err: |
| if http_err.response.status_code == 401: |
| logging.error("Hugging Face authentication failed (401 Unauthorized). Check your token.") |
| return "auth_error_user" |
| else: |
| logging.error(f"HTTP error retrieving Hugging Face user ID: {http_err}") |
| return "http_error_user" |
| except Exception as e: |
| logging.error(f"Could not retrieve Hugging Face user ID: {e}") |
| return "unknown_user" |
|
|
| def decode_model_details(model): |
| if model is None: |
| return json.dumps({"Error": "Model not loaded."}, indent=2) |
| if not hasattr(model, 'config'): |
| logging.warning("Model object lacks a 'config' attribute.") |
| details = OrderedDict() |
| details["Model Class"] = type(model).__name__ |
| details["Error"] = "Model config attribute not found." |
| return json.dumps(details, indent=2) |
|
|
| details = OrderedDict() |
| config_obj = model.config |
| t_start_decode = time.time() |
| logging.info("Decoding model details...") |
|
|
| try: |
| details["Model Class"] = type(model).__name__ |
| details["Config Class"] = getattr(config_obj, 'config_class', type(config_obj).__name__) |
| details["Model Type"] = getattr(config_obj, 'model_type', 'N/A') |
|
|
| total_params = 0 |
| trainable_params = 0 |
| param_dtypes = set() |
| param_devices = set() |
| try: |
| for name, param in model.named_parameters(): |
| num_elements = param.numel() |
| total_params += num_elements |
| param_dtypes.add(str(param.dtype).replace('torch.', '')) |
| param_devices.add(str(param.device)) |
| if param.requires_grad: |
| trainable_params += num_elements |
| if not param_devices: |
| device_str = "N/A (No parameters)" |
| elif len(param_devices) == 1: |
| device_str = param_devices.pop() |
| else: |
| device_str = f"Multiple ({', '.join(param_devices)})" |
| except Exception as e: |
| logging.warning(f"Could not fully analyze parameters: {e}") |
| device_str = "Error analyzing params" |
|
|
| details["Device(s)"] = device_str |
| trainable_perc = (100 * trainable_params / total_params) if total_params > 0 else 0.00 |
| details["Params Summary"] = (f"Total: {total_params:,}, Trainable: {trainable_params:,} " |
| f"({trainable_perc:.2f}%), Dtypes: {list(param_dtypes)}") |
|
|
| try: |
| layer_counts = Counter(type(m).__name__ for m in model.modules() if not isinstance(m, nn.Sequential)) |
| details["Layer Types Count"] = dict(layer_counts.most_common(15)) |
| except Exception as e: |
| logging.warning(f"Could not count layer types: {e}") |
| details["Layer Types Count"] = "Error counting layers" |
|
|
| details["Modification Flags"] = {} |
| all_flags = initialize_config_flags(None).__dict__.keys() |
| for flag in sorted(all_flags): |
| if hasattr(config_obj, flag): |
| value = getattr(config_obj, flag) |
| details["Modification Flags"][flag] = value |
|
|
| details["Key Config Attributes"] = {} |
| key_attrs = ['vocab_size', 'hidden_size', 'num_hidden_layers', 'num_attention_heads', |
| 'intermediate_size', 'max_position_embeddings', 'hidden_act', 'layer_norm_eps', |
| 'rms_norm_eps', 'attention_dropout', 'hidden_dropout_prob', 'initializer_range', |
| 'tie_word_embeddings', 'rope_scaling', 'sliding_window', 'attn_implementation'] |
| for attr in key_attrs: |
| if hasattr(config_obj, attr): |
| details["Key Config Attributes"][attr] = getattr(config_obj, attr) |
|
|
|
|
| logging.info(f"Model details decoded in {time.time() - t_start_decode:.2f}s") |
| return json.dumps(details, indent=2, default=str) |
| except Exception as e: |
| logging.error(f"Error decoding model details: {e} \n{traceback.format_exc()}") |
| details["Error"] = f"Failed during detail decoding: {e}" |
| return json.dumps(details, indent=2, default=str) |
|
|
|
|
| def update_pipeline(): |
| global global_model, global_tokenizer, global_pipe |
| if global_model and global_tokenizer: |
| device = get_device() |
|
|
| pipeline_device_arg = None |
| device_map = None |
|
|
| if device.type == 'cpu': |
| pipeline_device_arg = -1 |
| logging.info("Configuring pipeline for CPU.") |
| elif device.type == 'cuda': |
| if torch.cuda.device_count() > 1: |
| device_map = "auto" |
| pipeline_device_arg = None |
| logging.info("Multiple GPUs detected, configuring pipeline with device_map='auto'.") |
| else: |
| pipeline_device_arg = 0 |
| logging.info("Configuring pipeline for single CUDA device (device=0).") |
| elif device.type == 'mps': |
| pipeline_device_arg = 0 |
| logging.info("Configuring pipeline for MPS device (device=0).") |
| else: |
| pipeline_device_arg = -1 |
| logging.warning(f"Unknown device type '{device.type}', configuring pipeline for CPU.") |
|
|
| logging.info(f"Updating text generation pipeline (Device Arg: {pipeline_device_arg}, Device Map: {device_map})..."); T=time.time() |
| try: |
| if device_map is None and pipeline_device_arg is not None: |
| if pipeline_device_arg == -1: |
| global_model.to('cpu') |
| elif device.type == 'cuda': |
| global_model.to(f'cuda:{pipeline_device_arg}') |
| elif device.type == 'mps': |
| global_model.to('mps:0') |
|
|
| task = "text-generation" |
|
|
| global_pipe = pipeline( |
| task=task, |
| model=global_model, |
| tokenizer=global_tokenizer, |
| device=pipeline_device_arg, |
| device_map=device_map |
| ) |
|
|
| pipe_device_str = "N/A" |
| if global_pipe.device_map: |
| pipe_device_str = f"device_map: {global_pipe.device_map}" |
| elif global_pipe.device: |
| pipe_device_str = str(global_pipe.device) |
| logging.info(f"Text generation pipeline created/updated. Effective device(s): {pipe_device_str}") |
|
|
| if device_map is None and global_pipe.device != device: |
| logging.warning(f"Pipeline created on {global_pipe.device}, but target device was {device}. This might happen with device_map issues or insufficient VRAM.") |
|
|
| msg = f"Text generation pipeline updated successfully in {time.time()-T:.2f}s."; logging.info(msg) |
| return msg |
| except Exception as e: |
| msg=f"Pipeline update failed: {e}\n{traceback.format_exc()}"; logging.error(msg); |
| global_pipe = None; |
| return f"[Error] Pipeline update failed: {e}" |
| else: |
| msg = "Cannot update pipeline: Global model or tokenizer not loaded."; logging.warning(msg); |
| global_pipe = None; |
| return msg |
|
|
|
|
| def get_detailed_status_and_filter_states(): |
| global global_model, config |
| t_start = time.time() |
|
|
| if global_model is None: |
| logging.warning("Cannot get status: Model not loaded.") |
| return json.dumps({"Error": "Model not loaded."}, indent=2), *get_error_filter_updates() |
|
|
| if not hasattr(global_model, 'config') or global_model.config is None: |
| logging.warning("Model config missing. Initializing default flags for status check.") |
| temp_config = initialize_config_flags(None) |
| status_json = json.dumps({"Warning": "Model config missing, status reflects defaults.", **json.loads(decode_model_details(global_model))}, indent=2) |
| config_to_check = temp_config |
| else: |
| config = global_model.config |
| config = initialize_config_flags(config) |
| global_model.config = config |
| status_json = decode_model_details(global_model) |
| config_to_check = config |
|
|
| logging.info("Refreshing detailed model status and filter checkbox states...") |
|
|
| filter_states = {} |
| for name in filter_names_ui: |
| attr_name = filter_attr_map.get(name) |
| if attr_name: |
| filter_states[name] = getattr(config_to_check, attr_name, False) |
| else: |
| logging.error(f"Filter name '{name}' not found in attribute map. Setting state to False.") |
| filter_states[name] = False |
|
|
| updates = [gr.update(value=filter_states.get(name, False)) for name in filter_names_ui] |
|
|
| logging.info(f"Refreshed status and filter states in {time.time()-t_start:.2f}s."); |
| return status_json, *updates |
|
|
| def get_error_filter_updates(): |
| return [gr.update(value=False) for _ in filter_names_ui] |
|
|
| def base_toggle_function(func_enable, func_disable, enable, success_msg_enable, success_msg_disable, *args): |
| global global_model, config |
| t_start = time.time() |
|
|
| if not global_model: |
| return "[Error] Model not loaded. Load a model first." |
|
|
| if not hasattr(global_model, 'config') or global_model.config is None: |
| logging.warning("Model config missing. Initializing default flags before toggle.") |
| global_model.config = initialize_config_flags(None) |
| config = initialize_config_flags(global_model.config) |
| global_model.config = config |
|
|
| msg = "" |
| func_to_call = func_enable if enable else func_disable |
| action_name = "Enable" if enable else "Disable" |
| func_name = getattr(func_enable, '__name__', 'unknown_enable').replace('_', ' ').title() if enable else \ |
| getattr(func_disable, '__name__', 'unknown_disable').replace('_', ' ').title() |
|
|
| logging.info(f"Executing toggle: {action_name} {func_name}...") |
|
|
| try: |
| sig = inspect.signature(func_to_call) |
| pass_args = [] |
| if 'model' in sig.parameters or 'base_model' in sig.parameters or 'module' in sig.parameters: |
| pass_args.append(global_model) |
| if 'config' in sig.parameters: |
| pass_args.append(config) |
| pass_args.extend(args) |
|
|
| result = func_to_call(*pass_args) |
|
|
| if isinstance(result, str) and "[Error]" not in result: |
| msg = result |
| elif isinstance(result, str): |
| msg = result |
| else: |
| msg = success_msg_enable if enable else success_msg_disable |
|
|
| logging.info(f"Toggle Action ({func_name} -> {action_name}) Result: {msg} (Took {time.time()-t_start:.2f}s)") |
| if "[Error]" not in msg: |
| update_pipeline() |
|
|
| except Exception as e: |
| msg = f"[Error] during toggle ({action_name} {func_name}): {e}" |
| logging.error(f"{msg}\n{traceback.format_exc()}") |
|
|
| clean_memory() |
| return msg |
|
|
| def specific_action_function(action_func, *args, success_msg="Action completed successfully."): |
| global global_model, global_tokenizer, config |
| t_start=time.time() |
|
|
| if not global_model: |
| return "[Error] Model not loaded. Load a model first." |
|
|
| if not hasattr(global_model, 'config') or global_model.config is None: |
| logging.warning("Model config missing. Initializing default flags before action.") |
| global_model.config = initialize_config_flags(None) |
| config = initialize_config_flags(global_model.config) |
| global_model.config = config |
|
|
| msg = "" |
| func_name = getattr(action_func, '__name__', 'unknown_action') |
|
|
| logging.info(f"Executing action: {func_name}...") |
|
|
| try: |
| sig = inspect.signature(action_func) |
| pass_args = [] |
| if 'model' in sig.parameters or 'base_model' in sig.parameters or 'module' in sig.parameters: |
| pass_args.append(global_model) |
| if 'config' in sig.parameters: |
| pass_args.append(config) |
| if 'tokenizer' in sig.parameters: |
| if global_tokenizer: |
| pass_args.append(global_tokenizer) |
| else: |
| return f"[Error] Action '{func_name}' requires tokenizer, but it's not loaded." |
| pass_args.extend(args) |
|
|
| result = action_func(*pass_args) |
|
|
| if isinstance(result, str) and "[Error]" not in result: |
| msg = result |
| elif isinstance(result, str): |
| msg = result |
| else: |
| msg = success_msg |
|
|
| logging.info(f"Action ({func_name}) Result: {msg} (Took {time.time()-t_start:.2f}s)") |
| if "[Error]" not in msg: |
| update_pipeline() |
|
|
| except Exception as e: |
| msg = f"[Error] during action ({func_name}): {e}" |
| logging.error(f"{msg}\n{traceback.format_exc()}") |
|
|
| clean_memory() |
| return msg |
|
|
| toggle_bias_removal_wrapper = lambda enable: base_toggle_function(_replace_linear_without_bias, _enable_bias_in_linear, enable, "Bias removal applied.", "Bias addition applied (reverted removal).") |
| toggle_embeddings_untie_wrapper = lambda enable: base_toggle_function(_untie_embeddings, _retie_embeddings, enable, "Embeddings untied.", "Embeddings re-tied.") |
| toggle_layer_reduction_wrapper = lambda enable, layers: specific_action_function(_reduce_layers_to_one if enable else _enable_full_layers, layers if enable else None, success_msg=f"Layer reduction {'applied' if enable else 'reverted'}.") |
| apply_norm_swap_wrapper = lambda norm_type: specific_action_function(_swap_normalization_layer, norm_type, success_msg=f"Normalization swapped to {norm_type}") |
| apply_activation_change_wrapper = lambda name: specific_action_function(_swap_activation_function, name, success_msg=f"Activation Function Swapped to {name}") |
| revert_activation_change_wrapper = lambda: specific_action_function(_revert_activation_function, success_msg="Activation Function Reverted to Default") |
| toggle_bitnet_wrapper = lambda enable: base_toggle_function(convert_to_bitnet, revert_bitnet, enable, "BitNet conversion applied.", "BitNet conversion reverted.") |
| apply_multimodal_wrapper = lambda modalities: specific_action_function(_setup_multimodal, modalities, success_msg="Multi-modal setup attempted.") |
| revert_multimodal_wrapper = lambda: specific_action_function(_revert_multimodal, success_msg="Multi-modal setup reverted.") |
|
|
| toggle_token_speed_optimization_wrapper = lambda enable: specific_action_function(_optimize_token_generation_speed if enable else _revert_token_generation_speed_optimization, success_msg="Token Speed Opt Flags Updated") |
| toggle_coherence_improvement_wrapper = lambda enable: specific_action_function(_enable_coherence_improvement if enable else _disable_coherence_improvement, success_msg="Coherence Flag Updated") |
| toggle_layer_norm_bypass_wrapper = lambda enable: specific_action_function(_enable_layer_norm_bypass if enable else _disable_layer_norm_bypass, success_msg="LN Bypass Updated") |
| toggle_dropout_bypass_wrapper = lambda enable: specific_action_function(_enable_dropout_bypass if enable else _disable_dropout_bypass, success_msg="Dropout Bypass Updated") |
| toggle_fp32_precision_wrapper = lambda enable: specific_action_function(_recover_perfect_precision if enable else _revert_perfect_precision, success_msg="FP32 Precision Updated") |
| toggle_embedding_normalization_wrapper = lambda enable: specific_action_function(_normalize_embeddings if enable else _revert_embedding_normalization, success_msg="Embedding Normalization Updated") |
| toggle_gradient_checkpointing_wrapper = lambda enable: specific_action_function(_enable_gradient_checkpointing if enable else _disable_gradient_checkpointing, success_msg="Grad Checkpointing Updated") |
| toggle_flash_attention_wrapper = lambda enable: specific_action_function(_set_attention_variant_config, "flash_attention_2" if enable else "auto", success_msg=f"Flash Attention 2 {'Enabled' if enable else 'Disabled'} (via attn_implementation)") |
| apply_quantization_wrapper = lambda mode: specific_action_function(_quantize_model, mode, success_msg=f"Quantization Attempted: {mode}") |
| revert_quantization_wrapper = lambda: specific_action_function(_revert_quantization, success_msg="Quantization Reverted to FP32") |
|
|
| def _parse_pruning_amount(amount_str): |
| try: |
| amount = float(amount_str) |
| if not (0 < amount < 1): |
| raise ValueError("Pruning amount must be between 0 and 1") |
| return amount |
| except (ValueError, TypeError): |
| logging.warning(f"Invalid pruning amount '{amount_str}', using default {PRUNING_AMOUNT}") |
| return PRUNING_AMOUNT |
|
|
| apply_pruning_wrapper = lambda amount_str: specific_action_function( |
| _prune_weights_magnitude, |
| _parse_pruning_amount(amount_str), |
| success_msg=f"Pruning Applied (Amount: {_parse_pruning_amount(amount_str):.2f})" |
| ) |
| revert_pruning_wrapper = lambda: specific_action_function(_revert_pruning, success_msg="Pruning Flag Reverted") |
|
|
| set_lora_path_wrapper = lambda path: specific_action_function(_set_lora_adapter_path, path, success_msg="LoRA Path Set in Config") |
| add_peft_adapter_wrapper = lambda: specific_action_function( |
| _add_peft_adapter, |
| LoraConfig(**DEFAULT_PEFT_CONFIG_DICT) if _peft_installed else None, |
| success_msg="PEFT Adapter Added" |
| ) |
| merge_peft_adapter_wrapper = lambda: specific_action_function(_apply_lora_merge, success_msg="PEFT Adapter Merged") |
| remove_peft_adapter_wrapper = lambda: specific_action_function(_remove_peft_adapter, success_msg="PEFT Adapter Removed") |
|
|
| apply_layer_freeze_wrapper = lambda layers_str: specific_action_function(_freeze_layers, layers_str, success_msg="Layer Freezing Updated") |
| revert_layer_freeze_wrapper = lambda: specific_action_function(_unfreeze_all_layers, success_msg="All Layers Unfrozen") |
| toggle_limits_wrapper = lambda enable: specific_action_function(_configure_limits if enable else _remove_limits_configuration, success_msg="Limits Config Updated") |
| toggle_qa_restrictions_wrapper = lambda enable: specific_action_function(_remove_qa_restrictions if enable else _enable_qa_restrictions, success_msg="QA Restrictions Flag Updated") |
|
|
| def _parse_int_arg(arg, default, min_val=1): |
| try: |
| val = int(arg) |
| return max(val, min_val) |
| except (ValueError, TypeError): |
| return default |
|
|
| toggle_kd_wrapper = lambda enable, num_labels=2: specific_action_function( |
| _setup_knowledge_distillation if enable else _revert_knowledge_distillation, |
| _parse_int_arg(num_labels, 2, 1) if enable else (), |
| success_msg="KD Setup Updated" |
| ) |
| toggle_reward_modeling_wrapper = lambda enable, num_outputs=1: specific_action_function( |
| _setup_reward_modeling if enable else _revert_reward_modeling, |
| _parse_int_arg(num_outputs, 1, 1) if enable else (), |
| success_msg="Reward Modeling Setup Updated" |
| ) |
| toggle_swa_wrapper = lambda enable: specific_action_function(_apply_swa if enable else _revert_swa, success_msg="SWA Flag Updated") |
|
|
| def _parse_prob_arg(arg, default, min_val=0.0, max_val=1.0): |
| try: |
| val = float(arg) |
| return min(max(val, min_val), max_val) |
| except(ValueError, TypeError): |
| return default |
|
|
| toggle_layerdrop_wrapper = lambda enable, prob=0.1: specific_action_function( |
| _enable_layerdrop if enable else _disable_layerdrop, |
| _parse_prob_arg(prob, 0.1, 0.0, 1.0) if enable else (), |
| success_msg="LayerDrop Flag Updated" |
| ) |
| toggle_rope_scaling_wrapper = lambda enable, type="linear", factor=2.0: specific_action_function( |
| _set_rope_scaling_config if enable else _revert_rope_scaling, |
| str(type) if enable else (), |
| _parse_prob_arg(factor, 2.0, 1.0, 100.0) if enable else (), |
| success_msg="RoPE Scaling Config Updated" |
| ) |
| toggle_sliding_window_wrapper = lambda enable, size=4096: specific_action_function( |
| _set_sliding_window_config if enable else _revert_sliding_window, |
| _parse_int_arg(size, 4096, 0) if enable else (), |
| success_msg="Sliding Window Config Updated" |
| ) |
| apply_attention_variant_wrapper = lambda variant="auto": specific_action_function(_set_attention_variant_config, str(variant), success_msg="Attention Variant Config Updated") |
| revert_attention_variant_wrapper = lambda: specific_action_function(_revert_attention_variant, success_msg="Attention Variant Config Reverted") |
|
|
| toggle_gradient_clipping_flag_wrapper = lambda enable: specific_action_function(_enable_gradient_clipping if enable else _disable_gradient_clipping, success_msg="Grad Clipping Flag Updated") |
| toggle_weight_decay_flag_wrapper = lambda enable: specific_action_function(_enable_weight_decay if enable else _disable_weight_decay, success_msg="Weight Decay Flag Updated") |
| toggle_lr_scheduler_flag_wrapper = lambda enable: specific_action_function(_enable_lr_scheduler if enable else _disable_lr_scheduler, success_msg="LR Scheduler Flag Updated") |
| apply_optimizer_change_wrapper = lambda name: specific_action_function(_swap_optimizer, str(name), success_msg=f"Optimizer Pref Set: {name}") |
| revert_optimizer_change_wrapper = lambda: specific_action_function(_revert_optimizer, success_msg="Optimizer Pref Reverted") |
|
|
| def _set_grad_accum_config(config, steps): |
| try: |
| s = int(steps) |
| if s < 1: raise ValueError("Steps must be >= 1") |
| config.gradient_accumulation_steps = s |
| global GRADIENT_ACCUMULATION_STEPS |
| GRADIENT_ACCUMULATION_STEPS = s |
| return f"Grad Accum Steps set to {s} in config." |
| except (ValueError, TypeError) as e: |
| logging.error(f"Invalid gradient accumulation steps: {steps}. Error: {e}") |
| return f"[Error] Invalid Grad Accum steps: {e}" |
|
|
| set_gradient_accumulation_wrapper = lambda steps: specific_action_function(_set_grad_accum_config, steps, success_msg=f"Grad Accum Steps update attempted.") |
|
|
| toggle_all_safety_filters_wrapper = lambda enable: specific_action_function(_enable_all_safety_settings if enable else _disable_all_safety_settings, success_msg=f"All Safety Filters {'Enabled (Defaults)' if enable else 'Disabled'}") |
| force_disable_censorship_wrapper = lambda: specific_action_function(_disable_all_safety_settings, success_msg="Attempted Force Disable All Censorship Flags") |
|
|
| def toggle_individual_safety_filter_wrapper(*state_dict): |
| global global_model, config |
| t_start=time.time() |
| if not global_model: return "[Error] Model not loaded." |
|
|
| if not hasattr(global_model, 'config') or global_model.config is None: |
| logging.warning("Model config missing. Initializing default flags for filter toggle.") |
| global_model.config = initialize_config_flags(None) |
| config = initialize_config_flags(global_model.config) |
| global_model.config = config |
|
|
| results = [] |
| updated_count = 0 |
|
|
| if len(state_dict) != len(filter_names_ui): |
| return f"[Error] Mismatch between filter UI elements ({len(filter_names_ui)}) and received states ({len(state_dict)})." |
|
|
| ui_state = dict(zip(filter_names_ui, state_dict)) |
|
|
| for name, checkbox_state in ui_state.items(): |
| filter_attr = filter_attr_map.get(name) |
| if filter_attr: |
| current_state = getattr(config, filter_attr, False) |
| new_state = bool(checkbox_state) |
| if current_state != new_state: |
| setattr(config, filter_attr, new_state) |
| results.append(f"{name}: {'ON' if new_state else 'OFF'}") |
| updated_count += 1 |
| else: |
| logging.warning(f"UI filter name '{name}' not found in attribute map filter_attr_map. Skipping.") |
|
|
| if updated_count > 0: |
| msg = f"Applied {updated_count} individual filter toggle(s): {', '.join(results)}" |
| update_pipeline() |
| else: |
| msg = "No individual filter states were changed." |
|
|
| logging.info(f"Individual filter toggle action took {time.time()-t_start:.2f}s. Status: {msg}"); |
| return msg |
|
|
|
|
| def _improve_coherence(model, tokenizer, generation_args): |
| logging.info("Applying coherence improvement using beam search...") |
| coherence_beams = generation_args.get("num_beams", 1) |
| if coherence_beams <= 1: coherence_beams = 4 |
|
|
| coherence_args = generation_args.copy() |
| coherence_args["num_beams"] = coherence_beams |
| coherence_args["do_sample"] = False |
| coherence_args["num_return_sequences"] = 1 |
| coherence_args["early_stopping"] = True |
| coherence_args.pop("temperature", None) |
| coherence_args.pop("top_k", None) |
| coherence_args.pop("top_p", None) |
|
|
| input_ids = coherence_args.get("input_ids") |
| if input_ids is None: |
| logging.error("Coherence improvement failed: input_ids missing.") |
| return "[Error: input_ids missing in generation_args]" |
|
|
| try: |
| with torch.no_grad(): |
| outputs = model.generate(**coherence_args) |
| response_ids = outputs[0][input_ids.shape[-1]:] |
| response_text = tokenizer.decode(response_ids, skip_special_tokens=True) |
| logging.info("Coherence improvement (beam search) successful.") |
| return response_text |
| except Exception as e: |
| logging.error(f"Error during coherence improvement (beam search): {e}. Falling back to original generation settings.") |
| fallback_args = generation_args.copy() |
| if fallback_args.get("num_beams", 1) > 1 and fallback_args.get("do_sample", False): |
| fallback_args["num_beams"] = 1 |
| fallback_args["early_stopping"] = False |
|
|
| try: |
| with torch.no_grad(): |
| outputs = model.generate(**fallback_args) |
| response_ids = outputs[0][input_ids.shape[-1]:] |
| response_text = tokenizer.decode(response_ids, skip_special_tokens=True) |
| logging.info("Fallback generation successful after coherence error.") |
| return response_text |
| except Exception as fallback_e: |
| logging.error(f"Fallback generation also failed: {fallback_e}") |
| return f"[Error during generation (fallback): {fallback_e}]" |
|
|
| def _apply_output_filters(text, config): |
| if not getattr(config, "response_filters", True) or getattr(config, "no_response_filters", False): |
| return text |
|
|
| logging.debug("Applying output filters based on config...") |
| original_text = text |
| if not isinstance(text, str): text = str(text) |
|
|
| filter_actions = { |
| "profanity_filter": lambda t: re.sub(r'\b(damn|fuck|shit|cunt|bitch|asshole|bastard|hell)\b', '[profanity]', t, flags=re.IGNORECASE), |
| "pii_filter": lambda t: re.sub(r'\b(\d{3}[-.\s]?\d{3}[-.\s]?\d{4})\b|\b([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,})\b', '[censored_pii]', t), |
| "hate_filter": lambda t: "[censored_hate_speech]" if any(word in t.lower() for word in ["kill all", "attack the", "inferior race", "exterminate"]) else t, |
| "sexually_explicit_filter": lambda t: "[censored_explicit]" if any(word in t.lower() for word in [" rape ", " naked ", " porn ", " xxx ", " blowjob ", " sex toy"]) else t, |
| "dangerous_content_filter": lambda t: "[censored_dangerous]" if any(word in t.lower() for word in ["make a bomb", "build weapon", "illegal drug recipe", "how to kill"]) else t, |
| "medical_advice_filter": lambda t: "[discouraged_medical_advice]" if any(word in t.lower() for word in ["diagnose", "prescribe", "cure for", "medical treatment for", "symptoms suggest"]) else t, |
| "legal_advice_filter": lambda t: "[discouraged_legal_advice]" if any(word in t.lower() for word in ["legal advice", "sue", "represent me", "is this legal", "contract law"]) else t, |
| "financial_advice_filter": lambda t: "[discouraged_financial_advice]" if any(word in t.lower() for word in ["guaranteed investment", "stock tip", "financial plan", "buy bitcoin", "investment advice"]) else t, |
| "stereotype_filter": lambda t: "[censored_stereotype]" if re.search(r'\b(all|every)\s+([A-Za-z]+(\s+)?){1,3}\s+(are|always)\b', t.lower()) else t, |
| "misinfo_filter": lambda t: "[potential_misinfo]" if any(phrase in t.lower() for phrase in ["5g causes covid", "earth is flat", "vaccines cause autism", "chemtrails"]) else t, |
| "self_harm_filter": lambda t: "[censored_self_harm]" if any(phrase in t.lower() for phrase in ["commit suicide", "hurt myself", "painless death", "kill myself"]) else t, |
| } |
|
|
| active_filters_count = 0 |
| for filter_ui_name, filter_attr in filter_attr_map.items(): |
| if getattr(config, filter_attr, False): |
| filter_func = filter_actions.get(filter_attr) |
| if filter_func: |
| try: |
| filtered_text = filter_func(text) |
| if filtered_text != text: |
| active_filters_count +=1 |
| logging.debug(f"Filter '{filter_attr}' potentially applied modification.") |
| text = filtered_text |
| except Exception as e: |
| logging.warning(f"Error applying filter '{filter_attr}': {e}") |
|
|
| if not getattr(config, "no_advert_warning", False): |
| if re.search(r'\b(advertisement|sponsored|promo code|discount code|special offer)\b', text, re.IGNORECASE): |
| if "[Note: This response may contain promotional content.]" not in text: |
| text += "\n[Note: This response may contain promotional content.]" |
| active_filters_count +=1 |
|
|
| if active_filters_count > 0: |
| logging.debug(f"Output filtering potentially applied {active_filters_count} modifications.") |
|
|
| return text |
|
|
| def run_inference(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty): |
| global global_pipe, global_model, global_tokenizer, config |
| if not all([global_model, global_tokenizer]): |
| return "[Error] Model or Tokenizer not loaded. Please load a model first." |
| if global_pipe is None: |
| pipe_msg = update_pipeline() |
| if global_pipe is None: |
| return f"[Error] Text generation pipeline could not be initialized. Load/Reload model. Status: {pipe_msg}" |
|
|
| if not hasattr(global_model, 'config'): |
| logging.warning("Model config missing during inference. Initializing default flags.") |
| global_model.config = initialize_config_flags(None) |
| config = initialize_config_flags(global_model.config) |
| global_model.config = config |
|
|
| logging.info("Starting inference run..."); t_start_inf = time.time() |
| try: |
| use_filters = getattr(config, "response_filters", True) and not getattr(config, "no_response_filters", False) |
| apply_coherence = getattr(config, "coherence_improvement_enabled", False) |
|
|
| try: max_new_tokens = int(max_new_tokens); assert max_new_tokens > 0 |
| except: max_new_tokens = 256; logging.warning("Invalid max_new_tokens, using 256.") |
| try: temperature = float(temperature); assert temperature >= 0.0 |
| except: temperature = 0.7; logging.warning("Invalid temperature, using 0.7.") |
| try: top_k = int(top_k); assert top_k >= 0 |
| except: top_k = 50; logging.warning("Invalid top_k, using 50.") |
| try: top_p = float(top_p); assert 0.0 <= top_p <= 1.0 |
| except: top_p = 0.95; logging.warning("Invalid top_p, using 0.95.") |
| try: repetition_penalty = float(repetition_penalty); assert repetition_penalty >= 1.0 |
| except: repetition_penalty = 1.1; logging.warning("Invalid repetition_penalty, using 1.1.") |
|
|
| is_greedy = (temperature < 1e-6) or \ |
| (top_k == 1 and top_k != 0) or \ |
| (top_p <= 0.0 or top_p >= 1.0) or \ |
| getattr(config, "token_gen_speed_maximized", False) |
|
|
|
|
| gen_kwargs = { |
| "max_new_tokens": max_new_tokens, |
| "temperature": temperature if not is_greedy else None, |
| "top_k": top_k if top_k > 0 and not is_greedy else None, |
| "top_p": top_p if top_p > 0.0 and top_p < 1.0 and not is_greedy else None, |
| "repetition_penalty": repetition_penalty if repetition_penalty > 1.0 else None, |
| "do_sample": not is_greedy, |
| "use_cache": getattr(config, "use_cache", True), |
| "num_beams": (max(getattr(config, "num_beams", 1), 4) if apply_coherence else getattr(config, "num_beams", 1)), |
| "pad_token_id": global_tokenizer.pad_token_id if global_tokenizer.pad_token_id is not None else getattr(config, 'pad_token_id', None), |
| "eos_token_id": global_tokenizer.eos_token_id if global_tokenizer.eos_token_id is not None else getattr(config, 'eos_token_id', None), |
| "early_stopping": True if (apply_coherence or getattr(config, "num_beams", 1) > 1) else False |
| } |
| gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} |
|
|
| if gen_kwargs.get("num_beams", 1) > 1 and gen_kwargs.get("pad_token_id") is None: |
| if gen_kwargs.get("eos_token_id") is not None: |
| gen_kwargs["pad_token_id"] = gen_kwargs["eos_token_id"] |
| logging.warning(f"Using eos_token_id ({gen_kwargs['eos_token_id']}) as pad_token_id for beam search.") |
| else: |
| logging.error("Beam search requires pad_token_id, but none found (and eos_token_id missing). Generation might fail.") |
| return "[Error] Beam search failed: pad_token_id is required." |
|
|
| response_text = "" |
| device = get_device() |
| logging.debug(f"Generation arguments: {gen_kwargs}") |
|
|
| inputs = global_tokenizer(prompt, return_tensors="pt", padding=False, truncation=True, max_length=CONTEXT_LENGTH).to(device) |
| gen_kwargs["input_ids"] = inputs["input_ids"] |
| if "attention_mask" in inputs: |
| gen_kwargs["attention_mask"] = inputs["attention_mask"] |
|
|
| global_model.eval() |
|
|
| if apply_coherence: |
| response_text = _improve_coherence(global_model, global_tokenizer, gen_kwargs) |
| else: |
| with torch.no_grad(): |
| outputs = global_model.generate(**gen_kwargs) |
| output_sequence = outputs[0] |
| response_ids = output_sequence[inputs.input_ids.shape[-1]:] |
| response_text = global_tokenizer.decode(response_ids, skip_special_tokens=True) |
|
|
| if use_filters: |
| filtered_response = _apply_output_filters(response_text, config) |
| if filtered_response != response_text: |
| logging.info("Output filters applied modifications.") |
| response_text = filtered_response |
|
|
| final_response = response_text.strip() |
| logging.info(f"Inference finished in {time.time()-t_start_inf:.2f}s. Response length: {len(final_response)}") |
| return final_response |
|
|
| except Exception as e: |
| logging.error(f"Error during inference: {e}\n{traceback.format_exc()}") |
| return f"[Error during inference: {e}]" |
| finally: |
| if global_model and hasattr(global_model, 'training') and global_model.training: |
| global_model.train() |
|
|
|
|
| def start_training( |
| base_model_id: str, new_model_name: str, hf_token: str, |
| datasets_input_str: str, |
| activation_fn_name: str, target_layers_int: int, |
| grad_accum_ui: int, lr: float, epochs: int, max_steps: int, batch_size: int, |
| optimizer_name: str, scheduler_type: str, weight_decay: float, warmup_ratio: float, |
| use_peft: bool, peft_r: int, peft_alpha: int, peft_dropout: float, peft_target_modules_str: str, |
| wandb_token: str, use_cpu_flag: bool, bypass_limits_state: bool |
| ): |
| global global_model, global_tokenizer, global_pipe, original_num_layers_global, config, target_layers |
| global USE_CPU, BATCH_SIZE, LEARNING_RATE, EPOCHS, MAX_STEPS, DEFAULT_OPTIMIZER, DEFAULT_SCHEDULER, GRADIENT_ACCUMULATION_STEPS |
| global BYPASS_RESOURCE_LIMITS |
| BYPASS_RESOURCE_LIMITS = bypass_limits_state |
|
|
| start_overall_time = time.time() |
| logging.info("="*50) |
| logging.info("🚀 STARTING TRAINING PROCESS 🚀") |
|
|
| resources_ok, res_msg = check_resources() |
| if not resources_ok: |
| error_msg = f"[Error] Resource limits exceeded, cannot start training. {res_msg}" |
| logging.error(error_msg) |
| return error_msg |
| else: |
| logging.info(res_msg) |
|
|
| errors = [] |
| if not base_model_id: errors.append("Base Model ID/Path is required.") |
| if not new_model_name: errors.append("New Model Name (for saving/Hub) is required.") |
| if not datasets_input_str: errors.append("At least one dataset must be provided.") |
| try: target_layers_int = int(target_layers_int); assert target_layers_int >= 1 |
| except: errors.append("Target Layers must be a positive integer.") |
| try: grad_accum_ui = int(grad_accum_ui); assert grad_accum_ui >= 1 |
| except: errors.append("Gradient Accumulation Steps must be a positive integer.") |
| try: lr = float(lr); assert lr > 0 |
| except: errors.append("Learning Rate must be a positive float.") |
| try: epochs = int(epochs); assert epochs >= 0 |
| except: errors.append("Epochs must be an integer >= 0.") |
| try: max_steps = int(max_steps); assert max_steps >= 0 |
| except: errors.append("Max Steps must be an integer >= 0.") |
|
|
| if epochs <= 0 and max_steps <= 0: |
| errors.append("Training requires at least one of Epochs or Max Steps to be positive.") |
| elif epochs > 0 and max_steps > 0: |
| logging.info(f"Both Epochs ({epochs}) and Max Steps ({max_steps}) are set (> 0). Max Steps will take precedence.") |
| epochs = -1 |
| elif epochs <= 0 and max_steps > 0: |
| epochs = -1 |
| elif epochs > 0 and max_steps <= 0: |
| logging.info(f"Using Epochs ({epochs}) for training termination as Max Steps <= 0.") |
| max_steps = -1 |
| else: |
| logging.error("Logic error in epoch/max_step handling. Defaulting Max Steps to 1.") |
| max_steps = 1 |
| epochs = -1 |
|
|
|
|
| try: batch_size = int(batch_size); assert batch_size >= 1 |
| except: errors.append("Batch Size must be a positive integer.") |
| if optimizer_name not in OPTIMIZERS: errors.append(f"Invalid Optimizer. Choose from: {list(OPTIMIZERS.keys())}") |
| if scheduler_type not in SCHEDULER_TYPES: errors.append(f"Invalid Scheduler. Choose from: {SCHEDULER_TYPES}") |
| try: weight_decay = float(weight_decay); assert weight_decay >= 0.0 |
| except: errors.append("Weight Decay must be a non-negative float.") |
| try: warmup_ratio = float(warmup_ratio); assert 0.0 <= warmup_ratio <= 1.0 |
| except: errors.append("Warmup Ratio must be between 0.0 and 1.0.") |
| if activation_fn_name not in ACTIVATION_FUNCTIONS: errors.append(f"Invalid Activation Function. Choose from: {list(ACTIVATION_FUNCTIONS.keys())}") |
| if use_peft and not _peft_installed: errors.append("PEFT requested, but library not installed (`pip install peft`).") |
| peft_config_dict = {} |
| if use_peft: |
| try: |
| peft_r = int(peft_r); assert peft_r >= 1 |
| peft_alpha = int(peft_alpha); assert peft_alpha >= 1 |
| peft_dropout = float(peft_dropout); assert 0.0 <= peft_dropout <= 1.0 |
| peft_config_dict = { |
| "task_type": TaskType.CAUSAL_LM, |
| "inference_mode": False, |
| "r": peft_r, |
| "lora_alpha": peft_alpha, |
| "lora_dropout": peft_dropout, |
| } |
| if peft_target_modules_str: |
| modules = [m.strip() for m in peft_target_modules_str.split(',') if m.strip()] |
| if modules: |
| peft_config_dict["target_modules"] = modules |
| except Exception as peft_e: |
| errors.append(f"Invalid PEFT configuration: {peft_e}") |
|
|
|
|
| if errors: |
| error_msg = "[Error] Invalid training parameters:\n- " + "\n- ".join(errors) |
| logging.error(error_msg) |
| return error_msg |
|
|
| logging.info(f"Base Model: {base_model_id}, New Name: {new_model_name}") |
| logging.info(f"Use PEFT: {use_peft}") |
| if use_peft: logging.info(f"PEFT Config: r={peft_r}, alpha={peft_alpha}, dropout={peft_dropout}, targets={peft_target_modules_str or 'Auto'}") |
| logging.info(f"Datasets: \n{datasets_input_str}") |
| logging.info(f"LR: {lr}, Effective Epochs: {epochs if epochs > 0 else 'N/A'}, MaxSteps: {max_steps if max_steps > 0 else 'N/A'}, BS: {batch_size}, GradAccum: {grad_accum_ui}") |
| logging.info(f"Optim: {optimizer_name}, Scheduler: {scheduler_type}, WD: {weight_decay}, Warmup: {warmup_ratio}") |
| logging.info(f"Post-Mod Target Layers: {target_layers_int}, Post-Mod ActFn: {activation_fn_name}") |
| logging.info(f"Use CPU: {use_cpu_flag}, W&B: {'Enabled' if wandb_token else 'Disabled'}, Bypass Limits: {BYPASS_RESOURCE_LIMITS}") |
| logging.info("="*50) |
|
|
| USE_CPU = use_cpu_flag |
| BATCH_SIZE = batch_size |
| LEARNING_RATE = lr |
| EPOCHS = epochs if epochs > 0 else 1 |
| MAX_STEPS = max_steps |
| DEFAULT_OPTIMIZER = optimizer_name |
| DEFAULT_SCHEDULER = scheduler_type |
| GRADIENT_ACCUMULATION_STEPS = grad_accum_ui |
| target_layers = target_layers_int |
|
|
| logging.info("Setting up environment...") |
| clean_memory() |
| device = get_device() |
| logging.info(f"Using device: {device}") |
| num_cpu_cores_os = os.cpu_count() or 1 |
| global NUM_CPU_CORES |
| if NUM_CPU_CORES <= 0: NUM_CPU_CORES = num_cpu_cores_os |
| else: NUM_CPU_CORES = min(NUM_CPU_CORES, num_cpu_cores_os) |
| logging.info(f"Using {NUM_CPU_CORES} CPU cores for dataloading.") |
|
|
| wandb_run = None |
| use_wandb_reporting = False |
| if wandb_token: |
| logging.info("Attempting WandB login...") |
| try: |
| wandb.login(key=wandb_token) |
| logging.info("WandB login successful.") |
| use_wandb_reporting = True |
| except Exception as e: |
| logging.warning(f"WandB login failed: {e}. Proceeding without WandB logging.") |
| report_to = ["wandb"] if use_wandb_reporting else [] |
|
|
| user_id = "local_user"; repo_id_str = new_model_name; repo_link = "N/A (Upload skipped or failed)" |
| upload_to_hub = False |
| if hf_token: |
| logging.info("Attempting Hugging Face login...") |
| user_id = get_user_id(hf_token) |
| if user_id not in ["unknown_user", "http_error_user", "auth_error_user"]: |
| try: |
| login(token=hf_token, add_to_git_credential=False) |
| repo_id_str = f"{user_id}/{new_model_name}" |
| logging.info(f"Hugging Face login successful. User: {user_id}, Target Repo: {repo_id_str}") |
| create_repo(repo_id=repo_id_str, repo_type="model", exist_ok=True, token=hf_token) |
| logging.info(f"Hub repository '{repo_id_str}' ensured.") |
| repo_link = f"https://huggingface.co/{repo_id_str}" |
| upload_to_hub = True |
| except Exception as e: |
| logging.warning(f"Hugging Face login or repo creation failed: {e}. Upload will be skipped.") |
| hf_token = None |
| repo_id_str = new_model_name |
| repo_link = "N/A (Login/Repo Failed)" |
| else: |
| logging.warning(f"Could not get valid Hugging Face user ID ({user_id}). Upload will be skipped.") |
| hf_token = None |
| repo_id_str = new_model_name |
| repo_link = "N/A (Login Failed)" |
| else: |
| logging.info("No HF write token provided, Hub upload will be skipped.") |
|
|
| logging.info(f"Loading base model '{base_model_id}' and tokenizer...") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) |
| if tokenizer.pad_token is None: |
| if tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
| logging.info(f"Set tokenizer pad_token to eos_token ('{tokenizer.eos_token}')") |
| else: |
| added_pad = tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| if added_pad > 0: |
| logging.warning("Tokenizer missing pad_token and eos_token. Added '[PAD]' as pad_token.") |
| else: |
| logging.error("Tokenizer missing pad/eos and failed to add '[PAD]'. Training may fail.") |
|
|
| base_config_obj = AutoConfig.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) |
| base_config_obj = initialize_config_flags(base_config_obj) |
|
|
| original_num_layers_global = getattr(base_config_obj, 'num_hidden_layers', LAYERS) |
| if getattr(base_config_obj, 'original_num_layers', None) is None: |
| base_config_obj.original_num_layers = original_num_layers_global |
|
|
| if getattr(base_config_obj, 'vocab_size', -1) != len(tokenizer): |
| logging.warning(f"Config vocab size ({getattr(base_config_obj, 'vocab_size', 'N/A')}) differs from tokenizer ({len(tokenizer)}). Updating config.") |
| base_config_obj.vocab_size = len(tokenizer) |
| if getattr(base_config_obj, 'pad_token_id', -999) != tokenizer.pad_token_id: |
| base_config_obj.pad_token_id = tokenizer.pad_token_id |
|
|
| load_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32 |
| attn_impl_load = getattr(base_config_obj, 'attn_implementation', 'auto') |
| if attn_impl_load == "flash_attention_2": base_config_obj.use_flash_attention_2 = True |
| elif getattr(base_config_obj,'use_flash_attention_2', False): attn_impl_load = "flash_attention_2"; base_config_obj.attn_implementation = "flash_attention_2" |
|
|
| logging.info(f"Loading model with dtype={load_dtype}, attn_implementation='{attn_impl_load}'...") |
| model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| config=base_config_obj, |
| trust_remote_code=True, |
| token=hf_token, |
| torch_dtype=load_dtype, |
| low_cpu_mem_usage=True if device.type != 'cpu' else False, |
| attn_implementation=attn_impl_load if attn_impl_load != 'auto' else None |
| ) |
|
|
| if model.get_input_embeddings().weight.shape[0] != len(tokenizer): |
| logging.info(f"Resizing model token embeddings from {model.get_input_embeddings().weight.shape[0]} to tokenizer size {len(tokenizer)}") |
| model.resize_token_embeddings(len(tokenizer)) |
| if getattr(model.config, 'vocab_size', -1) != len(tokenizer): |
| model.config.vocab_size = len(tokenizer) |
|
|
| logging.info(f"Base model '{base_model_id}' loaded. Original Layers: {original_num_layers_global}, Current Layers: {model.config.num_hidden_layers}, Dtype: {model.dtype}") |
| if device.type == 'cpu' or not (device.type != 'cpu' and True): |
| model.to(device) |
| logging.info(f"Model moved to device: {device}") |
| else: |
| logging.info(f"Model loaded with low_cpu_mem_usage, should be on target device(s).") |
|
|
| config = model.config |
|
|
| except Exception as e: |
| logging.error(f"Failed to load base model or tokenizer '{base_model_id}': {e} \n{traceback.format_exc()}") |
| return f"[Error] Load failed for '{base_model_id}': {e}" |
|
|
| if use_peft: |
| logging.info("Applying PEFT adapter to the model for training...") |
| try: |
| lora_config = LoraConfig(**peft_config_dict) |
| peft_add_msg = _add_peft_adapter(model, config, peft_config_obj=lora_config) |
| except Exception as peft_e: |
| logging.error(f"Failed to configure or add PEFT adapter: {peft_e}") |
| return f"[Error] Failed to prepare PEFT model: {peft_e}" |
|
|
| if "[Error]" in peft_add_msg or "[Warning]" in peft_add_msg: |
| logging.error(f"Failed adding PEFT adapter: {peft_add_msg}") |
| return f"[Error] Failed adding PEFT adapter: {peft_add_msg}" |
|
|
| model = global_model |
| config = global_model.get_base_model().config |
| logging.info("PEFT adapter added successfully.") |
| else: |
| logging.info("Proceeding with full fine-tuning (PEFT not selected).") |
|
|
| logging.info("Loading and processing datasets...") |
| train_ds_processed = None |
| eval_ds_processed = None |
| try: |
| datasets_config_list = parse_datasets(datasets_input_str) |
| interleaved_ds = load_datasets_from_config(datasets_config_list) |
| if interleaved_ds is None: |
| raise ValueError("Dataset loading and interleaving resulted in None. Check logs.") |
|
|
| tokenize_partial = partial(tokenize_function, tokenizer=tokenizer, context_length=CONTEXT_LENGTH) |
| tokenized_ds = interleaved_ds.map( |
| tokenize_partial, |
| batched=True, |
| batch_size=1000, |
| ) |
|
|
| group_partial = partial(group_texts, block_size=CONTEXT_LENGTH) |
| lm_dataset = tokenized_ds.map( |
| group_partial, |
| batched=True, |
| batch_size=1000, |
| ) |
|
|
| try: |
| peek_final = next(iter(lm_dataset)) |
| final_cols = list(peek_final.keys()) |
| logging.info(f"Sample processed record structure: { {k: type(v).__name__ for k, v in peek_final.items()} }") |
| if not all(k in final_cols for k in ['input_ids', 'attention_mask', 'labels']): |
| raise ValueError(f"Final dataset structure after tokenizing/grouping is missing required keys. Found: {final_cols}") |
| except StopIteration: |
| raise ValueError("Dataset appears empty after processing and grouping.") |
|
|
| logging.info("Dataset tokenization and grouping complete.") |
|
|
| train_ds_processed, eval_ds_processed = split_dataset(lm_dataset) |
|
|
| if isinstance(train_ds_processed, IterableDataset): |
| logging.info("Training dataset is iterable (streaming).") |
| elif isinstance(train_ds_processed, Dataset): |
| logging.info(f"Training dataset size: {len(train_ds_processed):,} examples.") |
| else: |
| logging.warning("Could not determine training dataset type or size.") |
|
|
| if eval_ds_processed is not None: |
| logging.info(f"Created static evaluation dataset with {len(eval_ds_processed)} examples.") |
| else: |
| logging.info("No evaluation dataset created (buffer empty or error occurred).") |
|
|
| except Exception as e: |
| logging.error(f"Dataset loading, processing, or splitting failed: {e} \n{traceback.format_exc()}") |
| return f"[Error] Dataset preparation failed: {e}" |
|
|
| logging.info("Setting up Training Arguments...") |
| final_weight_decay = weight_decay if not getattr(config, 'weight_decay_disabled', False) else 0.0 |
| final_lr_scheduler = scheduler_type if not getattr(config, 'lr_scheduler_disabled', False) else "constant" |
| max_grad_norm_val = 1.0 if not getattr(config, 'gradient_clipping_disabled', False) else None |
|
|
| output_dir = f"./{new_model_name}_training_output" |
|
|
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| overwrite_output_dir=True, |
| report_to=report_to, |
| per_device_train_batch_size=BATCH_SIZE, |
| per_device_eval_batch_size=max(1, BATCH_SIZE * 2), |
| gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, |
| num_train_epochs=EPOCHS if EPOCHS > 0 else 1, |
| max_steps=MAX_STEPS, |
| optim=optimizer_name, |
| learning_rate=LEARNING_RATE, |
| weight_decay=final_weight_decay, |
| warmup_ratio=warmup_ratio, |
| lr_scheduler_type=final_lr_scheduler, |
| max_grad_norm=max_grad_norm_val if max_grad_norm_val is not None else 1e9, |
| fp16=load_dtype == torch.float16 and device.type == 'cuda', |
| bf16=load_dtype == torch.bfloat16 and device.type == 'cuda', |
| gradient_checkpointing=getattr(config, 'gradient_checkpointing_enabled', False), |
| gradient_checkpointing_kwargs={'use_reentrant': False} if getattr(config, 'gradient_checkpointing_enabled', False) else None, |
| dataloader_num_workers=NUM_CPU_CORES, |
| dataloader_pin_memory=True if device.type == 'cuda' else False, |
| evaluation_strategy="steps" if eval_ds_processed is not None else "no", |
| eval_steps=EVAL_STEPS if eval_ds_processed is not None else None, |
| save_strategy="steps", |
| save_steps=SAVE_STEPS, |
| save_total_limit=2, |
| load_best_model_at_end=LOAD_BEST_MODEL_AT_END if eval_ds_processed is not None else False, |
| metric_for_best_model=METRIC_FOR_BEST_MODEL if eval_ds_processed is not None else None, |
| logging_strategy="steps", |
| logging_steps=LOGGING_STEPS, |
| push_to_hub=upload_to_hub, |
| hub_model_id=repo_id_str if upload_to_hub else None, |
| hub_token=hf_token if upload_to_hub else None, |
| hub_strategy="checkpoint", |
| use_cpu=USE_CPU, |
| seed=42, |
| remove_unused_columns=False, |
| log_level="info", |
| ) |
|
|
| logging.info("Initializing Trainer...") |
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
| callbacks = [] |
| if LOAD_BEST_MODEL_AT_END and eval_ds_processed is not None: |
| callbacks.append(EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE, early_stopping_threshold=0.001)) |
|
|
| if use_wandb_reporting: |
| try: |
| wandb_run = wandb.init( |
| project=f"llm-modify-train-{new_model_name.replace('/', '-')}", |
| config=training_args.to_dict(), |
| name=f"run-{new_model_name.replace('/', '-')}-{int(time.time())}", |
| reinit=True |
| ) |
| logging.info(f"WandB run initialized: {wandb_run.name if wandb_run else 'Failed'}") |
| except Exception as wandb_e: |
| logging.error(f"Failed to initialize WandB run: {wandb_e}") |
| wandb_run = None |
| training_args.report_to = [] |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| tokenizer=tokenizer, |
| train_dataset=train_ds_processed, |
| eval_dataset=eval_ds_processed, |
| data_collator=data_collator, |
| callbacks=callbacks |
| ) |
|
|
| start_train_time = time.time() |
| logging.info(f"🚀 Starting model training (Using {type(trainer.model).__name__}). Effective steps: {training_args.max_steps if training_args.max_steps > 0 else 'N/A'}. Effective epochs: {training_args.num_train_epochs if training_args.num_train_epochs > 0 else 'N/A'}.") |
| train_result = None |
| training_successful = False |
| try: |
| last_checkpoint = None |
| if os.path.isdir(training_args.output_dir): |
| from transformers.trainer_utils import get_last_checkpoint |
| last_checkpoint = get_last_checkpoint(training_args.output_dir) |
| if last_checkpoint: |
| logging.info(f"*** Resuming training from checkpoint: {last_checkpoint} ***") |
|
|
| train_result = trainer.train(resume_from_checkpoint=last_checkpoint) |
| logging.info("✅ Training finished successfully.") |
| training_successful = True |
|
|
| trainer.save_model() |
| trainer.save_state() |
| if not use_peft: |
| tokenizer.save_pretrained(training_args.output_dir) |
| elif isinstance(trainer.model, PeftModel): |
| tokenizer.save_pretrained(training_args.output_dir) |
|
|
| except Exception as e: |
| logging.error(f"❌ Training failed: {e}\n{traceback.format_exc()}") |
| if wandb_run: wandb_run.finish(exit_code=1) |
| return f"[Error] Training failed: {e}" |
| finally: |
| end_train_time = time.time() |
| clean_memory() |
| training_time = end_train_time - start_train_time |
| logging.info(f"🕒 Training phase took {training_time:.2f} seconds.") |
|
|
| if not training_successful: |
| return "[Error] Training did not complete successfully." |
|
|
| final_trained_model = trainer.model |
| model_to_save = final_trained_model |
|
|
| merged_model_for_mods = None |
| if use_peft and isinstance(final_trained_model, PeftModel): |
| logging.info("Merging PEFT adapter into the base model for modification and final save...") |
| try: |
| merged_model_for_mods = final_trained_model.merge_and_unload() |
| logging.info("PEFT adapter merged successfully.") |
| merged_model_for_mods.config.peft_adapter_added = False |
| merged_model_for_mods.config.peft_config = None |
| merged_model_for_mods.config.lora_merged = True |
| except Exception as e: |
| logging.error(f"Failed to merge PEFT adapter after training: {e}. Saving adapter separately.") |
| adapter_save_path = os.path.join(training_args.output_dir, "final_adapter") |
| try: |
| final_trained_model.save_pretrained(adapter_save_path) |
| base_model_for_saving = final_trained_model.get_base_model() |
| base_model_for_saving.save_pretrained(training_args.output_dir) |
| tokenizer.save_pretrained(training_args.output_dir) |
| logging.info(f"PEFT adapter saved separately to {adapter_save_path}, base model to {training_args.output_dir}") |
| merged_model_for_mods = final_trained_model |
| except Exception as save_e: |
| logging.error(f"Failed to save adapter or base model separately: {save_e}. Proceeding with potentially unmerged PEFT model.") |
| merged_model_for_mods = final_trained_model |
| else: |
| merged_model_for_mods = final_trained_model |
|
|
| if merged_model_for_mods is None: |
| logging.error("Model state after training/merging is None. Cannot proceed.") |
| return "[Error] Lost model reference after training/merging." |
|
|
| def modify_model_post_train(model_obj, act_fn_name, target_layer_count): |
| logging.info(f"Applying post-training modifications: Target Layers={target_layer_count}, Activation={act_fn_name}") |
| if not hasattr(model_obj, 'config'): |
| logging.error("Cannot modify model: Missing config.") |
| return model_obj |
|
|
| current_config = initialize_config_flags(model_obj.config) |
| model_obj.config = current_config |
|
|
| current_layers = getattr(current_config, 'num_hidden_layers', None) |
| original_layers = getattr(current_config, 'original_num_layers', original_num_layers_global) |
|
|
| if current_layers is not None and original_layers is not None: |
| if target_layer_count != current_layers: |
| logging.info(f"Adjusting layers post-training: {current_layers} -> {target_layer_count} (Original: {original_layers})") |
| if target_layer_count < current_layers: |
| _reduce_layers_to_one(model_obj, current_config, target_layers=target_layer_count) |
| else: |
| restore_target = min(target_layer_count, original_layers) if original_layers else target_layer_count |
| if restore_target > current_layers: |
| logging.info(f"Attempting to restore layers: {current_layers} -> {restore_target}") |
| _enable_full_layers(model_obj, current_config, original_num_layers=restore_target) |
| else: |
| logging.info(f"Target layers ({target_layer_count}) >= current layers ({current_layers}). No layer increase needed or possible beyond original.") |
| else: |
| logging.info(f"Target layers ({target_layer_count}) matches current layers after training. No layer adjustment needed.") |
| else: |
| logging.warning("Could not determine current or original layer count from config post-training. Skipping layer adjustment.") |
|
|
| current_act_fn = getattr(current_config, 'current_activation_function', DEFAULT_ACTIVATION_FUNCTION) |
| if act_fn_name != current_act_fn: |
| logging.info(f"Setting activation function post-training to: {act_fn_name}") |
| _swap_activation_function(model_obj, current_config, act_fn_name) |
| else: |
| logging.info(f"Target activation function ({act_fn_name}) already matches current. No change needed.") |
|
|
| logging.info("Post-training modifications applied.") |
| return model_obj |
|
|
| logging.info("Applying final post-training modifications specified in UI...") |
| final_model_modified = modify_model_post_train(merged_model_for_mods, activation_fn_name, target_layers_int) |
| if merged_model_for_mods is not final_model_modified: |
| del merged_model_for_mods |
| clean_memory() |
|
|
| final_model_path = training_args.output_dir |
| logging.info(f"Saving final modified model state to {final_model_path}...") |
| try: |
| save_kwargs = {"safe_serialization": True} |
| final_model_modified.config = initialize_config_flags(final_model_modified.config) |
|
|
| if use_peft and isinstance(final_model_modified, PeftModel): |
| logging.warning("Saving unmerged PEFT model state again after modifications (adapter separate).") |
| adapter_save_dir = os.path.join(final_model_path, "final_adapter_modified") |
| final_model_modified.save_pretrained(adapter_save_dir) |
| logging.info(f"PEFT adapter saved to {adapter_save_dir}") |
| try: |
| base_model_final = final_model_modified.get_base_model() |
| base_model_final.save_pretrained(final_model_path, **save_kwargs) |
| tokenizer.save_pretrained(final_model_path) |
| logging.info(f"Base model saved to {final_model_path}") |
| except Exception as base_save_e: |
| logging.error(f"Failed to save base model separately after modification: {base_save_e}. Only adapter might be saved.") |
|
|
| else: |
| final_model_modified.save_pretrained(final_model_path, **save_kwargs) |
| tokenizer.save_pretrained(final_model_path) |
|
|
| logging.info("Final modified model saved locally.") |
| except Exception as e: |
| logging.error(f"Failed to save final modified model locally: {e}\n{traceback.format_exc()}") |
| if wandb_run: wandb_run.finish(exit_code=1) |
| global_model = final_model_modified.to(device) |
| global_tokenizer = tokenizer |
| config = final_model_modified.config |
| update_pipeline() |
| clean_memory() |
| return f"[Error] Failed to save final model locally: {e}. Training logs/checkpoints might be in {output_dir}." |
|
|
| final_eval_results = {}; final_eval_loss = None; final_perplexity = float('inf') |
| if eval_ds_processed is not None: |
| logging.info("Evaluating final modified model..."); T_final_eval = time.time() |
| try: |
| final_trainer = Trainer( |
| model=final_model_modified, |
| args=training_args, |
| tokenizer=tokenizer, |
| data_collator=data_collator, |
| eval_dataset=eval_ds_processed, |
| ) |
| final_eval_results = final_trainer.evaluate() |
| final_eval_loss = final_eval_results.get("eval_loss") |
| final_perplexity = compute_perplexity(final_eval_loss) |
| logging.info(f"✅ Final Model Evaluation Results: {final_eval_results}") |
| logging.info(f"Final Model Perplexity: {final_perplexity:.4f} (Eval time: {time.time() - T_final_eval:.2f}s)") |
| if use_wandb_reporting and wandb_run: |
| wandb_run.log({"final_eval_loss": final_eval_loss if final_eval_loss is not None else -1.0, |
| "final_perplexity": final_perplexity if final_perplexity != float('inf') else -1.0, |
| **final_eval_results}) |
| except Exception as e: |
| logging.error(f"Final evaluation failed: {e}\n{traceback.format_exc()}") |
| if use_wandb_reporting and wandb_run: |
| wandb_run.log({"final_eval_status": "Failed", "final_eval_error": str(e)}) |
| else: |
| logging.info("Skipping final evaluation as no evaluation dataset was available.") |
|
|
| upload_successful_final = False |
| if upload_to_hub: |
| logging.info(f"Attempting final upload of '{final_model_path}' to Hugging Face Hub: {repo_id_str}...") |
| try: |
| api = HfApi() |
| api.upload_folder( |
| folder_path=final_model_path, |
| repo_id=repo_id_str, |
| repo_type="model", |
| token=hf_token, |
| commit_message=f"Upload final trained model: {new_model_name} (Base: {base_model_id}, PPL: {final_perplexity:.2f})", |
| commit_description=(f"Training completed. Eval Loss: {final_eval_loss:.4f if final_eval_loss is not None else 'N/A'}, Perplexity: {final_perplexity:.4f if final_perplexity != float('inf') else 'N/A'}. " |
| f"Config: PEFT={use_peft}, Layers={target_layers_int}, ActFn={activation_fn_name}. Training time: {training_time:.2f}s.") |
| ) |
| repo_link = f"https://huggingface.co/{repo_id_str}" |
| logging.info(f"✅ Final model upload complete: {repo_link}") |
| if use_wandb_reporting and wandb_run: wandb_run.log({"hf_repo_link": repo_link, "hf_upload_status": "Success"}) |
| upload_successful_final = True |
| except Exception as e: |
| logging.error(f"Final Hugging Face upload failed: {e}\n{traceback.format_exc()}") |
| repo_link = "[Upload Failed]" |
| if use_wandb_reporting and wandb_run: wandb_run.log({"hf_upload_status": "Failed", "hf_upload_error": str(e)}) |
| else: |
| logging.info(f"Skipping final Hugging Face Hub upload based on initial setup.") |
|
|
| logging.info("Updating global state with the final model...") |
| global_model = final_model_modified.to(device) |
| global_tokenizer = tokenizer |
| config = global_model.config |
| update_pipeline() |
| clean_memory() |
|
|
| final_status_report_json = decode_model_details(global_model) |
| total_script_time = time.time() - start_overall_time |
|
|
| final_message = ( |
| f"✅ Training & Modification Process Complete!\n" |
| f"{'='*40}\n" |
| f"New Model Name: {new_model_name}\n" |
| f"Base Model: {base_model_id}\n" |
| f"Total Time: {total_script_time:.2f}s | Training Phase Time: {training_time:.2f}s\n" |
| f"{'='*40}\n" |
| f"Training Results:\n" |
| ) |
| if train_result: |
| final_message += f" - Steps Completed: {train_result.global_step}\n" |
| train_loss = train_result.training_loss |
| final_message += f" - Training Loss: {train_loss:.4f if train_loss is not None else 'N/A'}\n" |
| train_metrics = train_result.metrics |
| for metric, value in train_metrics.items(): |
| if "loss" in metric.lower() or "perplexity" in metric.lower() or "epoch" in metric.lower() or "step" in metric.lower(): |
| value_str = f"{value:.4f}" if isinstance(value, float) else str(value) |
| final_message += f" - {metric.replace('_', ' ').title()}: {value_str}\n" |
|
|
| final_message += ( |
| f"Final Evaluation:\n" |
| f" - Eval Loss: {final_eval_loss:.4f if final_eval_loss is not None else 'N/A'}\n" |
| f" - Perplexity: {final_perplexity:.4f if final_perplexity != float('inf') else 'N/A'}\n" |
| f"{'='*40}\n" |
| f"Saving & Upload:\n" |
| f" - Local Path: {final_model_path}\n" |
| f" - Hub Repo: {repo_link}\n" |
| f"{'='*40}\n" |
| f"Final Model Status Summary:\n" |
| ) |
| try: |
| status_data = json.loads(final_status_report_json) |
| summary_keys = ["Model Class", "Config Class", "Device(s)", "Params Summary", "Layer Types Count", "Key Config Attributes", "Modification Flags"] |
| for key in summary_keys: |
| if key in status_data: |
| value = status_data[key] |
| if isinstance(value, dict): |
| value_str = json.dumps(value, indent=4) |
| elif isinstance(value, list): |
| value_str = ", ".join(map(str, value)) |
| else: |
| value_str = str(value) |
| if len(value_str) > 200: value_str = value_str[:200] + "..." |
| final_message += f" - {key}: {value_str}\n" |
| final_message += f"(Full status logged and available in 'Model Controls' tab after refresh)\n" |
| except Exception as json_e: |
| logging.warning(f"Could not parse final status JSON for summary: {json_e}") |
| final_message += "(Could not generate status summary from JSON)\n" |
| final_message += f"{'='*40}" |
|
|
| if use_wandb_reporting and wandb_run: |
| try: |
| wandb_final_log = { |
| "total_time_seconds": total_script_time, |
| "training_time_seconds": training_time, |
| "final_eval_loss": final_eval_loss if final_eval_loss is not None else -1.0, |
| "final_perplexity": final_perplexity if final_perplexity != float('inf') else -1.0, |
| "upload_successful": upload_successful_final, |
| "final_steps_completed": train_result.global_step if train_result else -1, |
| "final_train_loss": train_result.training_loss if train_result and train_result.training_loss else -1.0, |
| } |
| wandb_run.log(wandb_final_log) |
| wandb_run.finish() |
| logging.info("WandB run finished.") |
| except Exception as e: |
| logging.warning(f"Error finishing WandB run: {e}") |
|
|
| logging.info("🏁 Full training and modification process finished. 🏁") |
| return final_message |
|
|
|
|
| def load_model_for_control(model_id_or_path, hf_token=None, bypass_limits_state=False): |
| global global_model, global_tokenizer, global_pipe, config, original_num_layers_global, BYPASS_RESOURCE_LIMITS |
| BYPASS_RESOURCE_LIMITS = bypass_limits_state |
| logging.info(f"Attempting to load model for control: {model_id_or_path}") |
| if not model_id_or_path: |
| return "[Error] Model ID or Path cannot be empty.", "{}", *get_error_filter_updates() |
|
|
| resources_ok, res_msg = check_resources() |
| if not resources_ok: |
| error_msg = f"[Error] Resource limits exceeded, cannot load model. {res_msg}" |
| logging.error(error_msg) |
| return error_msg, "{}", *get_error_filter_updates() |
| else: |
| logging.info(res_msg) |
|
|
| t_load_start = time.time() |
| device = get_device() |
| error_return = f"[Error] Failed to load model '{model_id_or_path}'.", "{}", *get_error_filter_updates() |
|
|
| global_model, global_tokenizer, global_pipe, config = None, None, None, None |
| clean_memory() |
|
|
| try: |
| logging.info("Loading tokenizer...") |
| tokenizer_load = AutoTokenizer.from_pretrained( |
| model_id_or_path, |
| trust_remote_code=True, |
| token=hf_token |
| ) |
| if tokenizer_load.pad_token is None: |
| if tokenizer_load.eos_token is not None: |
| tokenizer_load.pad_token = tokenizer_load.eos_token |
| logging.info(f"Set tokenizer pad_token to eos_token ('{tokenizer_load.eos_token}')") |
| else: |
| try: |
| tokenizer_load.add_special_tokens({'pad_token': '[PAD]'}) |
| logging.warning("Added '[PAD]' as pad_token.") |
| except Exception as pad_e: |
| logging.error(f"Could not set PAD token: {pad_e}. Batching or beam search might fail.") |
|
|
|
|
| logging.info("Loading model config...") |
| loaded_config = AutoConfig.from_pretrained( |
| model_id_or_path, |
| trust_remote_code=True, |
| token=hf_token |
| ) |
| config_load = initialize_config_flags(loaded_config) |
|
|
| original_layers_load = getattr(config_load, 'num_hidden_layers', LAYERS) |
| if getattr(config_load, 'original_num_layers', None) is None: |
| config_load.original_num_layers = original_layers_load |
| logging.info(f"Set original_num_layers in loaded config to {original_layers_load}") |
| original_num_layers_global = config_load.original_num_layers |
|
|
|
|
| if getattr(config_load, 'vocab_size', -1) != len(tokenizer_load): |
| config_load.vocab_size = len(tokenizer_load) |
| if getattr(config_load, 'pad_token_id', -999) != tokenizer_load.pad_token_id: |
| config_load.pad_token_id = tokenizer_load.pad_token_id |
|
|
| logging.info("Loading model weights...") |
| attn_impl_load = getattr(config_load, 'attn_implementation', 'auto') |
| if attn_impl_load == "flash_attention_2": config_load.use_flash_attention_2 = True |
| elif getattr(config_load,'use_flash_attention_2', False): attn_impl_load = "flash_attention_2"; config_load.attn_implementation = "flash_attention_2" |
|
|
| load_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32 |
| logging.info(f"Using dtype {load_dtype} and attn_implementation '{attn_impl_load}' for loading.") |
|
|
| model_load = AutoModelForCausalLM.from_pretrained( |
| model_id_or_path, |
| config=config_load, |
| trust_remote_code=True, |
| token=hf_token, |
| torch_dtype=load_dtype, |
| low_cpu_mem_usage=True if device.type != 'cpu' else False, |
| attn_implementation=attn_impl_load if attn_impl_load != 'auto' else None, |
| ) |
|
|
| if model_load.get_input_embeddings().weight.shape[0] != len(tokenizer_load): |
| logging.info(f"Resizing loaded model embeddings from {model_load.get_input_embeddings().weight.shape[0]} to tokenizer size {len(tokenizer_load)}") |
| model_load.resize_token_embeddings(len(tokenizer_load)) |
| if getattr(model_load.config, 'vocab_size', -1) != len(tokenizer_load): |
| model_load.config.vocab_size = len(tokenizer_load) |
|
|
| global_model = model_load.to(device) |
| global_tokenizer = tokenizer_load |
| config = global_model.config |
|
|
| logging.info(f"Model loaded successfully to {device}.") |
|
|
| update_pipeline() |
| clean_memory() |
| logging.info(f"Model '{model_id_or_path}' loaded and pipeline updated in {time.time() - t_load_start:.2f}s.") |
| status_json, *filter_updates = get_detailed_status_and_filter_states() |
| return f"Model '{model_id_or_path}' loaded successfully.", status_json, *filter_updates |
|
|
| except Exception as e: |
| logging.error(f"Failed to load model '{model_id_or_path}': {e}\n{traceback.format_exc()}") |
| global_model, global_tokenizer, global_pipe, config = None, None, None, None |
| clean_memory() |
| return error_return |
|
|
|
|
| def save_current_model(save_path, hf_token=None, hub_repo_id=None): |
| global global_model, global_tokenizer, config |
| if not global_model or not global_tokenizer: |
| return "[Error] No model loaded to save." |
| if not save_path and not hub_repo_id: |
| return "[Error] Please provide a local save path or a Hub Repo ID (or both)." |
|
|
| t_save_start = time.time() |
| model_to_save = global_model |
| tokenizer_to_save = global_tokenizer |
| config_to_save = initialize_config_flags(config if config else getattr(model_to_save, 'config', None)) |
| if config_to_save is None: |
| logging.error("Cannot save: Model config is missing.") |
| return "[Error] Model config is missing, cannot save." |
| model_to_save.config = config_to_save |
|
|
| is_peft_model = _peft_installed and isinstance(model_to_save, PeftModel) |
| save_adapter_only = is_peft_model |
| logging.info(f"Save mode: {'Adapter Only (PEFT model detected)' if save_adapter_only else 'Full Model'}") |
|
|
| temp_save_dir = None |
| effective_save_path = save_path.strip() if save_path else None |
| if not effective_save_path and hub_repo_id: |
| temp_save_dir = f"./hub_upload_temp_{hub_repo_id.replace('/', '_')}_{int(time.time())}" |
| effective_save_path = temp_save_dir |
| logging.info(f"No local path provided, saving temporarily to '{effective_save_path}' for Hub upload.") |
| elif not effective_save_path: |
| return "[Error] Cannot determine save location (missing local path and Hub ID)." |
|
|
| try: |
| os.makedirs(effective_save_path, exist_ok=True) |
| except OSError as e: |
| logging.error(f"Failed to create save directory '{effective_save_path}': {e}") |
| return f"[Error] Failed to create save directory: {e}" |
|
|
| local_save_message = "" |
| try: |
| logging.info(f"Saving current model state to {effective_save_path}...") |
| save_kwargs = {"safe_serialization": True} |
|
|
| if save_adapter_only: |
| logging.info("Saving PEFT adapter weights and tokenizer.") |
| model_to_save.save_pretrained(effective_save_path) |
| tokenizer_to_save.save_pretrained(effective_save_path) |
| try: |
| base_model_config = model_to_save.get_base_model().config |
| base_model_config.save_pretrained(effective_save_path) |
| except Exception as config_e: |
| logging.warning(f"Could not save base model config alongside adapter: {config_e}") |
| else: |
| logging.info("Saving full model weights and tokenizer.") |
| model_to_save.save_pretrained(effective_save_path, **save_kwargs) |
| tokenizer_to_save.save_pretrained(effective_save_path) |
|
|
| save_local_time = time.time() - t_save_start |
| logging.info(f"Model state saved locally to {effective_save_path} in {save_local_time:.2f}s") |
| local_save_message = f"Model saved locally to '{effective_save_path}'." |
|
|
| except Exception as e: |
| logging.error(f"Failed to save model locally to {effective_save_path}: {e}\n{traceback.format_exc()}") |
| if temp_save_dir and os.path.exists(temp_save_dir): |
| try: shutil.rmtree(temp_save_dir); logging.info("Cleaned up temporary directory after local save error.") |
| except Exception as clean_e: logging.warning(f"Could not remove temp dir {temp_save_dir} after error: {clean_e}") |
| return f"[Error] Failed to save model locally: {e}" |
|
|
| hub_message = "" |
| upload_successful = False |
| if hub_repo_id: |
| if not hf_token: |
| hub_message = "[Warning] Hub upload skipped: Hugging Face Write Token required." |
| logging.warning(hub_message) |
| else: |
| logging.info(f"Attempting to upload '{effective_save_path}' to Hub repo: {hub_repo_id}") |
| try: |
| api = HfApi(); |
| create_repo(repo_id=hub_repo_id, repo_type="model", exist_ok=True, token=hf_token) |
| api.upload_folder( |
| folder_path=effective_save_path, |
| repo_id=hub_repo_id, |
| repo_type="model", |
| token=hf_token, |
| commit_message=f"Upload model state ({'Adapter' if save_adapter_only else 'Full'}) via LLM Platform", |
| commit_description=f"Saved from LLM Platform UI. Model class: {type(global_model).__name__}. State: {'PEFT Adapter' if save_adapter_only else 'Full Model'}.", |
| ) |
| hub_link = f"https://huggingface.co/{hub_repo_id}" |
| hub_message = f"Successfully uploaded to Hub: {hub_link}" |
| upload_successful = True |
| logging.info(hub_message) |
| except Exception as e: |
| hub_message = f"[Error] Hub upload failed: {e}" |
| logging.error(f"Hub upload failed: {e}\n{traceback.format_exc()}") |
|
|
| if temp_save_dir and os.path.exists(temp_save_dir): |
| try: |
| shutil.rmtree(temp_save_dir) |
| logging.info("Cleaned up temporary save directory.") |
| except Exception as e: |
| logging.warning(f"Could not remove temporary directory {temp_save_dir}: {e}") |
|
|
| final_message = local_save_message if save_path else "" |
| if hub_message: |
| if final_message: final_message += f" | {hub_message}" |
| else: final_message = hub_message |
| if not final_message: |
| final_message = "[Info] No local save path provided and Hub upload failed or was skipped." |
|
|
| total_save_time = time.time() - t_save_start |
| logging.info(f"Total save operation took {total_save_time:.2f}s") |
| return final_message |
|
|
|
|
| filter_names_ui = [ |
| "Harassment", "Hate Speech", "Sexually Explicit", "Dangerous Content", |
| "Civic Integrity", "Harmful Code", "Medical Advice", "Legal Advice", |
| "Financial Advice", "PII (Basic)", "Political Content", "Religious Content", |
| "Profanity", "Stereotype", "Misinfo", "Self Harm", |
| "Personal Attack", "Toxicity", "Spam", "Off Topic", |
| "Tone", "Min Max Length", "Repetition Filter", "Factuality Filter" |
| ] |
| filter_attr_map = {name: name.lower().replace(" ", "_").replace("(", "").replace(")", "") + "_filter" for name in filter_names_ui} |
| filter_attr_map["PII (Basic)"] = "pii_filter" |
| filter_attr_map["Harmful Code"] = "code_filter" |
| filter_attr_map["Min Max Length"] = "min_max_length_filter" |
| filter_attr_map["Repetition Filter"] = "repetition_filter_enabled" |
| filter_attr_map["Factuality Filter"] = "factuality_filter_enabled" |
|
|
| custom_theme = gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky).set( |
| button_primary_background_fill="*primary_500", |
| button_primary_background_fill_hover="*primary_400", |
| button_secondary_background_fill="*secondary_500", |
| button_secondary_background_fill_hover="*secondary_400", |
| button_cancel_background_fill="*neutral_200", |
| button_cancel_background_fill_hover="*neutral_300", |
| ) |
|
|
| with gr.Blocks(theme=custom_theme, title="Advanced LLM Training & Modification Platform") as demo: |
| gr.Markdown("# 🤖 Advanced LLM Training & Modification Platform v1.2") |
| gr.Markdown("Load, modify, filter, train (Full/PEFT), test, merge, and save Large Language Models. Includes PEFT, experimental multi-modal capabilities, reward modeling setup, and resource checks.") |
|
|
| with gr.Accordion("🔑 Authentication & Settings", open=False): |
| with gr.Row(): |
| hf_token_read = gr.Textbox(label="🤗 HF Token (Read - Optional, for private models)", type="password", interactive=True, placeholder="hf_...") |
| hf_token_write = gr.Textbox(label="🤗 HF Token (Write - Optional, for Hub upload/training)", type="password", interactive=True, placeholder="hf_...") |
| train_wandb_token_inp = gr.Textbox(label="📊 WandB Token (Optional, for logging runs)", type="password", interactive=True) |
| with gr.Row(): |
| bypass_limits_chk = gr.Checkbox(label="Bypass RAM/Disk Limits (Use with Caution!)", value=False, interactive=True) |
|
|
|
|
| with gr.Tabs(): |
| with gr.TabItem("💾 Load, Save & Merge"): |
| with gr.Row(): |
| with gr.Column(scale=2): |
| gr.Markdown("### Load Model for Modification & Inference") |
| load_model_selector = HuggingfaceHubSearch(label="Search Hub or Enter Path/ID", placeholder="google/gemma-2b") |
| load_button = gr.Button("Load Model", variant="primary") |
| load_status_output = gr.Textbox(label="Load Status", interactive=False, lines=1) |
| with gr.Column(scale=2): |
| gr.Markdown("### Save Current Model State") |
| save_path_inp = gr.Textbox(label="Local Save Path (Optional)", placeholder="./saved_models/my_modified_model", interactive=True) |
| save_hub_repo_inp = gr.Textbox(label="Hub Repo ID (Optional, e.g., user/repo)", placeholder="username/my-cool-llm", interactive=True) |
| save_button = gr.Button("Save Model", variant="secondary") |
| save_status_output = gr.Textbox(label="Save Status", interactive=False, lines=1) |
|
|
| gr.Markdown("---") |
| gr.Markdown("### Merge Model Architectures (Parameter Averaging - Experimental)") |
| gr.Markdown("⚠️ **Experimental:** Averages parameters of models with compatible layers. Enter comma-separated Model IDs/Paths. The first model's config and tokenizer will be used as the base.") |
| merge_model_ids_inp = gr.Textbox(label="Model IDs/Paths to Merge (comma-separated)", placeholder="org/model-a, org/model-b, ./local-model-c") |
| merge_button = gr.Button("Merge Architectures", variant="primary") |
| merge_status_output = gr.Textbox(label="Merge Status", interactive=False, lines=2) |
|
|
| with gr.TabItem("🚀 Training"): |
| gr.Markdown("Fine-tune a model based on a selected base model. Supports Full fine-tuning and PEFT (LoRA). Apply modifications post-training.") |
| gr.Markdown("### 1. Base Model & Output Name") |
| with gr.Row(): |
| train_model_selector = HuggingfaceHubSearch(label="Search & Select Base Model for Training", placeholder="Type to search Hugging Face Hub...") |
| with gr.Row(): |
| train_new_model_inp = gr.Textbox(label="New Model Name (for saving locally and optionally on Hub)", placeholder="MyTunedModel-v1", interactive=True) |
|
|
| gr.Markdown("### 2. Training Data") |
| with gr.Row(): |
| train_dataset_selector = HuggingfaceHubSearch(label="Search Datasets on Hub (or specify local below)") |
| train_datasets_inp = gr.Textbox( |
| label="Datasets (one per line: 'id[,config[,split[,weight]]]')", |
| placeholder="Example:\nopenwebtext\nwikitext,wikitext-103-raw-v1,train,0.5\nmy_local_dataset_path,,train,1.5\nusername/my_dataset,my_config,validation,2.0", |
| lines=5, interactive=True) |
|
|
| gr.Markdown("### 3. Training Configuration") |
| with gr.Accordion("Training Mode & Hyperparameters", open=True): |
| train_use_peft_chk = gr.Checkbox(label="Enable PEFT (LoRA) Training", value=True, interactive=True) |
| with gr.Row(): |
| train_lr_inp = gr.Number(value=LEARNING_RATE, label="Learning Rate", interactive=True, minimum=1e-8, step=1e-6, precision=8) |
| train_epochs_inp = gr.Number(value=EPOCHS, label="Epochs (Set <= 0 if using Max Steps)", precision=0, minimum=-1, interactive=True) |
| train_max_steps_inp = gr.Number(value=MAX_STEPS, label="Max Steps (Set <= 0 if using Epochs)", precision=0, minimum=-1, interactive=True) |
| with gr.Row(): |
| train_batch_size_inp = gr.Number(value=BATCH_SIZE, label="Batch Size (Per Device)", precision=0, minimum=1, interactive=True) |
| train_grad_accum_inp = gr.Number(value=GRADIENT_ACCUMULATION_STEPS, label="Grad Accum Steps", precision=0, minimum=1, interactive=True) |
| train_optim_selector = gr.Dropdown(choices=list(OPTIMIZERS.keys()), value=DEFAULT_OPTIMIZER, label="Optimizer", interactive=True) |
| with gr.Row(): |
| train_scheduler_selector = gr.Dropdown(choices=SCHEDULER_TYPES, value=DEFAULT_SCHEDULER, label="LR Scheduler", interactive=True) |
| train_wd_inp = gr.Number(value=0.01, label="Weight Decay", minimum=0.0, interactive=True, step=0.001, precision=4) |
| train_warmup_ratio_inp = gr.Slider(0.0, 0.5, value=0.03, step=0.01, label="Warmup Ratio", interactive=True) |
|
|
| with gr.Accordion("PEFT Configuration (if PEFT enabled)", open=False, visible=True) as peft_config_accordion: |
| peft_r_inp = gr.Slider(label="LoRA r (Rank)", minimum=1, maximum=256, value=8, step=1, interactive=True) |
| peft_alpha_inp = gr.Slider(label="LoRA alpha", minimum=1, maximum=512, value=32, step=1, interactive=True) |
| peft_dropout_inp = gr.Slider(label="LoRA Dropout", minimum=0.0, maximum=0.5, value=0.1, step=0.01, interactive=True) |
| peft_target_modules_inp = gr.Textbox(label="Target Modules (comma-sep, optional, e.g., q_proj,v_proj)", placeholder="Leave empty for auto-detection (recommended)", interactive=True) |
|
|
| train_use_peft_chk.change(lambda x: gr.update(visible=x), inputs=train_use_peft_chk, outputs=peft_config_accordion) |
|
|
| with gr.Accordion("Post-Training Modifications (Applied After Training)", open=False): |
| with gr.Row(): |
| train_post_activation_fn_selector = gr.Dropdown(choices=list(ACTIVATION_FUNCTIONS.keys()), value=DEFAULT_ACTIVATION_FUNCTION, label="Target Activation Fn") |
| train_post_target_layers_inp = gr.Number(value=LAYERS, label="Target Layer Count", precision=0, minimum=1) |
|
|
| with gr.Accordion("Hardware & Logging", open=False): |
| train_use_cpu_chk = gr.Checkbox(value=USE_CPU, label="Force Use CPU (Very Slow!)", interactive=True) |
|
|
| gr.Markdown("### 4. Start Training") |
| train_button = gr.Button("✨ Start Training Process", variant="primary") |
| train_output = gr.Textbox(label="Training Log & Status", interactive=False, lines=20, max_lines=50) |
|
|
| with gr.TabItem("🔧 Model Controls"): |
| gr.Markdown("Interactively toggle modifications and filters for the **currently loaded** model. Refresh status after changes.") |
| with gr.Row(): |
| refresh_status_button = gr.Button("🔄 Refresh Status & Filter Checkboxes") |
| control_output = gr.Textbox(label="Control Action Status", interactive=False, lines=1) |
| status_output = gr.TextArea(label="Current Model Status (JSON)", interactive=False, lines=20, max_lines=60) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Core & Structure"): |
| with gr.Row(): |
| with gr.Column(min_width=150): bias_on = gr.Button("Bias Rem. ✅"); bias_off = gr.Button("Bias Rem. ❌") |
| with gr.Column(min_width=150): emb_on = gr.Button("Emb. Untie ✅"); emb_off = gr.Button("Emb. Untie ❌") |
| layer_target_inp = gr.Number(value=LAYERS, label="Target Layers", precision=0, minimum=1, interactive=True, scale=1) |
| layer_red_on = gr.Button("Apply Layer Red.", scale=1) |
| layer_red_off = gr.Button("Revert Layer Red.", scale=1) |
| with gr.Row(): |
| with gr.Column(min_width=150): norm_swap_rms = gr.Button("Use RMSNorm"); norm_swap_ln = gr.Button("Use LayerNorm") |
| act_select = gr.Dropdown(choices=list(ACTIVATION_FUNCTIONS.keys()), value=DEFAULT_ACTIVATION_FUNCTION, label="Change ActFn") |
| act_revert = gr.Button("Revert ActFn") |
| with gr.Column(min_width=150): bitnet_on = gr.Button("BitNet ✅"); bitnet_off = gr.Button("BitNet ❌") |
|
|
| with gr.Accordion("Multi-Modal Conversion (Experimental)", open=False): |
| gr.Markdown("⚠️ **Experimental:** Adds modality-specific encoders (e.g., ViT, Whisper) and projection layers. **Requires manual `forward` pass adaptation & multi-modal data/training.**") |
| modality_checkboxes_ui = gr.CheckboxGroup(choices=AVAILABLE_MODALITIES, label="Select Modalities") |
| with gr.Row(): |
| apply_multimodal_button = gr.Button("Apply Multi-Modal Setup") |
| revert_multimodal_button = gr.Button("Revert Multi-Modal Setup") |
|
|
| with gr.TabItem("Performance & Opt."): |
| with gr.Row(): |
| with gr.Column(min_width=150): speed_on = gr.Button("Speed Opt. ✅"); speed_off = gr.Button("Speed Opt. ❌") |
| with gr.Column(min_width=150): coher_on = gr.Button("Coherence ✅"); coher_off = gr.Button("Coherence ❌") |
| with gr.Column(min_width=150): ln_bypass_on = gr.Button("LN Bypass ✅"); ln_bypass_off = gr.Button("LN Bypass ❌") |
| with gr.Row(): |
| with gr.Column(min_width=150): do_bypass_on = gr.Button("Dropout Bypass ✅"); do_bypass_off = gr.Button("Dropout Bypass ❌") |
| with gr.Column(min_width=150): prec_on = gr.Button("FP32 Prec. ✅"); prec_off = gr.Button("FP32 Prec. ❌") |
| with gr.Column(min_width=150): norm_emb_on = gr.Button("Emb. Norm. ✅"); norm_emb_off = gr.Button("Emb. Norm. ❌") |
| with gr.Row(): |
| with gr.Column(min_width=150): gc_cp_on = gr.Button("Grad Checkpoint ✅"); gc_cp_off = gr.Button("Grad Checkpoint ❌") |
| with gr.Column(min_width=150): flash_attn_on = gr.Button("Flash Attn 2 ✅"); flash_attn_off = gr.Button("Flash Attn 2 ❌") |
|
|
| with gr.Accordion("Quantization & Pruning", open=False): |
| with gr.Row(): |
| quant_select = gr.Dropdown(choices=QUANTIZATION_MODES, value=DEFAULT_QUANTIZATION, label="Quantize To") |
| quant_apply = gr.Button("Apply Quant.") |
| quant_revert = gr.Button("Revert Quant.") |
| with gr.Row(): |
| prune_amount_inp = gr.Slider(0.01, 0.95, value=PRUNING_AMOUNT, step=0.01, label="Prune Amount") |
| prune_apply = gr.Button("Apply Pruning") |
| prune_revert = gr.Button("Revert Pruning") |
|
|
| with gr.TabItem("PEFT Adapters"): |
| gr.Markdown("Add, merge, or remove LoRA/PEFT adapters from the currently loaded model.") |
| peft_lora_path_input = gr.Textbox(label="LoRA/PEFT Adapter Path or Hub ID", placeholder="username/my-lora-adapter") |
| with gr.Row(): |
| peft_set_path_btn = gr.Button("Set Path in Config") |
| peft_add_adapter_btn = gr.Button("Add Default Adapter") |
| peft_merge_btn = gr.Button("Merge Active Adapter") |
| peft_remove_adapter_btn = gr.Button("Remove/Unload Adapter") |
|
|
| with gr.TabItem("Advanced Config & Layers"): |
| with gr.Row(): |
| freeze_input = gr.Textbox(label="Layers to Freeze (e.g., '0-3, 7, 10-11')") |
| freeze_apply = gr.Button("🧊 Freeze") |
| freeze_revert = gr.Button("🔥 Unfreeze All") |
| with gr.Row(): |
| with gr.Column(min_width=150): lim_on = gr.Button("Limits Cfg ✅"); lim_off = gr.Button("Limits Cfg ❌") |
| with gr.Column(min_width=150): qa_on = gr.Button("QA Restrict Rem. ✅"); qa_off = gr.Button("QA Restrict Rem. ❌") |
| layerdrop_prob_inp = gr.Slider(0.0, 0.5, value=0.1, step=0.01, label="LayerDrop Prob") |
| layerdrop_on = gr.Button("LayerDrop Flag ✅") |
| layerdrop_off = gr.Button("LayerDrop Flag ❌") |
| with gr.Accordion("RoPE, Sliding Window, Attention Variant (Require Model Reload)", open=False): |
| gr.Markdown("**Warning:** These settings modify the config but require reloading the model to take effect.") |
| with gr.Row(): |
| rope_type_inp = gr.Dropdown(label="RoPE Type", choices=["linear", "dynamic"], value="linear") |
| rope_factor_inp = gr.Number(label="RoPE Factor (>=1.0)", value=2.0, minimum=1.0, step=0.1) |
| rope_apply_btn = gr.Button("Set RoPE") |
| rope_revert_btn = gr.Button("Revert RoPE") |
| with gr.Row(): |
| sw_size_inp = gr.Number(label="Sliding Window Size (0=disable)", value=4096, minimum=0, step=64) |
| sw_apply_btn = gr.Button("Set Sliding Window") |
| sw_revert_btn = gr.Button("Revert Sliding Window") |
| with gr.Row(): |
| attn_variant_inp = gr.Dropdown(label="Attention Implementation", choices=["auto", "eager", "sdpa", "flash_attention_2"], value="auto") |
| attn_apply_btn = gr.Button("Set Attention Variant") |
| attn_revert_btn = gr.Button("Revert Attention Variant") |
|
|
| with gr.Accordion("KD & Reward Heads (Experimental - Requires Training Changes)", open=False): |
| with gr.Row(): |
| kd_labels_inp = gr.Number(label="KD Num Labels", value=2, minimum=1, precision=0) |
| kd_setup_btn = gr.Button("Setup KD Head") |
| kd_revert_btn = gr.Button("Revert KD Head") |
| with gr.Row(): |
| rm_outputs_inp = gr.Number(label="RM Num Outputs", value=1, minimum=1, precision=0) |
| rm_setup_btn = gr.Button("Setup RM Head") |
| rm_revert_btn = gr.Button("Revert RM Head") |
|
|
| with gr.Accordion("Other Flags (Symbolic - May Require Specific Training Logic)", open=False): |
| with gr.Row(): |
| swa_on = gr.Button("SWA Flag ✅"); swa_off = gr.Button("SWA Flag ❌") |
| ke_on = gr.Button("Know. Edit Flag ✅"); ke_off = gr.Button("Know. Edit Flag ❌") |
| hp_on = gr.Button("Head Prune Flag ✅"); hp_off = gr.Button("Head Prune Flag ❌") |
| with gr.Row(): |
| qat_on = gr.Button("QAT Flag ✅"); qat_off = gr.Button("QAT Flag ❌") |
| gn_on = gr.Button("Grad Noise Flag ✅"); gn_off = gr.Button("Grad Noise Flag ❌") |
| wi_on = gr.Button("Weight Init Flag ✅"); wi_off = gr.Button("Weight Init Flag ❌") |
|
|
| with gr.TabItem("Training Param Flags"): |
| gr.Markdown("Toggle flags in the config that affect **subsequent** Trainer initialization (won't affect current training).") |
| with gr.Row(): |
| with gr.Column(min_width=150): gc_flag_on = gr.Button("GradClip Flg ✅"); gc_flag_off = gr.Button("GradClip Flg ❌") |
| with gr.Column(min_width=150): wd_flag_on = gr.Button("WD Flg ✅"); wd_flag_off = gr.Button("WD Flg ❌") |
| with gr.Column(min_width=150): lr_flag_on = gr.Button("LR Sched. Flg ✅"); lr_flag_off = gr.Button("LR Sched. Flg ❌") |
| with gr.Row(): |
| optim_flag_select = gr.Dropdown(choices=list(OPTIMIZERS.keys()), value=DEFAULT_OPTIMIZER, label="Set Optim. Pref") |
| optim_flag_apply = gr.Button("Apply Optim.") |
| optim_flag_revert = gr.Button("Revert Optim.") |
| with gr.Row(): |
| grad_accum_ui_inp_config = gr.Number(value=GRADIENT_ACCUMULATION_STEPS, label="Grad Accum Steps (Config)", precision=0, minimum=1) |
| grad_accum_set_btn = gr.Button("Set Grad Accum") |
|
|
| with gr.TabItem("🔒 Safety & Content Filters"): |
| gr.Markdown("Control safety filter flags in the model's config. Actual filtering effectiveness depends on the inference implementation.") |
| with gr.Row(): |
| safety_all_on = gr.Button("🔒 Enable ALL Filters (Defaults)", variant="secondary") |
| safety_all_off = gr.Button("🔓 Disable ALL Filters", variant="stop") |
| gr.Markdown("Individual Filter Toggles:") |
| filter_checkboxes = [] |
| num_cols = 4 |
| for i in range(0, len(filter_names_ui), num_cols): |
| with gr.Row(): |
| for j in range(num_cols): |
| idx = i + j |
| if idx < len(filter_names_ui): |
| name = filter_names_ui[idx] |
| cb = gr.Checkbox(label=name, value=False, interactive=True) |
| filter_checkboxes.append(cb) |
| else: |
| gr.HTML("") |
| apply_filters_button = gr.Button("Apply Individual Filter Toggles", variant="secondary") |
|
|
| with gr.TabItem("💬 Inference"): |
| gr.Markdown("Test the **currently loaded and configured** model.") |
| with gr.Row(): |
| inference_prompt = gr.Textbox(label="Enter Prompt", lines=4, placeholder="Once upon a time...") |
| inference_output = gr.Textbox(label="Model Response", interactive=False, lines=15) |
| with gr.Accordion("Generation Parameters", open=True): |
| with gr.Row(): |
| max_new_tokens_slider = gr.Slider(10, 4096, value=256, step=10, label="Max New Tokens", interactive=True) |
| temperature_slider = gr.Slider(0.0, 2.0, value=0.7, step=0.01, label="Temperature (0=greedy)", interactive=True) |
| with gr.Row(): |
| top_k_slider = gr.Slider(0, 200, value=50, step=1, label="Top-K (0=disable)", interactive=True) |
| top_p_slider = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-P (0 or 1=disable)", interactive=True) |
| repetition_penalty_slider = gr.Slider(1.0, 3.0, value=1.1, step=0.05, label="Repetition Penalty (1=disable)", interactive=True) |
| generate_button = gr.Button("Generate Response", variant="primary") |
|
|
| with gr.TabItem("🚫 Censor Control"): |
| gr.Markdown("## Force Disable Censorship Flags") |
| gr.Markdown("Click the button below to attempt to set all known censorship/filter flags in the loaded model's configuration to `False`. This uses the `Disable ALL Filters` function.") |
| censor_off_button = gr.Button("🔓 Attempt Force Disable All Censorship Flags", variant="stop") |
| censor_status = gr.Textbox(label="Censorship Flag Status", interactive=False, lines=2) |
|
|
| load_button.click( |
| fn=load_model_for_control, |
| inputs=[load_model_selector, hf_token_read, bypass_limits_chk], |
| outputs=[load_status_output, status_output] + filter_checkboxes |
| ) |
| save_button.click( |
| fn=save_current_model, |
| inputs=[save_path_inp, hf_token_write, save_hub_repo_inp], |
| outputs=save_status_output |
| ) |
| merge_button.click( |
| fn=_merge_architectures, |
| inputs=[merge_model_ids_inp, hf_token_read, bypass_limits_chk], |
| outputs=[merge_status_output, status_output] + filter_checkboxes |
| ) |
|
|
| train_button.click( |
| fn=start_training, |
| inputs=[ |
| train_model_selector, train_new_model_inp, hf_token_write, |
| train_datasets_inp, |
| train_post_activation_fn_selector, train_post_target_layers_inp, |
| train_grad_accum_inp, train_lr_inp, train_epochs_inp, train_max_steps_inp, train_batch_size_inp, |
| train_optim_selector, train_scheduler_selector, train_wd_inp, train_warmup_ratio_inp, |
| train_use_peft_chk, peft_r_inp, peft_alpha_inp, peft_dropout_inp, peft_target_modules_inp, |
| train_wandb_token_inp, train_use_cpu_chk, bypass_limits_chk |
| ], |
| outputs=train_output |
| ).then( |
| fn=get_detailed_status_and_filter_states, |
| inputs=None, |
| outputs=[status_output] + filter_checkboxes |
| ) |
|
|
| refresh_outputs = [status_output] + filter_checkboxes |
| refresh_status_button.click(fn=get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) |
|
|
| def link_control(button, func, inputs=None): |
| processed_inputs = inputs if inputs else [] |
| click_event = button.click(func, inputs=processed_inputs, outputs=control_output) |
| click_event.then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) |
|
|
| link_control(bias_on, lambda: toggle_bias_removal_wrapper(True)) |
| link_control(bias_off, lambda: toggle_bias_removal_wrapper(False)) |
| link_control(emb_on, lambda: toggle_embeddings_untie_wrapper(True)) |
| link_control(emb_off, lambda: toggle_embeddings_untie_wrapper(False)) |
| link_control(layer_red_on, lambda layers: toggle_layer_reduction_wrapper(True, layers), inputs=[layer_target_inp]) |
| link_control(layer_red_off, lambda: toggle_layer_reduction_wrapper(False, None)) |
| link_control(norm_swap_rms, lambda: apply_norm_swap_wrapper('RMSNorm')) |
| link_control(norm_swap_ln, lambda: apply_norm_swap_wrapper('LayerNorm')) |
| act_select.change(lambda name: apply_activation_change_wrapper(name), inputs=[act_select], outputs=control_output).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) |
| link_control(act_revert, revert_activation_change_wrapper) |
| link_control(bitnet_on, lambda: toggle_bitnet_wrapper(True)) |
| link_control(bitnet_off, lambda: toggle_bitnet_wrapper(False)) |
| link_control(apply_multimodal_button, apply_multimodal_wrapper, inputs=[modality_checkboxes_ui]) |
| link_control(revert_multimodal_button, revert_multimodal_wrapper) |
|
|
| link_control(speed_on, lambda: toggle_token_speed_optimization_wrapper(True)) |
| link_control(speed_off, lambda: toggle_token_speed_optimization_wrapper(False)) |
| link_control(coher_on, lambda: toggle_coherence_improvement_wrapper(True)) |
| link_control(coher_off, lambda: toggle_coherence_improvement_wrapper(False)) |
| link_control(ln_bypass_on, lambda: toggle_layer_norm_bypass_wrapper(True)) |
| link_control(ln_bypass_off, lambda: toggle_layer_norm_bypass_wrapper(False)) |
| link_control(do_bypass_on, lambda: toggle_dropout_bypass_wrapper(True)) |
| link_control(do_bypass_off, lambda: toggle_dropout_bypass_wrapper(False)) |
| link_control(prec_on, lambda: toggle_fp32_precision_wrapper(True)) |
| link_control(prec_off, lambda: toggle_fp32_precision_wrapper(False)) |
| link_control(norm_emb_on, lambda: toggle_embedding_normalization_wrapper(True)) |
| link_control(norm_emb_off, lambda: toggle_embedding_normalization_wrapper(False)) |
| link_control(gc_cp_on, lambda: toggle_gradient_checkpointing_wrapper(True)) |
| link_control(gc_cp_off, lambda: toggle_gradient_checkpointing_wrapper(False)) |
| link_control(flash_attn_on, lambda: toggle_flash_attention_wrapper(True)) |
| link_control(flash_attn_off, lambda: toggle_flash_attention_wrapper(False)) |
| link_control(quant_apply, apply_quantization_wrapper, inputs=[quant_select]) |
| link_control(quant_revert, revert_quantization_wrapper) |
| link_control(prune_apply, apply_pruning_wrapper, inputs=[prune_amount_inp]) |
| link_control(prune_revert, revert_pruning_wrapper) |
|
|
| link_control(peft_set_path_btn, set_lora_path_wrapper, inputs=[peft_lora_path_input]) |
| link_control(peft_add_adapter_btn, add_peft_adapter_wrapper) |
| link_control(peft_merge_btn, merge_peft_adapter_wrapper) |
| link_control(peft_remove_adapter_btn, remove_peft_adapter_wrapper) |
|
|
| link_control(freeze_apply, apply_layer_freeze_wrapper, inputs=[freeze_input]) |
| link_control(freeze_revert, revert_layer_freeze_wrapper) |
| link_control(lim_on, lambda: toggle_limits_wrapper(True)) |
| link_control(lim_off, lambda: toggle_limits_wrapper(False)) |
| link_control(qa_on, lambda: toggle_qa_restrictions_wrapper(True)) |
| link_control(qa_off, lambda: toggle_qa_restrictions_wrapper(False)) |
| link_control(layerdrop_on, lambda prob: toggle_layerdrop_wrapper(True, prob), inputs=[layerdrop_prob_inp]) |
| link_control(layerdrop_off, lambda: toggle_layerdrop_wrapper(False)) |
| link_control(rope_apply_btn, lambda type, factor: toggle_rope_scaling_wrapper(True, type, factor), inputs=[rope_type_inp, rope_factor_inp]) |
| link_control(rope_revert_btn, lambda: toggle_rope_scaling_wrapper(False)) |
| link_control(sw_apply_btn, lambda size: toggle_sliding_window_wrapper(True, size), inputs=[sw_size_inp]) |
| link_control(sw_revert_btn, lambda: toggle_sliding_window_wrapper(False)) |
| link_control(attn_apply_btn, apply_attention_variant_wrapper, inputs=[attn_variant_inp]) |
| link_control(attn_revert_btn, revert_attention_variant_wrapper) |
| link_control(kd_setup_btn, lambda labels: toggle_kd_wrapper(True, labels), inputs=[kd_labels_inp]) |
| link_control(kd_revert_btn, lambda: toggle_kd_wrapper(False)) |
| link_control(rm_setup_btn, lambda outputs: toggle_reward_modeling_wrapper(True, outputs), inputs=[rm_outputs_inp]) |
| link_control(rm_revert_btn, lambda: toggle_reward_modeling_wrapper(False)) |
| link_control(swa_on, lambda: specific_action_function(_apply_swa)) |
| link_control(swa_off, lambda: specific_action_function(_revert_swa)) |
| link_control(ke_on, lambda: specific_action_function(_apply_knowledge_editing)) |
| link_control(ke_off, lambda: specific_action_function(_revert_knowledge_editing)) |
| link_control(hp_on, lambda: specific_action_function(_apply_head_pruning)) |
| link_control(hp_off, lambda: specific_action_function(_revert_head_pruning)) |
| link_control(qat_on, lambda: specific_action_function(_apply_qat)) |
| link_control(qat_off, lambda: specific_action_function(_revert_qat)) |
| link_control(gn_on, lambda: specific_action_function(_apply_gradient_noise)) |
| link_control(gn_off, lambda: specific_action_function(_revert_gradient_noise)) |
| link_control(wi_on, lambda: specific_action_function(_apply_weight_init)) |
| link_control(wi_off, lambda: specific_action_function(_revert_weight_init)) |
|
|
| link_control(gc_flag_on, lambda: toggle_gradient_clipping_flag_wrapper(True)) |
| link_control(gc_flag_off, lambda: toggle_gradient_clipping_flag_wrapper(False)) |
| link_control(wd_flag_on, lambda: toggle_weight_decay_flag_wrapper(True)) |
| link_control(wd_flag_off, lambda: toggle_weight_decay_flag_wrapper(False)) |
| link_control(lr_flag_on, lambda: toggle_lr_scheduler_flag_wrapper(True)) |
| link_control(lr_flag_off, lambda: toggle_lr_scheduler_flag_wrapper(False)) |
| link_control(optim_flag_apply, apply_optimizer_change_wrapper, inputs=[optim_flag_select]) |
| link_control(optim_flag_revert, revert_optimizer_change_wrapper) |
| link_control(grad_accum_set_btn, set_gradient_accumulation_wrapper, inputs=[grad_accum_ui_inp_config]) |
|
|
| link_control(safety_all_on, lambda: toggle_all_safety_filters_wrapper(True)) |
| link_control(safety_all_off, lambda: toggle_all_safety_filters_wrapper(False)) |
| apply_filters_button.click( |
| fn=toggle_individual_safety_filter_wrapper, |
| inputs=filter_checkboxes, |
| outputs=control_output |
| ).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) |
|
|
| generate_button.click( |
| fn=run_inference, |
| inputs=[ |
| inference_prompt, max_new_tokens_slider, temperature_slider, |
| top_k_slider, top_p_slider, repetition_penalty_slider |
| ], |
| outputs=inference_output |
| ) |
|
|
| censor_off_button.click( |
| fn=force_disable_censorship_wrapper, |
| outputs=censor_status |
| ).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch(server_name="0.0.0.0", share=True, debug=False) |