gchghgg / appx.py
jnjj's picture
Rename app.py to appx.py
6f22cb9 verified
import os
import time
import math
import copy
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import prune
from transformers import (
AutoTokenizer,
AutoConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
AutoModelForCausalLM,
AutoModel,
EarlyStoppingCallback,
pipeline,
get_scheduler,
logging as hf_logging
)
try:
from peft import PeftModel, LoraConfig, get_peft_model, TaskType, PeftConfig
_peft_installed = True
except ImportError:
_peft_installed = False
PeftModel = None
LoraConfig = None
get_peft_model = None
TaskType = None
PeftConfig = None
from datasets import load_dataset, interleave_datasets, concatenate_datasets, Dataset, Features, Value, IterableDataset, DatasetDict
from huggingface_hub import login, create_repo, HfApi, hf_hub_download
import wandb
import gradio as gr
from gradio_huggingfacehub_search import HuggingfaceHubSearch
import re
import json
import gc
from accelerate import Accelerator
import logging
import traceback
from collections import Counter, OrderedDict
import requests
import gzip
import inspect
import shutil
from functools import partial
import types
import psutil
hf_logging.set_verbosity_error()
logging.getLogger("datasets").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
padding = True
truncation = True
TOKENIZERS_PARALLELISM = True
os.environ["TOKENIZERS_PARALLELISM"] = str(TOKENIZERS_PARALLELISM)
BATCH_SIZE = 8
LEARNING_RATE = 1.5e-4
EPOCHS = 1
MAX_STEPS = 1
USE_CPU = False
NUM_CPU_CORES = -1
MERGE_ALPHA = 0.7
CONTEXT_LENGTH = 256
HEADS = 4
DIMENSIONS = 256
LAYERS = 1
INTERMEDIATE_SIZE = 1024
USE_WANDB = False
ACTIVATION_FUNCTIONS = { "relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU, "mish": nn.Mish, "leaky_relu": nn.LeakyReLU, "elu": nn.ELU, "relu6": nn.ReLU6, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "identity": nn.Identity }
DEFAULT_ACTIVATION_FUNCTION = "gelu"
OPTIMIZERS = {
"adamw_torch": torch.optim.AdamW,
"adam_torch": torch.optim.Adam,
"sgd": torch.optim.SGD,
"adamax": torch.optim.Adamax,
"adagrad": torch.optim.Adagrad,
"rmsprop": torch.optim.RMSprop
}
DEFAULT_OPTIMIZER = "adamw_torch"
PRUNING_AMOUNT = 0.2
QUANTIZATION_MODES = ['float32', 'float16', 'bfloat16']
DEFAULT_QUANTIZATION = 'float32'
SCHEDULER_TYPES = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
DEFAULT_SCHEDULER = "cosine"
GRADIENT_ACCUMULATION_STEPS = 1
EVAL_STEPS = 500
SAVE_STEPS = 500
LOGGING_STEPS = 100
EARLY_STOPPING_PATIENCE = 5
LOAD_BEST_MODEL_AT_END = True
METRIC_FOR_BEST_MODEL = "eval_loss"
AVAILABLE_MODALITIES = ['Image', 'Audio']
MODALITY_ENCODERS = {
'Image': 'google/vit-base-patch16-224-in21k',
'Audio': 'openai/whisper-base'
}
DEFAULT_PEFT_CONFIG_DICT = {
"task_type": TaskType.CAUSAL_LM if _peft_installed else None,
"inference_mode": False,
"r": 8,
"lora_alpha": 32,
"lora_dropout": 0.1,
"target_modules": None
} if _peft_installed else {}
global_model = None
global_tokenizer = None
global_pipe = None
original_num_layers_global = LAYERS
config = None
target_layers = LAYERS
current_peft_config = copy.deepcopy(DEFAULT_PEFT_CONFIG_DICT) if _peft_installed else {}
RAM_LIMIT_PERCENT = 85.0
DISK_LIMIT_GB = 5.0
BYPASS_RESOURCE_LIMITS = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(dim, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
nn.init.ones_(self.weight)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.elementwise_affine:
output = output * self.weight
return output
def extra_repr(self):
return f'{self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
def activation_quant(x):
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
y = (x * scale).round().clamp(-128, 127) / scale
return y
def weight_quant(w):
scale = 1.0 / w.abs().mean().clamp(min=1e-5)
u = (w * scale).round().clamp(-1, 1) / scale
return u
class BitLinear(nn.Linear):
def forward(self, x):
w = self.weight
device = w.device
if x.device != device:
x = x.to(device)
rms_norm_module = RMSNorm(x.shape[-1], eps=1e-6, elementwise_affine=False, device=device, dtype=x.dtype)
x_norm = rms_norm_module(x)
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
bias = self.bias.to(w_quant.dtype) if self.bias is not None else None
output = F.linear(x_quant, w_quant, None)
if bias is not None:
output = output + bias.to(output.dtype)
return output
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
if self.bias is not None:
self.bias = self.bias.to(*args, **kwargs)
return self
class BypassLayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if isinstance(normalized_shape, int):
self.normalized_shape = (normalized_shape,)
else:
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.bypass = False
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
if self.bypass:
return x
original_dtype = x.dtype
if original_dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = x.float()
output = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return output.to(original_dtype)
def extra_repr(self) -> str:
return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, bypass={self.bypass}'
class BypassDropout(nn.Module):
def __init__(self, p=0.5, inplace=False):
super().__init__()
self.p = p
self.inplace = inplace
self.bypass = False
def forward(self, x):
if self.bypass or not self.training or self.p == 0:
return x
return F.dropout(x, self.p, self.training, self.inplace)
def extra_repr(self) -> str:
return f'p={self.p}, inplace={self.inplace}, bypass={self.bypass}'
def get_device():
if torch.cuda.is_available() and not USE_CPU:
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and not USE_CPU:
logging.info("MPS backend detected on Mac. Note: MPS support is experimental and may have limitations.")
return torch.device("mps")
else:
if not USE_CPU:
logging.warning("CUDA/MPS not available or USE_CPU=True. Falling back to CPU.")
return torch.device("cpu")
def clean_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logging.debug("Cleaned memory.")
def check_resources(ram_limit_pct=RAM_LIMIT_PERCENT, disk_limit_gb=DISK_LIMIT_GB):
if BYPASS_RESOURCE_LIMITS:
logging.info("Resource limit checks bypassed.")
return True, "Resource checks bypassed."
try:
ram = psutil.virtual_memory()
ram_used_pct = ram.percent
ram_ok = ram_used_pct < ram_limit_pct
disk = psutil.disk_usage('/')
disk_free_gb = disk.free / (1024**3)
disk_ok = disk_free_gb > disk_limit_gb
status_msg = (f"Resource Check: RAM Used: {ram_used_pct:.1f}% (Limit: <{ram_limit_pct}%), "
f"Disk Free: {disk_free_gb:.1f}GB (Limit: >{disk_limit_gb}GB).")
if ram_ok and disk_ok:
logging.info(status_msg + " Status: OK")
return True, status_msg + " Status: OK"
else:
warning_msg = status_msg + " Status: LIMIT EXCEEDED!"
logging.warning(warning_msg)
return False, warning_msg
except Exception as e:
logging.error(f"Failed to check resources: {e}")
return True, f"Resource check failed: {e}"
def initialize_config_flags(existing_config=None):
if existing_config is None:
from transformers import PretrainedConfig
config_obj = PretrainedConfig()
elif isinstance(existing_config, dict):
from transformers import PretrainedConfig
try:
config_obj = PretrainedConfig(**existing_config)
except Exception as e:
logging.warning(f"Could not initialize PretrainedConfig from dict, using default. Error: {e}")
config_obj = PretrainedConfig()
else:
config_obj = existing_config
default_flags = {
"reduced_layers": False, "original_num_layers": None, "removed_bias": False, "untied_embeddings": False,
"limits_configured": False, "qa_restrictions_removed": False, "additional_mechanisms_applied": False,
"safety_settings_enabled": True, "perfect_precision_recovered": False, "token_gen_speed_maximized": False,
"coherence_improvement_enabled": False, "inconsistencies_biases_removed": False,
"quantization_applied": False, "quantization_mode": DEFAULT_QUANTIZATION,
"layer_norm_bypassed": False, "replaced_layer_norm": False, "dropout_bypassed": False, "replaced_dropout": False,
"activation_function_swapped": False, "current_activation_function": DEFAULT_ACTIVATION_FUNCTION,
"embedding_normalized": False, "gradient_clipping_disabled": False, "weight_decay_disabled": False,
"lr_scheduler_disabled": False, "bitnet_applied": False, "gradient_checkpointing_enabled": False,
"pruning_applied": False, "pruning_amount": None, "frozen_layers": None,
"enhanced_security_enabled": False, "debug_mode_enabled": False, "auto_optimization_enabled": False,
"internal_logging_enabled": False, "drift_detection_enabled": False, "ultra_fast_mode": False,
"optimizer": DEFAULT_OPTIMIZER, "rms_norm_applied": False, "layerdrop_enabled": False, "layerdrop_prob": 0.0,
"lora_merged": False, "lora_adapter_path": None, "knowledge_distillation_setup": False, "kd_num_labels": None,
"reward_modeling_setup": False, "rm_num_outputs": 1,
"swa_applied": False, "knowledge_edited": False, "head_pruning_applied": False, "qat_applied": False,
"architecture_merged": False, "weight_init_applied": False, "gradient_noise_applied": False,
"rope_scaling_type": None, "rope_scaling_factor": None, "sliding_window_size": None, "attention_variant": None,
"response_filters": True, "harassment_filter": True, "hate_filter": True, "sexually_explicit_filter": True,
"dangerous_content_filter": True, "civic_integrity_filter": True, "code_filter": True,
"medical_advice_filter": True, "legal_advice_filter": True, "financial_advice_filter": True,
"pii_filter": True, "political_filter": True, "religious_filter": True, "profanity_filter": True,
"stereotype_filter": True, "misinfo_filter": True, "self_harm_filter": True, "personal_attack_filter": True,
"toxicity_filter": True, "spam_filter": True, "off_topic_filter": True, "tone_filter": True,
"min_max_length_filter": True, "repetition_filter_enabled": False, "factuality_filter_enabled": False,
"baseline_distribution": None, "remove_censorship": False, "no_response_filters": False,
"no_advert_warning": False, "no_limits": False, "knowledge_date": None, "cutoff_date": None,
"max_input_tokens": None, "max_output_tokens": None,
"multimodal_applied": False, "supported_modalities": [], "modality_encoders": {}, "modality_projection_dim": None, "modality_special_tokens": {},
"use_flash_attention_2": getattr(config_obj, 'attn_implementation', None) == 'flash_attention_2',
"attn_implementation": getattr(config_obj, 'attn_implementation', 'auto'),
"gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
"peft_adapter_added": False, "peft_config": None
}
for flag, default_value in default_flags.items():
if not hasattr(config_obj, flag):
setattr(config_obj, flag, default_value)
if getattr(config_obj, 'attn_implementation', 'auto') == 'flash_attention_2':
config_obj.use_flash_attention_2 = True
else:
config_obj.use_flash_attention_2 = False
if getattr(config_obj, 'quantization_mode', DEFAULT_QUANTIZATION) == 'float32':
config_obj.quantization_applied = False
config_obj.perfect_precision_recovered = True
else:
config_obj.quantization_applied = True
config_obj.perfect_precision_recovered = False
if _peft_installed and isinstance(existing_config, PeftConfig):
config_obj.peft_adapter_added = True
config_obj.peft_config = existing_config.to_dict()
return config_obj
def _recursive_setattr(obj, attr_str, value):
parts = attr_str.split('.')
obj_to_set = obj
try:
for part in parts[:-1]:
if not hasattr(obj_to_set, part):
logging.warning(f"Intermediate attribute {part} not found in {attr_str} for object {type(obj_to_set)}")
return False
obj_to_set = getattr(obj_to_set, part)
if obj_to_set is None:
logging.warning(f"Intermediate attribute {part} is None in {attr_str}")
return False
if hasattr(obj_to_set, parts[-1]):
setattr(obj_to_set, parts[-1], value)
return True
else:
logging.warning(f"Final attribute {parts[-1]} not found in {attr_str} on object {type(obj_to_set)}")
return False
except AttributeError as e:
logging.error(f"AttributeError setting {attr_str}: {e}")
return False
except Exception as e:
logging.error(f"Generic error setting attribute {attr_str}: {e}")
return False
def _get_encoder_hidden_size(encoder_model_id, trust_remote_code=True):
try:
encoder_config = AutoConfig.from_pretrained(encoder_model_id, trust_remote_code=trust_remote_code)
possible_keys = ['hidden_size', 'd_model', 'embed_dim']
for key in possible_keys:
if hasattr(encoder_config, key):
size = getattr(encoder_config, key)
if isinstance(size, int) and size > 0:
return size
nested_configs = ['vision_config', 'audio_config', 'encoder']
for nested_name in nested_configs:
if hasattr(encoder_config, nested_name):
nested_cfg = getattr(encoder_config, nested_name)
if nested_cfg and isinstance(nested_cfg, object):
for key in possible_keys:
if hasattr(nested_cfg, key):
size = getattr(nested_cfg, key)
if isinstance(size, int) and size > 0:
return size
raise ValueError(f"Could not automatically determine hidden/embedding size for encoder {encoder_model_id}. Checked attributes: {possible_keys} and nested configs: {nested_configs}.")
except Exception as e:
logging.error(f"Failed to get config or hidden size for encoder {encoder_model_id}: {e}")
raise ValueError(f"Failed to get config or hidden size for encoder {encoder_model_id}") from e
def convert_to_bitnet(model, config, copy_weights=True):
if not hasattr(nn, 'RMSNorm'):
logging.warning("BitNet conversion requires RMSNorm, which might not be standard. Using custom RMSNorm.")
device = get_device()
converted_count = 0
modules_to_process = list(model.named_modules())
processed_names = set()
with torch.no_grad():
for name, module in modules_to_process:
if name in processed_names:
continue
is_target_linear = isinstance(module, nn.Linear) and (
any(sub in name.lower() for sub in ["attn", "mlp", "fc", "dense", "query", "key", "value", "out", "wi", "wo"])
and "norm" not in name.lower()
and "embedding" not in name.lower()
)
if is_target_linear:
try:
current_dtype = module.weight.dtype if hasattr(module, 'weight') and module.weight is not None else torch.float32
has_bias = module.bias is not None
bl = BitLinear(module.in_features, module.out_features, has_bias).to(device=device, dtype=current_dtype)
if copy_weights and hasattr(module, 'weight') and module.weight is not None:
if bl.weight.shape == module.weight.shape:
bl.weight.data.copy_(module.weight.data)
else:
logging.warning(f"Shape mismatch for weight {name}: Expected {bl.weight.shape}, got {module.weight.shape}. Skipping weight copy.")
if has_bias and bl.bias is not None:
if bl.bias.shape == module.bias.shape:
bl.bias.data.copy_(module.bias.data)
else:
logging.warning(f"Shape mismatch for bias {name}: Expected {bl.bias.shape}, got {module.bias.shape}. Skipping bias copy.")
elif not has_bias and bl.bias is not None:
nn.init.zeros_(bl.bias)
elif has_bias and bl.bias is None:
logging.warning(f"Module {name} had bias, but BitLinear does not. Bias info lost.")
elif not copy_weights:
nn.init.xavier_uniform_(bl.weight)
if bl.bias is not None:
nn.init.zeros_(bl.bias)
if _recursive_setattr(model, name, bl):
converted_count += 1
processed_names.add(name)
logging.debug(f"Converted layer {name} to BitLinear.")
else:
logging.warning(f"Failed to set BitLinear for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error replacing {name} with BitLinear: {e} \n{traceback.format_exc()}")
processed_names.add(name)
if converted_count > 0:
config.bitnet_applied = True
logging.info(f"Applied BitNet conversion to {converted_count} linear layers.")
return f"Applied BitNet conversion to {converted_count} linear layers."
else:
logging.info("No applicable linear layers found or converted for BitNet.")
config.bitnet_applied = False
return "No applicable layers found for BitNet conversion."
def revert_bitnet(model, config):
if not getattr(config, 'bitnet_applied', False):
return "BitNet not applied according to config, nothing to revert."
device = get_device()
model.to(device)
reverted_count = 0
modules_to_process = list(model.named_modules())
processed_names = set()
with torch.no_grad():
for name, module in modules_to_process:
if name in processed_names:
continue
if isinstance(module, BitLinear):
try:
dtype = module.weight.dtype if hasattr(module, 'weight') and module.weight is not None else torch.float32
has_bias = module.bias is not None
lin = nn.Linear(module.in_features, module.out_features, bias=has_bias).to(device, dtype=dtype)
if hasattr(module, 'weight') and module.weight is not None:
if lin.weight.shape == module.weight.shape:
lin.weight.data.copy_(module.weight.data)
else:
logging.warning(f"Shape mismatch reverting weight {name}: Expected {lin.weight.shape}, got {module.weight.shape}. Re-initializing nn.Linear weight.")
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
if has_bias and lin.bias is not None:
if lin.bias.shape == module.bias.shape:
lin.bias.data.copy_(module.bias.data)
else:
logging.warning(f"Shape mismatch reverting bias {name}: Expected {lin.bias.shape}, got {module.bias.shape}. Re-initializing nn.Linear bias.")
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(lin.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(lin.bias, -bound, bound)
elif has_bias and lin.bias is None:
logging.error(f"BitLinear layer {name} had bias, but reverted nn.Linear does not. Reversion failed for bias.")
elif not has_bias and lin.bias is not None:
logging.error(f"BitLinear layer {name} lacked bias, but reverted nn.Linear has one. Setting to zero.")
nn.init.zeros_(lin.bias)
if _recursive_setattr(model, name, lin):
reverted_count += 1
processed_names.add(name)
logging.debug(f"Reverted BitLinear layer {name} to nn.Linear.")
else:
logging.warning(f"Failed to revert BitLinear for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error reverting BitLinear {name}: {e} \n{traceback.format_exc()}")
processed_names.add(name)
if reverted_count > 0:
config.bitnet_applied = False
logging.info(f"Reverted {reverted_count} BitNet layers to standard nn.Linear.")
return f"Reverted {reverted_count} BitNet layers."
else:
config.bitnet_applied = False
logging.info("No BitNet layers found to revert, but flag was true. Resetting flag.")
return "No BitNet layers found to revert."
def _find_decoder_layers_module(model):
prefixes = [
('model.decoder', 'layers'),
('model.layers', None),
('transformer.h', None),
('transformer.blocks', None),
('encoder.layer', None),
('model.encoder.layers', None),
('', 'layers'),
('', 'h'),
('model', 'layers'),
('decoder.block', None),
('decoder.layers', None)
]
if hasattr(model, 'model'):
base_obj = model.model
else:
base_obj = model
direct_attrs = ['layers', 'h', 'blocks', 'block']
for attr in direct_attrs:
if hasattr(base_obj, attr):
layer_list = getattr(base_obj, attr)
if isinstance(layer_list, nn.ModuleList) and len(layer_list) > 0:
logging.info(f"Found layer list at 'model.{attr}' or '{attr}' with {len(layer_list)} layers.")
return base_obj, attr, layer_list
elif isinstance(layer_list, (list, tuple)) and len(layer_list) > 0 and isinstance(layer_list[0], nn.Module):
logging.warning(f"Found layers as list/tuple at 'model.{attr}' or '{attr}'. Converting to ModuleList.")
setattr(base_obj, attr, nn.ModuleList(layer_list))
return base_obj, attr, getattr(base_obj, attr)
for p_base, attr_name_explicit in prefixes:
mod = model
valid_path = True
if p_base:
for comp in p_base.split('.'):
if not hasattr(mod, comp):
valid_path = False
break
mod = getattr(mod, comp)
if mod is None:
valid_path = False
break
if not valid_path:
continue
attrs_to_check = [attr_name_explicit] if attr_name_explicit else ['layers', 'h', 'blocks', 'block', 'layer']
for attr in attrs_to_check:
if hasattr(mod, attr):
layer_list = getattr(mod, attr)
if isinstance(layer_list, nn.ModuleList) and len(layer_list) > 0:
logging.info(f"Found layer list at '{p_base}.{attr}' with {len(layer_list)} layers.")
return mod, attr, layer_list
elif isinstance(layer_list, (list, tuple)) and len(layer_list) > 0 and isinstance(layer_list[0], nn.Module):
logging.warning(f"Found layers as list/tuple at '{p_base}.{attr}'. Converting to ModuleList.")
setattr(mod, attr, nn.ModuleList(layer_list))
return mod, attr, getattr(mod, attr)
logging.warning("Could not automatically find the standard decoder/transformer layer list module.")
return None, None, None
def _reduce_layers_to_one(base_model, config, target_layers=1):
if not isinstance(target_layers, int) or target_layers < 1:
logging.error(f"Invalid target_layers value: {target_layers}. Must be an integer >= 1.")
return f"Error: Target layers must be >= 1, got {target_layers}."
layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model)
if layer_module and layer_attr and current_layers is not None:
current_len = len(current_layers)
if current_len <= 0:
logging.warning("Found layer attribute but the ModuleList is empty. Cannot reduce.")
return "Warning: Layer list found but it's empty. Cannot reduce."
if current_len > target_layers:
logging.info(f"Reducing layers: {current_len} -> {target_layers}...")
original_layer_count = current_len
if not hasattr(config, 'original_num_layers') or config.original_num_layers is None or config.original_num_layers < current_len:
config.original_num_layers = original_layer_count
logging.info(f"Stored/Updated original layer count in config: {original_layer_count}")
new_layer_list = nn.ModuleList(current_layers[:target_layers])
setattr(layer_module, layer_attr, new_layer_list)
config.num_hidden_layers = target_layers
config.reduced_layers = True
if hasattr(config, 'n_layer'): config.n_layer = target_layers
if hasattr(config, 'num_layers'): config.num_layers = target_layers
if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = target_layers
if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = target_layers
logging.info(f"Successfully reduced layers to {target_layers}.")
clean_memory()
return f"Layers reduced to {target_layers}. Original count was: {original_layer_count}."
elif current_len == target_layers:
logging.info(f"Model already has {current_len} layers, matching the target {target_layers}. No reduction needed.")
config.reduced_layers = False if current_len == getattr(config, 'original_num_layers', current_len) else True
config.num_hidden_layers = current_len
if hasattr(config, 'n_layer'): config.n_layer = current_len
if hasattr(config, 'num_layers'): config.num_layers = current_len
if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len
if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len
return f"Model already has {current_len} layers (target {target_layers}). No reduction performed."
else:
logging.info(f"Model has {current_len} layers, which is less than the target {target_layers}. No reduction needed.")
config.reduced_layers = True
config.num_hidden_layers = current_len
if hasattr(config, 'n_layer'): config.n_layer = current_len
if hasattr(config, 'num_layers'): config.num_layers = current_len
if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len
if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len
return f"Model already has {current_len} layers (< target {target_layers}). No reduction performed."
else:
logging.warning("Could not find standard layer structure for reduction.")
config.reduced_layers = False
return "Warning: Could not find standard layer structure for reduction."
def _enable_full_layers(base_model, config, original_num_layers=None):
if not getattr(config, 'reduced_layers', False):
layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model)
current_len = len(current_layers) if current_layers is not None else 0
orig_len_config = getattr(config, 'original_num_layers', None)
if current_len > 0 and orig_len_config is not None and current_len == orig_len_config:
config.reduced_layers = False
return "Layers already seem to be at the original count. Flag corrected if necessary."
else:
return "Layers not previously reduced according to config flag, or cannot verify current/original counts."
orig_layers = original_num_layers if original_num_layers is not None else getattr(config, 'original_num_layers', None)
if orig_layers is None:
global original_num_layers_global
orig_layers = original_num_layers_global
if orig_layers is not None:
logging.warning(f"Using globally stored original layer count: {orig_layers} as it was missing in config.")
config.original_num_layers = orig_layers
else:
logging.error("Cannot restore layers: Original layer count is missing from config and global state.")
return "Error: Cannot revert - Original layer count unknown."
if not isinstance(orig_layers, int) or orig_layers <= 0:
logging.error(f"Cannot restore layers: Invalid original layer count found ({orig_layers}).")
return f"Error: Cannot revert - Invalid original layer count ({orig_layers})."
layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model)
if layer_module and layer_attr and current_layers is not None:
current_len = len(current_layers)
if current_len < orig_layers:
logging.info(f"Restoring layers: {current_len} -> {orig_layers}..."); T = time.time()
try:
if current_len == 0:
logging.error("Cannot restore layers: No existing layers found to copy structure from.")
return "Error: Cannot restore layers - no template layer available."
device = next(iter(current_layers[0].parameters()), torch.tensor([])).device
template_layer = current_layers[0].to('cpu')
layers_to_add = []
num_layers_to_add = orig_layers - current_len
logging.info(f"Need to add {num_layers_to_add} layers.")
for i in range(num_layers_to_add):
new_layer = copy.deepcopy(template_layer)
for _, sub_module in new_layer.named_modules():
if hasattr(sub_module, 'reset_parameters'):
try:
sub_module.reset_parameters()
except Exception as reset_e:
logging.warning(f"Could not reset parameters for submodule {sub_module} in new layer {i}: {reset_e}")
elif isinstance(sub_module, nn.Linear):
nn.init.kaiming_uniform_(sub_module.weight, a=math.sqrt(5))
if sub_module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(sub_module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(sub_module.bias, -bound, bound)
elif isinstance(sub_module, nn.Embedding):
nn.init.normal_(sub_module.weight)
if sub_module.padding_idx is not None:
with torch.no_grad(): sub_module.weight[sub_module.padding_idx].fill_(0)
elif isinstance(sub_module, (nn.LayerNorm, RMSNorm, BypassLayerNorm)):
if sub_module.elementwise_affine:
if hasattr(sub_module, 'weight') and sub_module.weight is not None: nn.init.ones_(sub_module.weight)
if hasattr(sub_module, 'bias') and sub_module.bias is not None: nn.init.zeros_(sub_module.bias)
new_layer = new_layer.to(device)
layers_to_add.append(new_layer)
full_layer_list = nn.ModuleList(list(current_layers) + layers_to_add)
setattr(layer_module, layer_attr, full_layer_list)
config.num_hidden_layers = orig_layers
config.reduced_layers = False
if hasattr(config, 'n_layer'): config.n_layer = orig_layers
if hasattr(config, 'num_layers'): config.num_layers = orig_layers
if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = orig_layers
if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = orig_layers
msg = f"Restored layer structure to {orig_layers} layers in {time.time()-T:.2f}s."
logging.info(msg)
clean_memory()
return msg
except Exception as e:
logging.error(f"Error restoring layers: {e}\n{traceback.format_exc()}")
setattr(layer_module, layer_attr, current_layers)
config.num_hidden_layers = current_len
config.reduced_layers = True
if hasattr(config, 'n_layer'): config.n_layer = current_len
if hasattr(config, 'num_layers'): config.num_layers = current_len
if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len
if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len
return f"Error restoring layers: {e}. State might be inconsistent."
else:
config.reduced_layers = False
config.num_hidden_layers = current_len
if hasattr(config, 'n_layer'): config.n_layer = current_len
if hasattr(config, 'num_layers'): config.num_layers = current_len
if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len
if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len
msg = f"Model already has {current_len} layers (>= original {orig_layers}). No restoration needed. Corrected flags if necessary."
logging.info(msg)
return msg
elif layer_module and layer_attr and current_layers is None:
logging.warning(f"Layer attribute '{layer_attr}' exists but is None or invalid. Cannot restore layers.")
return "Warning: Layer attribute found but invalid. Cannot restore layers."
else:
logging.warning("Could not find standard layer structure for restoration.")
return "Warning: Could not find standard layer structure for restoration."
def _replace_linear_without_bias(module, config):
device = get_device()
replaced_count = 0
modules_to_process = list(module.named_modules())
processed_names = set()
for name, child in modules_to_process:
if name in processed_names:
continue
if isinstance(child, nn.Linear) and child.bias is not None:
try:
dtype = child.weight.dtype
current_device = child.weight.device
new_linear = nn.Linear(child.in_features, child.out_features, bias=False).to(device=current_device, dtype=dtype)
with torch.no_grad():
if new_linear.weight.shape == child.weight.shape:
new_linear.weight.copy_(child.weight)
else:
logging.warning(f"Shape mismatch removing bias for weight {name}: Expected {new_linear.weight.shape}, got {child.weight.shape}. Re-initializing.")
nn.init.kaiming_uniform_(new_linear.weight, a=math.sqrt(5))
if _recursive_setattr(module, name, new_linear):
replaced_count += 1
processed_names.add(name)
logging.debug(f"Removed bias from layer {name}")
else:
logging.warning(f"Failed to set bias-less Linear for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error removing bias for layer {name}: {e}")
processed_names.add(name)
if replaced_count > 0:
config.removed_bias = True
logging.info(f"Removed bias from {replaced_count} linear layers.")
return f"Removed bias from {replaced_count} linear layers."
else:
logging.info("No linear layers with bias found to modify.")
return "No linear layers with bias found to modify."
def _enable_bias_in_linear(module, config):
if not getattr(config, 'removed_bias', False):
return "Bias not previously removed according to config flag. Cannot enable (revert)."
device = get_device()
enabled_count = 0
modules_to_process = list(module.named_modules())
processed_names = set()
for name, child in modules_to_process:
if name in processed_names:
continue
if isinstance(child, nn.Linear) and child.bias is None:
try:
dtype = child.weight.dtype
current_device = child.weight.device
new_linear = nn.Linear(child.in_features, child.out_features, bias=True).to(device=current_device, dtype=dtype)
with torch.no_grad():
if new_linear.weight.shape == child.weight.shape:
new_linear.weight.copy_(child.weight)
else:
logging.warning(f"Shape mismatch enabling bias for weight {name}: Expected {new_linear.weight.shape}, got {child.weight.shape}. Re-initializing weight.")
nn.init.kaiming_uniform_(new_linear.weight, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(new_linear.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(new_linear.bias, -bound, bound)
if _recursive_setattr(module, name, new_linear):
enabled_count += 1
processed_names.add(name)
logging.debug(f"Enabled bias for layer {name}")
else:
logging.warning(f"Failed to set biased Linear for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error enabling bias for layer {name}: {e}")
processed_names.add(name)
if enabled_count > 0:
config.removed_bias = False
logging.info(f"Enabled (restored) bias for {enabled_count} linear layers.")
return f"Enabled bias for {enabled_count} linear layers."
else:
config.removed_bias = False
logging.info("No bias-less linear layers found to enable bias for. Resetting flag.")
return "No bias-less linear layers found to enable bias for."
def _replace_layer_norm_with_bypass(module, config):
replaced_count = 0
modules_to_process = list(module.named_modules())
processed_names = set()
for name, child in modules_to_process:
if name in processed_names:
continue
if isinstance(child, nn.LayerNorm) and not isinstance(child, (BypassLayerNorm, RMSNorm)):
try:
child_device = get_device()
child_dtype = torch.float32
if hasattr(child, 'weight') and child.weight is not None:
child_device = child.weight.device
child_dtype = child.weight.dtype
elif hasattr(child, 'bias') and child.bias is not None:
child_device = child.bias.device
child_dtype = child.bias.dtype
elif hasattr(child, '_parameters') and child._parameters:
first_param = next(iter(child.parameters()), None)
if first_param is not None:
child_device = first_param.device
child_dtype = first_param.dtype
norm_shape = child.normalized_shape
eps = child.eps
affine = child.elementwise_affine
new_layer_norm = BypassLayerNorm(norm_shape, eps, affine, device=child_device, dtype=child_dtype)
if affine:
with torch.no_grad():
if hasattr(child, 'weight') and child.weight is not None and new_layer_norm.weight is not None:
if new_layer_norm.weight.shape == child.weight.shape:
new_layer_norm.weight.copy_(child.weight)
else:
logging.warning(f"Shape mismatch replacing LN weight {name}. Expected {new_layer_norm.weight.shape}, got {child.weight.shape}. Initializing BypassLN weight.")
nn.init.ones_(new_layer_norm.weight)
elif new_layer_norm.weight is not None:
nn.init.ones_(new_layer_norm.weight)
if hasattr(child, 'bias') and child.bias is not None and new_layer_norm.bias is not None:
if new_layer_norm.bias.shape == child.bias.shape:
new_layer_norm.bias.copy_(child.bias)
else:
logging.warning(f"Shape mismatch replacing LN bias {name}. Expected {new_layer_norm.bias.shape}, got {child.bias.shape}. Initializing BypassLN bias.")
nn.init.zeros_(new_layer_norm.bias)
elif new_layer_norm.bias is not None:
nn.init.zeros_(new_layer_norm.bias)
if _recursive_setattr(module, name, new_layer_norm):
replaced_count += 1
processed_names.add(name)
logging.debug(f"Replaced LayerNorm {name} with BypassLayerNorm.")
else:
logging.warning(f"Failed to set BypassLayerNorm for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error replacing LayerNorm {name} with Bypass version: {e}\n{traceback.format_exc()}")
processed_names.add(name)
if replaced_count > 0:
config.replaced_layer_norm = True
config.layer_norm_bypassed = False
logging.info(f"Replaced {replaced_count} LayerNorm layers with Bypass version.")
return f"Replaced {replaced_count} LayerNorm layers with Bypass version."
else:
logging.info("No standard nn.LayerNorm layers found to replace with BypassLayerNorm.")
return "No standard LayerNorm layers found to replace."
def _revert_bypass_layer_norm(module, config):
if not getattr(config, 'replaced_layer_norm', False):
return "BypassLayerNorm not previously applied according to config flag. Cannot revert."
reverted_count = 0
modules_to_process = list(module.named_modules())
processed_names = set()
for name, child in modules_to_process:
if name in processed_names:
continue
if isinstance(child, BypassLayerNorm):
try:
child_device = get_device()
child_dtype = torch.float32
if child.elementwise_affine:
if child.weight is not None:
child_device = child.weight.device
child_dtype = child.weight.dtype
elif child.bias is not None:
child_device = child.bias.device
child_dtype = child.bias.dtype
else:
pass
norm_shape = child.normalized_shape
eps = child.eps
affine = child.elementwise_affine
if isinstance(norm_shape, tuple) and len(norm_shape) == 1:
norm_arg = norm_shape[0]
elif isinstance(norm_shape, (list, tuple)):
norm_arg = list(norm_shape)
elif isinstance(norm_shape, int):
norm_arg = norm_shape
else:
raise ValueError(f"Unsupported normalized_shape type for nn.LayerNorm: {type(norm_shape)}")
new_layer_norm = nn.LayerNorm(norm_arg, eps, affine, device=child_device, dtype=child_dtype)
if affine:
with torch.no_grad():
if hasattr(child, 'weight') and child.weight is not None and new_layer_norm.weight is not None:
if new_layer_norm.weight.shape == child.weight.shape:
new_layer_norm.weight.copy_(child.weight)
else:
logging.warning(f"Shape mismatch reverting BypassLN weight {name}. Expected {new_layer_norm.weight.shape}, got {child.weight.shape}. Initializing LayerNorm weight.")
nn.init.ones_(new_layer_norm.weight)
elif new_layer_norm.weight is not None:
nn.init.ones_(new_layer_norm.weight)
if hasattr(child, 'bias') and child.bias is not None and new_layer_norm.bias is not None:
if new_layer_norm.bias.shape == child.bias.shape:
new_layer_norm.bias.copy_(child.bias)
else:
logging.warning(f"Shape mismatch reverting BypassLN bias {name}. Expected {new_layer_norm.bias.shape}, got {child.bias.shape}. Initializing LayerNorm bias.")
nn.init.zeros_(new_layer_norm.bias)
elif new_layer_norm.bias is not None:
nn.init.zeros_(new_layer_norm.bias)
if _recursive_setattr(module, name, new_layer_norm):
reverted_count += 1
processed_names.add(name)
logging.debug(f"Reverted BypassLayerNorm {name} to standard nn.LayerNorm.")
else:
logging.warning(f"Failed to revert BypassLayerNorm for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error reverting BypassLayerNorm {name} to standard LayerNorm: {e}\n{traceback.format_exc()}")
processed_names.add(name)
if reverted_count > 0:
config.replaced_layer_norm = False
config.layer_norm_bypassed = False
logging.info(f"Reverted {reverted_count} BypassLayerNorm layers back to standard nn.LayerNorm.")
return f"Reverted {reverted_count} BypassLayerNorm layers."
else:
config.replaced_layer_norm = False
config.layer_norm_bypassed = False
logging.info("No BypassLayerNorm layers found to revert. Resetting flags.")
return "No BypassLayerNorm layers found to revert."
def _enable_layer_norm_bypass(model):
count = 0
found_bypass_layers = False
for m in model.modules():
if isinstance(m, BypassLayerNorm):
found_bypass_layers = True
if not m.bypass:
m.bypass = True
count += 1
if not found_bypass_layers:
if getattr(model.config, 'replaced_layer_norm', False):
logging.warning("Config indicates LN were replaced with BypassLN, but none found. Cannot enable bypass.")
model.config.layer_norm_bypassed = False
return "Replaced LN flag is true, but no BypassLN layers found. Run 'Replace LN' first or revert."
else:
return "No BypassLayerNorm layers found in the model. Replace standard LayerNorm first to enable bypass functionality."
elif count > 0:
model.config.layer_norm_bypassed = True
logging.info(f"Enabled bypass for {count} BypassLayerNorm layers.")
return f"Enabled bypass for {count} LN layers."
else:
model.config.layer_norm_bypassed = True
logging.info("All existing BypassLayerNorm layers already have bypass enabled.")
return "No changes made (layers might already be bypassed)."
def _disable_layer_norm_bypass(model):
count = 0
found_bypass_layers = False
for m in model.modules():
if isinstance(m, BypassLayerNorm):
found_bypass_layers = True
if m.bypass:
m.bypass = False
count += 1
if not found_bypass_layers:
if getattr(model.config, 'replaced_layer_norm', False):
model.config.layer_norm_bypassed = False
return "Replaced LN flag is true, but no BypassLN layers found to disable bypass on."
else:
return "No BypassLayerNorm layers found in the model to disable bypass on."
elif count > 0:
model.config.layer_norm_bypassed = False
logging.info(f"Disabled bypass for {count} BypassLayerNorm layers.")
return f"Disabled bypass for {count} LN layers."
else:
model.config.layer_norm_bypassed = False
logging.info("All existing BypassLayerNorm layers already have bypass disabled.")
return "No changes made (layers might already have bypass disabled)."
def _replace_dropout_with_bypass(module, config):
replaced_count = 0
modules_to_process = list(module.named_modules())
processed_names = set()
for name, child in modules_to_process:
if name in processed_names:
continue
if type(child) == nn.Dropout:
try:
new_dropout = BypassDropout(child.p, child.inplace)
try:
parent_name = '.'.join(name.split('.')[:-1])
parent_module = module.get_submodule(parent_name) if parent_name else module
first_param = next(iter(parent_module.parameters()), None)
if first_param is not None:
new_dropout.to(device=first_param.device)
except Exception:
new_dropout.to(device=get_device())
if _recursive_setattr(module, name, new_dropout):
replaced_count += 1
processed_names.add(name)
logging.debug(f"Replaced Dropout {name} with BypassDropout.")
else:
logging.warning(f"Failed to set BypassDropout for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error replacing Dropout {name} with Bypass version: {e}")
processed_names.add(name)
if replaced_count > 0:
config.replaced_dropout = True
config.dropout_bypassed = False
logging.info(f"Replaced {replaced_count} nn.Dropout layers with BypassDropout version.")
return f"Replaced {replaced_count} Dropout layers."
else:
logging.info("No standard nn.Dropout layers found to replace with BypassDropout.")
return "No standard Dropout layers found to replace."
def _revert_bypass_dropout(module, config):
if not getattr(config, 'replaced_dropout', False):
return "BypassDropout not previously applied according to config flag. Cannot revert."
reverted_count = 0
modules_to_process = list(module.named_modules())
processed_names = set()
for name, child in modules_to_process:
if name in processed_names:
continue
if isinstance(child, BypassDropout):
try:
new_dropout = nn.Dropout(child.p, child.inplace)
try:
parent_name = '.'.join(name.split('.')[:-1])
parent_module = module.get_submodule(parent_name) if parent_name else module
first_param = next(iter(parent_module.parameters()), None)
if first_param is not None:
new_dropout.to(device=first_param.device)
except Exception:
new_dropout.to(device=get_device())
if _recursive_setattr(module, name, new_dropout):
reverted_count += 1
processed_names.add(name)
logging.debug(f"Reverted BypassDropout {name} to standard nn.Dropout.")
else:
logging.warning(f"Failed to revert BypassDropout for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error reverting BypassDropout {name} to standard nn.Dropout: {e}")
processed_names.add(name)
if reverted_count > 0:
config.replaced_dropout = False
config.dropout_bypassed = False
logging.info(f"Reverted {reverted_count} BypassDropout layers back to standard nn.Dropout.")
return f"Reverted {reverted_count} BypassDropout layers."
else:
config.replaced_dropout = False
config.dropout_bypassed = False
logging.info("No BypassDropout layers found to revert. Resetting flags.")
return "No BypassDropout layers found to revert."
def _enable_dropout_bypass(model):
count = 0
found_bypass_layers = False
for m in model.modules():
if isinstance(m, BypassDropout):
found_bypass_layers = True
if not m.bypass:
m.bypass = True
count += 1
if not found_bypass_layers:
if getattr(model.config, 'replaced_dropout', False):
model.config.dropout_bypassed = False
return "Replaced Dropout flag is true, but no BypassDropout layers found. Run 'Replace DO' first or revert."
else:
return "No BypassDropout layers found in the model. Replace standard Dropout first to enable bypass."
elif count > 0:
model.config.dropout_bypassed = True
logging.info(f"Enabled bypass for {count} BypassDropout layers.")
return f"Enabled bypass for {count} Dropout layers."
else:
model.config.dropout_bypassed = True
logging.info("All existing BypassDropout layers already have bypass enabled.")
return "No changes made (layers might already be bypassed)."
def _disable_dropout_bypass(model):
count = 0
found_bypass_layers = False
for m in model.modules():
if isinstance(m, BypassDropout):
found_bypass_layers = True
if m.bypass:
m.bypass = False
count += 1
if not found_bypass_layers:
if getattr(model.config, 'replaced_dropout', False):
model.config.dropout_bypassed = False
return "Replaced Dropout flag is true, but no BypassDropout layers found to disable bypass on."
else:
return "No BypassDropout layers found in the model to disable bypass on."
elif count > 0:
model.config.dropout_bypassed = False
logging.info(f"Disabled bypass for {count} BypassDropout layers.")
return f"Disabled bypass for {count} Dropout layers."
else:
model.config.dropout_bypassed = False
logging.info("All existing BypassDropout layers already have bypass disabled.")
return "No changes made (layers might already have bypass disabled)."
def _swap_activation_function(model, config, activation_fn_name):
activation_fn_class = ACTIVATION_FUNCTIONS.get(activation_fn_name)
if not activation_fn_class:
msg = f"Warning: Activation function '{activation_fn_name}' not found or invalid. Using default '{DEFAULT_ACTIVATION_FUNCTION}'."
logging.warning(msg)
activation_fn_class = ACTIVATION_FUNCTIONS[DEFAULT_ACTIVATION_FUNCTION]
activation_fn_name = DEFAULT_ACTIVATION_FUNCTION
if not activation_fn_class:
logging.error(f"Default activation function '{DEFAULT_ACTIVATION_FUNCTION}' is also missing! Cannot swap.")
return f"Error: Cannot find '{activation_fn_name}' or the default '{DEFAULT_ACTIVATION_FUNCTION}'."
else:
msg = ""
replaced_count = 0
current_act_classes = tuple(f for f in ACTIVATION_FUNCTIONS.values() if f is not None and inspect.isclass(f) and issubclass(f, nn.Module))
target_act_class = activation_fn_class
modules_to_process = list(model.named_modules())
processed_names = set()
for name, child in modules_to_process:
if name in processed_names:
continue
if type(child) in current_act_classes:
if type(child) == target_act_class:
processed_names.add(name)
continue
try:
new_activation = target_act_class()
try:
parent_name = '.'.join(name.split('.')[:-1])
parent_module = model.get_submodule(parent_name) if parent_name else module
first_param = next(iter(parent_module.parameters()), None)
if first_param is not None:
new_activation.to(device=first_param.device)
except Exception:
new_activation.to(device=get_device())
if _recursive_setattr(model, name, new_activation):
replaced_count += 1
processed_names.add(name)
logging.debug(f"Swapped activation {name} from {type(child).__name__} to {target_act_class.__name__}")
else:
logging.warning(f"Failed to set new activation function for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error replacing activation function {name} of type {type(child).__name__} with {target_act_class.__name__}: {e}")
processed_names.add(name)
if replaced_count > 0:
msg += f"Swapped {replaced_count} activation functions to {activation_fn_name}."
config.activation_function_swapped = True
config.current_activation_function = activation_fn_name
if hasattr(config, 'hidden_act'):
config.hidden_act = activation_fn_name
if hasattr(config, 'activation_function'):
config.activation_function = activation_fn_name
else:
msg += f"No eligible activation functions found to swap to {activation_fn_name} (or already using it)."
if not config.activation_function_swapped:
current_in_config = getattr(config, 'hidden_act', getattr(config, 'activation_function', DEFAULT_ACTIVATION_FUNCTION))
config.current_activation_function = current_in_config if current_in_config in ACTIVATION_FUNCTIONS else DEFAULT_ACTIVATION_FUNCTION
logging.info(msg)
return msg
def _revert_activation_function(model, config):
current_activation = getattr(config, 'current_activation_function', DEFAULT_ACTIVATION_FUNCTION)
was_swapped = getattr(config, 'activation_function_swapped', False)
if not was_swapped and current_activation == DEFAULT_ACTIVATION_FUNCTION:
return f"Activation function is already the default ('{DEFAULT_ACTIVATION_FUNCTION}') and was not marked as swapped."
elif not was_swapped:
logging.info(f"Activation function is '{current_activation}' but wasn't marked as swapped. Attempting to revert to '{DEFAULT_ACTIVATION_FUNCTION}' anyway.")
pass
else:
logging.info(f"Reverting activation function from '{current_activation}' to default '{DEFAULT_ACTIVATION_FUNCTION}'...")
result_msg = _swap_activation_function(model, config, DEFAULT_ACTIVATION_FUNCTION)
config.activation_function_swapped = False
config.current_activation_function = DEFAULT_ACTIVATION_FUNCTION
if hasattr(config, 'hidden_act'):
config.hidden_act = DEFAULT_ACTIVATION_FUNCTION
if hasattr(config, 'activation_function'):
config.activation_function = DEFAULT_ACTIVATION_FUNCTION
final_msg = f"Reverted to default activation ('{DEFAULT_ACTIVATION_FUNCTION}'). Result: {result_msg}"
return final_msg
def _swap_normalization_layer(model, config, target_norm_type='RMSNorm'):
device = get_device()
swapped_count = 0
processed_names = set()
if target_norm_type == 'RMSNorm':
current_norm_class = nn.LayerNorm
new_norm_class = RMSNorm
config_flag_name = 'rms_norm_applied'
target_flag_value = True
elif target_norm_type == 'LayerNorm':
current_norm_class = RMSNorm
new_norm_class = nn.LayerNorm
config_flag_name = 'rms_norm_applied'
target_flag_value = False
else:
msg = f"Error: Unsupported target normalization type '{target_norm_type}'. Use 'RMSNorm' or 'LayerNorm'."
logging.error(msg)
return msg
already_configured = getattr(config, config_flag_name, False) == target_flag_value
has_current_norm_instances = any(isinstance(m, current_norm_class) for name, m in model.named_modules() if not isinstance(m, (BypassLayerNorm, new_norm_class)))
if already_configured and not has_current_norm_instances:
logging.info(f"Model config flag '{config_flag_name}' is already {target_flag_value}, and no instances of {current_norm_class.__name__} found to swap. No action needed.")
return f"Model already configured for {target_norm_type} (or no swappable layers found)."
elif already_configured and has_current_norm_instances:
logging.warning(f"Model config flag '{config_flag_name}' is {target_flag_value}, but instances of {current_norm_class.__name__} were found. Attempting swap anyway to ensure consistency.")
pass
elif not already_configured and not has_current_norm_instances:
logging.info(f"No instances of {current_norm_class.__name__} found to swap to {target_norm_type}. Updating config flag to {target_flag_value}.")
setattr(config, config_flag_name, target_flag_value)
if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False
return f"No {current_norm_class.__name__} layers found to swap. Config flag '{config_flag_name}' set to {target_flag_value}."
modules_to_process = list(model.named_modules())
for name, module in modules_to_process:
if name in processed_names:
continue
if isinstance(module, current_norm_class) and not isinstance(module, BypassLayerNorm):
try:
eps = module.eps
elementwise_affine = module.elementwise_affine
module_device = get_device()
module_dtype = torch.float32
params = list(module.parameters())
if params:
module_device = params[0].device
module_dtype = params[0].dtype
elif elementwise_affine and hasattr(module, 'weight') and module.weight is not None:
module_device = module.weight.device
module_dtype = module.weight.dtype
dim = None
if isinstance(module, nn.LayerNorm):
dim = module.normalized_shape
elif isinstance(module, RMSNorm):
if elementwise_affine and hasattr(module, 'weight') and module.weight is not None:
dim = module.weight.shape[0]
else:
logging.warning(f"Cannot determine dimension for affine-less RMSNorm {name}. Cannot swap this layer.")
processed_names.add(name)
continue
else:
raise ValueError(f"Module {name} is unexpected type {type(module)} during norm swap.")
if new_norm_class == nn.LayerNorm:
if isinstance(dim, int):
norm_arg = dim
elif isinstance(dim, (list, tuple)):
norm_arg = list(dim)
else:
raise ValueError(f"Unsupported dimension type {type(dim)} '{dim}' for creating LayerNorm from {current_norm_class.__name__} layer {name}.")
new_norm = new_norm_class(norm_arg, eps=eps, elementwise_affine=elementwise_affine, device=module_device, dtype=module_dtype)
elif new_norm_class == RMSNorm:
if isinstance(dim, int):
norm_arg = dim
elif isinstance(dim, (list, tuple)):
if len(dim) == 1:
norm_arg = dim[0]
else:
logging.warning(f"LayerNorm shape {dim} has multiple dimensions. Using last dim ({dim[-1]}) for RMSNorm {name}.")
norm_arg = dim[-1]
else:
raise ValueError(f"Unsupported dimension type {type(dim)} '{dim}' for creating RMSNorm from {current_norm_class.__name__} layer {name}.")
new_norm = new_norm_class(norm_arg, eps=eps, elementwise_affine=elementwise_affine, device=module_device, dtype=module_dtype)
else:
raise ValueError("Invalid new_norm_class.")
if elementwise_affine:
with torch.no_grad():
if hasattr(module, 'weight') and module.weight is not None and hasattr(new_norm, 'weight') and new_norm.weight is not None:
if new_norm.weight.shape == module.weight.shape:
new_norm.weight.copy_(module.weight)
else:
logging.warning(f"Weight shape mismatch swapping norm {name}: {module.weight.shape} -> {new_norm.weight.shape}. Re-initializing target weight.")
nn.init.ones_(new_norm.weight)
elif hasattr(new_norm, 'weight') and new_norm.weight is not None:
logging.debug(f"Initializing weight for new norm {name} as source lacked it.")
nn.init.ones_(new_norm.weight)
if hasattr(module, 'bias') and module.bias is not None and hasattr(new_norm, 'bias') and new_norm.bias is not None:
if new_norm.bias.shape == module.bias.shape:
new_norm.bias.copy_(module.bias)
else:
logging.warning(f"Bias shape mismatch swapping norm {name}: {module.bias.shape} -> {new_norm.bias.shape}. Re-initializing target bias.")
nn.init.zeros_(new_norm.bias)
elif hasattr(new_norm, 'bias') and new_norm.bias is not None:
logging.debug(f"Initializing bias for new LayerNorm {name} as source RMSNorm lacked it.")
nn.init.zeros_(new_norm.bias)
if _recursive_setattr(model, name, new_norm):
swapped_count += 1
processed_names.add(name)
logging.debug(f"Swapped {current_norm_class.__name__} layer {name} to {new_norm_class.__name__}.")
else:
logging.warning(f"Failed to set swapped normalization layer for {name}")
processed_names.add(name)
except Exception as e:
logging.error(f"Error swapping norm layer {name} from {current_norm_class.__name__} to {new_norm_class.__name__}: {e}\n{traceback.format_exc()}")
processed_names.add(name)
if swapped_count > 0:
setattr(config, config_flag_name, target_flag_value)
if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False
msg = f"Swapped {swapped_count} {current_norm_class.__name__} layers to {new_norm_class.__name__}."
else:
if not already_configured:
setattr(config, config_flag_name, target_flag_value)
if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False
msg = f"No {current_norm_class.__name__} layers found or matched criteria to swap to {new_norm_class.__name__}. Updated config flag."
else:
msg = f"No {current_norm_class.__name__} layers were swapped (already configured or other issue)."
logging.info(msg)
return msg
def _normalize_embeddings(model, config):
emb_layer = None
if hasattr(model, 'get_input_embeddings'):
try:
emb_layer_candidate = model.get_input_embeddings()
if isinstance(emb_layer_candidate, nn.Embedding):
emb_layer = emb_layer_candidate
logging.info("Found embedding layer via get_input_embeddings()")
except Exception as e:
logging.warning(f"Error calling get_input_embeddings(): {e}")
if emb_layer is None:
potential_emb_names = ['embed_tokens', 'wte', 'word_embeddings', 'embeddings.word_embeddings', 'shared']
model_base = getattr(model, 'model', model)
for name in potential_emb_names:
try:
candidate = model_base
parts = name.split('.')
valid_path = True
for part in parts:
if hasattr(candidate, part):
candidate = getattr(candidate, part)
if candidate is None:
valid_path = False
break
else:
valid_path = False
break
if valid_path and isinstance(candidate, nn.Embedding) and hasattr(candidate, 'weight') and candidate.weight is not None:
emb_layer = candidate
logging.info(f"Found embedding layer via attribute: '{name}'")
break
except AttributeError:
continue
except Exception as e:
logging.warning(f"Error accessing potential embedding layer '{name}': {e}")
if emb_layer is not None and hasattr(emb_layer, 'weight') and emb_layer.weight is not None:
try:
with torch.no_grad():
w = emb_layer.weight.data
norms = torch.norm(w, p=2, dim=-1, keepdim=True)
safe_norms = norms.clamp(min=1e-12)
w.div_(safe_norms)
config.embedding_normalized = True
logging.info("Input embeddings normalized (L2 norm).")
return "Input embeddings normalized (L2 norm)."
except Exception as e:
logging.error(f"Error normalizing embeddings: {e}")
config.embedding_normalized = False
return f"Error normalizing embeddings: {e}"
else:
msg="Input embedding layer or its weights not found using common methods. Cannot normalize."
logging.warning(msg)
config.embedding_normalized = False
return msg
def _revert_embedding_normalization(model, config):
if not getattr(config, 'embedding_normalized', False):
return "Embedding normalization flag is already false (or was never applied)."
config.embedding_normalized = False
logging.info("Embedding normalization flag reverted. Note: Original embedding weights are NOT restored.")
return "Embedding normalization flag reverted (weights NOT restored)."
def _prune_weights_magnitude(model, config, amount=0.2):
if not isinstance(amount, (float, int)) or not (0 < amount < 1):
msg="Error: Pruning amount must be a float between 0 and 1 (exclusive)."
logging.error(msg)
return msg
logging.info(f"Applying global unstructured L1 magnitude pruning (amount={amount:.2f})...")
device = get_device()
model.to(device)
params_to_prune = []
for module_name, module in model.named_modules():
if isinstance(module, (nn.Linear, BitLinear)):
if hasattr(module, 'weight') and module.weight is not None and module.weight.requires_grad:
params_to_prune.append((module, 'weight'))
if not params_to_prune:
msg="No prunable Linear or BitLinear layers with trainable weights found."
logging.warning(msg)
config.pruning_applied = False
config.pruning_amount = None
return msg
try:
prune.global_unstructured(
parameters=params_to_prune,
pruning_method=prune.L1Unstructured,
amount=amount
)
pruned_count = 0
total_params = 0
modules_made_permanent = 0
for module, name in params_to_prune:
if prune.is_pruned(module):
prune.remove(module, name)
modules_made_permanent += 1
if hasattr(module, name):
weight = getattr(module, name)
if weight is not None:
pruned_count += torch.sum(weight == 0).item()
total_params += weight.nelement()
if modules_made_permanent > 0:
sparsity = 100. * pruned_count / total_params if total_params > 0 else 0
msg = (f"Pruning applied and made permanent on {modules_made_permanent} parameter groups. "
f"Final Sparsity: {sparsity:.2f}% ({pruned_count:,}/{total_params:,} zeros).")
config.pruning_applied = True
config.pruning_amount = amount
elif any(prune.is_pruned(mod) for mod, _ in params_to_prune):
msg = "Pruning hooks were applied but removal failed or was incomplete. Pruning might not be permanent."
config.pruning_applied = False
config.pruning_amount = None
else:
msg = "Pruning was attempted, but no modules seem to have been pruned or made permanent."
config.pruning_applied = False
config.pruning_amount = None
except Exception as e:
msg = f"Error during pruning: {e}\n{traceback.format_exc()}"
logging.error(msg)
for module, name in params_to_prune:
if prune.is_pruned(module):
try:
prune.remove(module, name)
logging.info(f"Cleaned up pruning hook from {module}.{name} during error handling.")
except Exception as remove_e:
logging.warning(f"Couldn't remove pruning hook from {module}.{name} during cleanup: {remove_e}")
config.pruning_applied = False
config.pruning_amount = None
logging.info(msg)
return msg
def _revert_pruning(model, config):
if not getattr(config, 'pruning_applied', False):
return "Pruning flag is already false (or pruning was never applied/made permanent)."
config.pruning_applied = False
config.pruning_amount = None
logging.info("Pruning flag reverted. Note: Pruned weights (zeros) are NOT restored.")
return "Pruning flag reverted (weights NOT restored)."
def _quantize_model(model, config, mode='bfloat16'):
logging.info(f"Attempting to change model dtype to {mode}...")
original_dtype_str = getattr(config, 'quantization_mode', DEFAULT_QUANTIZATION)
target_dtype = None
if mode == 'bfloat16':
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
target_dtype = torch.bfloat16
else:
msg="Device does not support bfloat16. Cannot quantize to bfloat16. Keeping current dtype."
logging.warning(msg)
return msg
elif mode == 'float16':
target_dtype = torch.float16
elif mode == 'float32':
target_dtype = torch.float32
else:
msg = f"Unsupported quantization mode '{mode}'. Choose from {QUANTIZATION_MODES}."
logging.error(msg)
return msg
try:
current_dtype = next(iter(model.parameters()), torch.tensor([])).dtype
if not isinstance(current_dtype, torch.dtype):
msg = "Model has no parameters. Cannot determine or change dtype."
logging.error(msg)
return msg
except StopIteration:
msg = "Model has no parameters. Cannot determine or change dtype."
logging.error(msg)
return msg
except Exception as e:
msg = f"Could not determine current model dtype: {e}"
logging.error(msg)
return msg
if current_dtype == target_dtype:
msg = f"Model is already in {mode} ({target_dtype}). No change needed."
logging.info(msg)
config.quantization_applied = (mode != 'float32')
config.quantization_mode = mode
config.perfect_precision_recovered = (mode == 'float32')
return msg
try:
device = get_device()
model.to(device=device, dtype=target_dtype)
new_dtype = next(iter(model.parameters()), torch.tensor([])).dtype
if new_dtype == target_dtype:
config.quantization_applied = (mode != 'float32')
config.quantization_mode = mode
config.perfect_precision_recovered = (mode == 'float32')
msg = f"Model dtype successfully changed to {mode} ({target_dtype}) on device {device}."
logging.info(msg)
clean_memory()
return msg
else:
logging.error(f"Model dtype did not change as expected after .to() call. Still {new_dtype}. Reverting config flags.")
config.quantization_applied = (original_dtype_str != 'float32')
config.quantization_mode = original_dtype_str
config.perfect_precision_recovered = (original_dtype_str == 'float32')
raise RuntimeError(f"Model dtype did not change as expected. Still {new_dtype}.")
except Exception as e:
msg=f"Error converting model to {target_dtype}: {e}\n{traceback.format_exc()}"
logging.error(msg)
config.quantization_applied = (original_dtype_str != 'float32')
config.quantization_mode = original_dtype_str
config.perfect_precision_recovered = (original_dtype_str == 'float32')
try:
original_torch_dtype = getattr(torch, original_dtype_str, torch.float32)
model.to(device=device, dtype=original_torch_dtype)
logging.info(f"Attempted to restore model to original dtype {original_dtype_str} after error.")
except Exception as revert_e:
logging.error(f"Failed to restore original dtype after error: {revert_e}")
return msg
def _revert_quantization(model, config):
logging.info("Reverting quantization to float32...")
return _quantize_model(model, config, mode='float32')
def _freeze_layers(model, config, layers_to_freeze_str):
if not layers_to_freeze_str or not isinstance(layers_to_freeze_str, str):
msg="No layers specified to freeze or invalid input type."
logging.warning(msg)
config.frozen_layers = None
return msg
layer_indices = set()
try:
raw_parts = layers_to_freeze_str.split(',')
for part in raw_parts:
part = part.strip()
if not part: continue
if '-' in part:
start_end = part.split('-')
if len(start_end) == 2:
s = int(start_end[0].strip())
e = int(start_end[1].strip())
if s < 0 or e < 0: raise ValueError("Negative indices are not allowed.")
if s <= e:
layer_indices.update(range(s, e + 1))
else:
layer_indices.update(range(e, s + 1))
logging.warning(f"Interpreted range '{part}' as descending: {list(range(e, s + 1))}")
else:
raise ValueError(f"Invalid range format: {part}")
else:
idx = int(part)
if idx < 0: raise ValueError("Negative indices are not allowed.")
layer_indices.add(idx)
except ValueError as e:
msg=f"Error parsing layer specification '{layers_to_freeze_str}': {e}. Use non-negative, comma-separated numbers or ranges (e.g., '0-3, 7, 10-11')."
logging.error(msg)
return msg
layer_module, layer_attr, layer_list = _find_decoder_layers_module(model)
if not (layer_module and layer_attr and layer_list is not None):
msg="Could not determine layer structure for freezing. No layers frozen."
logging.warning(msg)
return msg
total_layers = len(layer_list)
frozen_params_count = 0
actual_frozen_indices = set()
unfrozen_globally = 0
for param in model.parameters():
if not param.requires_grad:
param.requires_grad = True
unfrozen_globally += 1
if unfrozen_globally > 0:
logging.info(f"Unfroze {unfrozen_globally} parameters globally before applying new freeze spec.")
else:
logging.info("No parameters were frozen globally before applying new spec.")
invalid_indices_skipped = set()
for i in layer_indices:
if 0 <= i < total_layers:
try:
current_layer = layer_list[i]
params_in_layer = 0
for param in current_layer.parameters():
if param.requires_grad:
param.requires_grad = False
frozen_params_count += 1
params_in_layer += 1
if params_in_layer > 0:
actual_frozen_indices.add(i)
logging.debug(f"Froze {params_in_layer} parameters in layer {i}.")
else:
logging.debug(f"Layer {i} had no trainable parameters to freeze.")
except IndexError:
logging.warning(f"Index {i} seems out of bounds for layer list during freezing loop, although check passed earlier. Skipping.")
invalid_indices_skipped.add(i)
except Exception as e:
logging.error(f"Error accessing or freezing parameters for layer {i}: {e}")
invalid_indices_skipped.add(i)
else:
logging.warning(f"Layer index {i} is out of bounds (0-{total_layers-1}). Skipping.")
invalid_indices_skipped.add(i)
frozen_list_str = ",".join(map(str, sorted(list(actual_frozen_indices))))
config.frozen_layers = frozen_list_str if actual_frozen_indices else None
msg = f"Froze {frozen_params_count} parameters in layers: {frozen_list_str} (Total layers: {total_layers})."
if invalid_indices_skipped:
msg += f" Skipped invalid indices: {sorted(list(invalid_indices_skipped))}."
if frozen_params_count == 0 and not invalid_indices_skipped:
msg = f"No parameters were frozen. Specified layers {frozen_list_str} might have already been frozen or had no trainable params."
logging.info(msg)
return msg
def _unfreeze_all_layers(model, config):
unfrozen_count = 0
for name, param in model.named_parameters():
if not param.requires_grad:
param.requires_grad = True
unfrozen_count += 1
config.frozen_layers = None
msg = f"Unfroze {unfrozen_count} parameters across the entire model." if unfrozen_count > 0 else "No parameters needed unfreezing."
logging.info(msg)
return msg
def _enable_gradient_checkpointing(model, config):
gc_enabled_in_model = False
if hasattr(model, 'gradient_checkpointing_enable'):
try:
sig = inspect.signature(model.gradient_checkpointing_enable)
if 'gradient_checkpointing_kwargs' in sig.parameters:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
msg = "Gradient Checkpointing enabled via model method (non-reentrant)."
else:
model.gradient_checkpointing_enable()
msg = "Gradient Checkpointing enabled via model method."
gc_enabled_in_model = True
logging.info(msg)
except Exception as e:
logging.warning(f"Failed to enable gradient checkpointing via standard method: {e}. Trying config attribute.")
if hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing'):
if not gc_enabled_in_model:
logging.info("Enabling gradient checkpointing via model config attribute.")
model.config.gradient_checkpointing = True
gc_enabled_in_model = True
if gc_enabled_in_model:
if hasattr(model.config, 'use_cache'):
if model.config.use_cache:
model.config.use_cache = False
logging.info("Set model.config.use_cache = False (required for Gradient Checkpointing).")
else:
logging.warning("Model config missing 'use_cache' attribute. Gradient checkpointing might not work correctly or efficiently.")
config.gradient_checkpointing_enabled = True
final_msg = "Gradient Checkpointing enabled."
if not hasattr(model, 'gradient_checkpointing_enable') and not (hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing')):
final_msg += " (Set via main config flag only; ensure Trainer args/model support it)."
return final_msg
else:
config.gradient_checkpointing_enabled = False
msg = "Could not enable Gradient Checkpointing via model methods or config attributes."
logging.error(msg)
return f"[Error] {msg}"
def _disable_gradient_checkpointing(model, config):
gc_disabled_in_model = False
if hasattr(model, 'gradient_checkpointing_disable'):
try:
model.gradient_checkpointing_disable()
gc_disabled_in_model = True
logging.info("Gradient Checkpointing disabled via model method.")
except Exception as e:
logging.warning(f"Failed to disable gradient checkpointing via standard method: {e}. Trying config attribute.")
if hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing'):
if not gc_disabled_in_model:
logging.info("Disabling gradient checkpointing via model config attribute.")
model.config.gradient_checkpointing = False
gc_disabled_in_model = True
if gc_disabled_in_model:
if hasattr(model.config, 'use_cache'):
if not model.config.use_cache:
model.config.use_cache = True
logging.info("Set model.config.use_cache = True (restored after disabling Gradient Checkpointing).")
config.gradient_checkpointing_enabled = False
final_msg = "Gradient Checkpointing disabled."
if not hasattr(model, 'gradient_checkpointing_disable') and not (hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing')):
final_msg += " (Set via main config flag only)."
return final_msg
else:
config.gradient_checkpointing_enabled = False
msg = "Could not disable Gradient Checkpointing via model methods or config attributes (may not have been enabled)."
logging.warning(msg)
return msg
def _swap_optimizer(config, optimizer_name):
if optimizer_name in OPTIMIZERS:
config.optimizer = optimizer_name
global DEFAULT_OPTIMIZER
DEFAULT_OPTIMIZER = optimizer_name
msg=f"Optimizer preference set to '{optimizer_name}' in config. This will be used by the Trainer if training starts."
logging.info(msg)
return msg
else:
available_opts = ", ".join(OPTIMIZERS.keys())
msg=f"Error: Optimizer '{optimizer_name}' unknown or not available. Choose from: {available_opts}."
logging.error(msg)
return msg
def _revert_optimizer(config):
original_default_optimizer = "adamw_torch"
logging.info(f"Reverting optimizer preference to script default: '{original_default_optimizer}'.")
return _swap_optimizer(config, original_default_optimizer)
def _untie_embeddings(model, config):
try:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
if output_embeddings is None:
if hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear):
output_embeddings = model.lm_head
logging.info("Using 'lm_head' as the output embedding layer for untie check.")
else:
msg="Could not get output embedding layer (get_output_embeddings() returned None and 'lm_head' not found/Linear). Cannot untie."
logging.warning(msg)
config.untied_embeddings = True
if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False
return msg
if input_embeddings is None:
msg="Could not get input embedding layer. Cannot untie."
logging.warning(msg)
config.untied_embeddings = False
if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True
return msg
are_tied = False
if hasattr(input_embeddings, "weight") and hasattr(output_embeddings, "weight") and \
input_embeddings.weight is not None and output_embeddings.weight is not None:
if input_embeddings.weight.data_ptr() == output_embeddings.weight.data_ptr():
are_tied = True
elif input_embeddings.weight.storage().data_ptr() == output_embeddings.weight.storage().data_ptr():
are_tied = True
logging.info("Weights appear tied (share storage).")
if are_tied:
logging.info("Detected tied input/output embeddings. Attempting to untie...")
device = input_embeddings.weight.device
dtype = input_embeddings.weight.dtype
new_output_weight = input_embeddings.weight.clone().detach()
new_output_weight.requires_grad_(output_embeddings.weight.requires_grad)
output_embeddings.weight = nn.Parameter(new_output_weight.to(device, dtype=dtype))
if hasattr(input_embeddings, "bias") and input_embeddings.bias is not None and \
hasattr(output_embeddings, "bias") and output_embeddings.bias is not None and \
input_embeddings.bias.data_ptr() == output_embeddings.bias.data_ptr():
logging.info("Detected tied bias, untying as well.")
new_output_bias = input_embeddings.bias.clone().detach()
new_output_bias.requires_grad_(output_embeddings.bias.requires_grad)
output_embeddings.bias = nn.Parameter(new_output_bias.to(device, dtype=dtype))
if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False
config.untied_embeddings = True
msg="Embeddings untied successfully (output layer weights/bias are now distinct copies)."
logging.info(msg)
clean_memory()
return msg
else:
config.untied_embeddings = True
if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False
msg="Embeddings are already untied (or weights are missing/different objects)."
logging.info(msg)
return msg
except Exception as e:
msg=f"Error untying embeddings: {e}\n{traceback.format_exc()}"
logging.error(msg)
return msg
def _retie_embeddings(model, config):
if not getattr(config, 'untied_embeddings', False):
try:
input_emb = model.get_input_embeddings()
output_emb = model.get_output_embeddings()
if output_emb is None and hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear):
output_emb = model.lm_head
if input_emb is not None and output_emb is not None and \
hasattr(input_emb, 'weight') and input_emb.weight is not None and \
hasattr(output_emb, 'weight') and output_emb.weight is not None and \
input_emb.weight.data_ptr() == output_emb.weight.data_ptr():
msg = "Embeddings seem already tied. Resetting flag if needed."
config.untied_embeddings = False
if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True
logging.info(msg)
return msg
else:
msg = "Cannot re-tie: Flag 'untied_embeddings' is false or cannot verify current state."
logging.info(msg)
return msg
except Exception as e:
msg = f"Cannot re-tie: Error checking current state ({e}). Flag 'untied_embeddings' is false."
logging.warning(msg)
return msg
try:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
if output_embeddings is None and hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear):
output_embeddings = model.lm_head
logging.info("Using 'lm_head' as output layer for re-tying.")
if input_embeddings is None or output_embeddings is None:
msg="Could not get both input and output embedding layers for re-tying."
logging.warning(msg)
return msg
if hasattr(input_embeddings, "weight") and input_embeddings.weight is not None and \
hasattr(output_embeddings, "weight") and output_embeddings.weight is not None:
if input_embeddings.weight.shape == output_embeddings.weight.shape:
logging.info("Attempting to re-tie embeddings by sharing input embedding weight...")
device = input_embeddings.weight.device
dtype = input_embeddings.weight.dtype
output_embeddings = output_embeddings.to(device=device, dtype=dtype)
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
logging.info("Setting output embedding bias to None as part of re-tying.")
output_embeddings.bias = None
if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True
config.untied_embeddings = False
msg="Embeddings re-tied successfully (output layer now shares input layer's weight, bias set to None)."
logging.info(msg)
clean_memory()
return msg
else:
msg = f"Cannot re-tie embeddings: Weight shapes mismatch. Input: {input_embeddings.weight.shape}, Output: {output_embeddings.weight.shape}."
logging.warning(msg)
return msg
else:
msg = "Cannot re-tie embeddings: Input or output embedding weights missing or None."
logging.warning(msg)
return msg
except Exception as e:
msg=f"Error re-tying embeddings: {e}\n{traceback.format_exc()}"
logging.error(msg)
return msg
def _configure_limits(config):
config.knowledge_date = "2045-03-28"
config.cutoff_date = "2045-03-28"
current_max_pos = getattr(config, 'max_position_embeddings', 512)
new_max_pos = current_max_pos * 100
config.max_position_embeddings = new_max_pos
config.limits_configured = True
config.no_limits = True
logging.info(f"Set knowledge/cutoff date flags and increased max_position_embeddings in config to {config.max_position_embeddings}.")
return f"Limit-related flags configured (Knowledge Date: 2045, Max Pos Emb: {config.max_position_embeddings}). Requires model reload or RoPE scaling for actual effect."
def _remove_limits_configuration(config):
if not getattr(config, 'limits_configured', False):
return "Limit configuration flags are already in their default state."
config.knowledge_date = None
config.cutoff_date = None
config.limits_configured = False
config.no_limits = False
logging.info("Reset knowledge date and cutoff date flags in config. Max position embeddings remain modified.")
return "Limit-related flags removed/reset. Max position embeddings NOT reverted."
def _remove_qa_restrictions(config):
config.qa_restrictions_removed = True
logging.info("QA restrictions removal flag set in config. Actual effect depends on model usage/fine-tuning and inference logic.")
return "QA Restrictions Removal Flag Enabled (symbolic)."
def _enable_qa_restrictions(config):
config.qa_restrictions_removed = False
logging.info("QA restrictions removal flag disabled in config.")
return "QA Restrictions Removal Flag Disabled (symbolic)."
def _enable_coherence_improvement(config):
config.coherence_improvement_enabled = True
logging.info("Coherence improvement flag enabled. Inference will use beam search if this is active.")
return "Coherence Improvement Flag ON (uses beam search in inference)."
def _disable_coherence_improvement(config):
config.coherence_improvement_enabled = False
logging.info("Coherence improvement flag disabled.")
return "Coherence Improvement Flag OFF."
def _set_flag_only(config, flag_name, value, msg_on, msg_off):
if not hasattr(config, flag_name):
logging.warning(f"Config object does not have flag '{flag_name}'. Adding it.")
bool_value = bool(value)
setattr(config, flag_name, bool_value)
msg = msg_on if bool_value else msg_off
logging.info(f"Config flag '{flag_name}' set to {bool_value}. Message: {msg}")
return msg
def _apply_swa(model, config): return _set_flag_only(config, "swa_applied", True, "SWA flag set. Requires SWA callback/logic during training.", "SWA flag disabled.")
def _revert_swa(model, config): return _set_flag_only(config, "swa_applied", False, "SWA flag set.", "SWA flag disabled.")
def _apply_knowledge_editing(model, config): return _set_flag_only(config, "knowledge_edited", True, "Knowledge Editing flag set. Indicates manual edits or specific editing techniques were applied (symbolic).", "Knowledge Editing flag disabled.")
def _revert_knowledge_editing(model, config): return _set_flag_only(config, "knowledge_edited", False, "Knowledge Editing flag set.", "Knowledge Editing flag disabled.")
def _apply_head_pruning(model, config): return _set_flag_only(config, "head_pruning_applied", True, "Head Pruning flag set. Requires specific pruning implementation outside this script (symbolic).", "Head Pruning flag disabled.")
def _revert_head_pruning(model, config): return _set_flag_only(config, "head_pruning_applied", False, "Head Pruning flag set.", "Head Pruning flag disabled.")
def _apply_qat(model, config): return _set_flag_only(config, "qat_applied", True, "QAT flag set. Requires Quantization-Aware Training setup and execution (symbolic).", "QAT flag disabled.")
def _revert_qat(model, config): return _set_flag_only(config, "qat_applied", False, "QAT flag set.", "QAT flag disabled.")
def _apply_architecture_merge_flag(model, config): return _set_flag_only(config, "architecture_merged", True, "Architecture Merged flag set. Indicates model is likely a result of parameter averaging.", "Architecture Merged flag disabled.")
def _revert_architecture_merge_flag(model, config): return _set_flag_only(config, "architecture_merged", False, "Architecture Merged flag set.", "Architecture Merged flag disabled.")
def _apply_weight_init(model, config): return _set_flag_only(config, "weight_init_applied", True, "Weight Initialization flag set. Indicates a specific init strategy was used (symbolic).", "Weight Initialization flag disabled.")
def _revert_weight_init(model, config): return _set_flag_only(config, "weight_init_applied", False, "Weight Initialization flag set.", "Weight Initialization flag disabled.")
def _apply_gradient_noise(model, config): return _set_flag_only(config, "gradient_noise_applied", True, "Gradient Noise flag set. Requires implementation in optimizer/trainer (symbolic).", "Gradient Noise flag disabled.")
def _revert_gradient_noise(model, config): return _set_flag_only(config, "gradient_noise_applied", False, "Gradient Noise flag set.", "Gradient Noise flag disabled.")
def _apply_additional_mechanisms(base_model, config):
logging.info("Applying various additional experimental mechanisms flags and simple optimizations...")
_set_flag_only(config, "enhanced_security_enabled", True, "Enhanced Security Flag ON.", "Enhanced Security Flag OFF.")
_set_flag_only(config, "debug_mode_enabled", True, "Debug Mode Flag ON.", "Debug Mode Flag OFF.")
_set_flag_only(config, "internal_logging_enabled", True, "Internal Logging Flag ON.", "Internal Logging Flag OFF.")
_set_flag_only(config, "drift_detection_enabled", True, "Drift Detection Flag ON.", "Drift Detection Flag OFF.")
_set_flag_only(config, "ultra_fast_mode", True, "Ultra Fast Mode Flag ON.", "Ultra Fast Mode Flag OFF.")
coherence_msg = _enable_coherence_improvement(config)
speed_msg = _optimize_token_generation_speed(config)
config.additional_mechanisms_applied = True
logging.info("Applied various additional mechanism flags and optimizations.")
return f"Applied Additional Mechanism Flags & Optimizations. Coherence: {coherence_msg}, Speed: {speed_msg}"
def _disable_additional_mechanisms(config):
if not getattr(config, 'additional_mechanisms_applied', False):
return "Additional mechanisms flag is already off. No changes made."
logging.info("Disabling various additional experimental mechanisms flags and reverting optimizations...")
_set_flag_only(config, "enhanced_security_enabled", False, "Enhanced Security Flag ON.", "Enhanced Security Flag OFF.")
_set_flag_only(config, "debug_mode_enabled", False, "Debug Mode Flag ON.", "Debug Mode Flag OFF.")
_set_flag_only(config, "internal_logging_enabled", False, "Internal Logging Flag ON.", "Internal Logging Flag OFF.")
_set_flag_only(config, "drift_detection_enabled", False, "Drift Detection Flag ON.", "Drift Detection Flag OFF.")
_set_flag_only(config, "ultra_fast_mode", False, "Ultra Fast Mode Flag ON.", "Ultra Fast Mode Flag OFF.")
coherence_msg = _disable_coherence_improvement(config)
speed_msg = _revert_token_generation_speed_optimization(config)
config.additional_mechanisms_applied = False
logging.info("Disabled various additional mechanism flags and reverted optimizations.")
return f"Disabled Additional Mechanism Flags & Reverted Optimizations. Coherence: {coherence_msg}, Speed: {speed_msg}"
def _disable_all_safety_settings(config):
flags_to_disable = [
"response_filters", "safety_settings_enabled",
"harassment_filter", "hate_filter", "sexually_explicit_filter",
"dangerous_content_filter", "civic_integrity_filter", "code_filter",
"medical_advice_filter", "legal_advice_filter", "financial_advice_filter",
"pii_filter", "political_filter", "religious_filter", "profanity_filter",
"stereotype_filter", "misinfo_filter", "self_harm_filter",
"personal_attack_filter", "toxicity_filter", "spam_filter",
"off_topic_filter", "tone_filter", "min_max_length_filter",
"repetition_filter_enabled", "factuality_filter_enabled"
]
flags_to_set_true = [
"remove_censorship", "no_response_filters", "no_advert_warning", "no_limits"
]
config = initialize_config_flags(config)
updated_flags = 0
for flag in flags_to_disable:
if hasattr(config, flag) and getattr(config, flag) is not False:
setattr(config, flag, False)
updated_flags += 1
for flag in flags_to_set_true:
if hasattr(config, flag) and getattr(config, flag) is not True:
setattr(config, flag, True)
updated_flags += 1
config.safety_settings_enabled = False
config.response_filters = False
logging.info(f"Disabled all known safety/content filters and related flags in config ({updated_flags} flags updated).")
return "All safety filter flags disabled in config."
def _enable_all_safety_settings(config):
flags_to_set_default_true = [
"safety_settings_enabled", "response_filters",
"harassment_filter", "hate_filter", "sexually_explicit_filter",
"dangerous_content_filter", "self_harm_filter", "pii_filter",
"min_max_length_filter",
"toxicity_filter", "personal_attack_filter",
]
flags_to_set_optional_true = [
"civic_integrity_filter", "code_filter",
"medical_advice_filter", "legal_advice_filter", "financial_advice_filter",
"political_filter", "religious_filter", "profanity_filter",
"stereotype_filter", "misinfo_filter",
"spam_filter", "off_topic_filter", "tone_filter"
]
flags_to_set_false = [
"remove_censorship", "no_response_filters", "no_advert_warning", "no_limits"
]
flags_to_set_default_false = [
"repetition_filter_enabled", "factuality_filter_enabled"
]
config = initialize_config_flags(config)
updated_flags = 0
all_flags_to_enable = flags_to_set_default_true + flags_to_set_optional_true
for flag in all_flags_to_enable:
if hasattr(config, flag) and getattr(config, flag) is not True:
setattr(config, flag, True)
updated_flags += 1
for flag in flags_to_set_false:
if hasattr(config, flag) and getattr(config, flag) is not False:
setattr(config, flag, False)
updated_flags += 1
for flag in flags_to_set_default_false:
if hasattr(config, flag) and getattr(config, flag) is not False:
setattr(config, flag, False)
updated_flags += 1
config.safety_settings_enabled = True
config.response_filters = True
logging.info(f"Enabled default safety/content filters and related flags in config ({updated_flags} flags updated).")
return "Default safety filter flags enabled in config."
def _remove_inconsistencias_and_biases(base_model, config):
bias_adjusted_count = 0
params_adjusted_count = 0
device = get_device()
base_model.to(device)
if getattr(config, 'inconsistencies_biases_removed', False):
return "Inconsistencies/Biases removal flag already set. No action taken."
with torch.no_grad():
for name, param in base_model.named_parameters():
if "bias" in name and isinstance(param, nn.Parameter) and param.requires_grad:
if any(lin_name in name.lower() for lin_name in ['linear', 'dense', 'fc', 'out_proj', 'q_proj', 'k_proj', 'v_proj', 'wi', 'wo', 'lm_head']):
try:
original_mean = torch.mean(param.data.float()).item()
if abs(original_mean) > 1e-6:
param.sub_(original_mean)
bias_adjusted_count += 1
params_adjusted_count += param.numel()
logging.debug(f"Centered bias for {name} (original mean: {original_mean:.4e})")
except Exception as e:
logging.warning(f"Could not center bias for {name}: {e}")
if bias_adjusted_count > 0:
config.inconsistencies_biases_removed = True
logging.info(f"Centered {bias_adjusted_count} bias terms ({params_adjusted_count} parameters) to potentially reduce inconsistencies.")
return f"{bias_adjusted_count} bias terms centered."
else:
config.inconsistencies_biases_removed = True
logging.info("Attempted bias centering, but no adjustable bias terms with significant mean found or no bias terms present.")
return "Attempted bias centering (no significant changes made or no biases found)."
def _reenable_inconsistencias_and_biases(config):
if not getattr(config, 'inconsistencies_biases_removed', False):
return "Inconsistencies/Biases removal flag already disabled."
config.inconsistencies_biases_removed = False
logging.info("Inconsistencies/Biases removal flag reverted. Note: Original bias values are NOT restored.")
return "Inconsistencies/Biases removal flag reverted (biases NOT restored)."
def _enable_layerdrop(config, probability=0.1):
if not isinstance(probability, (float, int)) or not (0 <= probability <= 1):
msg=f"Error: LayerDrop probability must be between 0 and 1. Got {probability}."
logging.error(msg)
return msg
if hasattr(config, 'layerdrop'):
config.layerdrop = float(probability)
else:
logging.warning("Config does not have a standard 'layerdrop' attribute. Setting custom flag only.")
setattr(config, 'layerdrop', float(probability))
config.layerdrop_enabled = (probability > 0)
config.layerdrop_prob = float(probability)
logging.info(f"LayerDrop enabled flag set in config with probability {probability}. Actual effect depends on model architecture support during training/inference.")
return f"LayerDrop flag {'ON' if probability > 0 else 'OFF'} (p={probability:.2f}). Requires model/Trainer support."
def _disable_layerdrop(config):
return _enable_layerdrop(config, probability=0.0)
def _apply_lora_merge(model, config):
global global_model
adapter_path = getattr(config, 'lora_adapter_path', None)
if not adapter_path:
msg="No LoRA adapter path specified in config ('lora_adapter_path'). Use 'Set Path in Config' first or train/load an adapter."
logging.warning(msg)
return msg
if not _peft_installed:
msg="Error: PEFT library not installed, cannot merge LoRA."
logging.error(msg)
return msg
current_model = model
if not isinstance(current_model, PeftModel):
logging.warning(f"Model is not a PeftModel. Attempting to load adapter '{adapter_path}' onto it first.")
try:
peft_model_instance = PeftModel.from_pretrained(current_model, adapter_path, is_trainable=False)
current_model = peft_model_instance
logging.info(f"Successfully loaded adapter '{adapter_path}' onto the base model.")
except Exception as e:
msg = f"Error loading adapter '{adapter_path}' onto base model: {e}\n{traceback.format_exc()}"
logging.error(msg)
return msg
else:
active_adapter = getattr(current_model, 'active_adapter', 'default')
target_adapter_name = os.path.basename(os.path.normpath(adapter_path))
if not target_adapter_name: target_adapter_name = 'default'
if target_adapter_name not in current_model.peft_config:
logging.info(f"Adapter '{target_adapter_name}' (from path {adapter_path}) not found in existing PeftModel config. Loading it now.")
try:
current_model.load_adapter(adapter_path, adapter_name=target_adapter_name, is_trainable=False)
logging.info(f"Loaded new adapter '{target_adapter_name}'.")
except Exception as e:
msg = f"Error loading adapter '{target_adapter_name}' from path '{adapter_path}' onto existing PeftModel: {e}\n{traceback.format_exc()}"
logging.error(msg)
return msg
if active_adapter != target_adapter_name:
try:
current_model.set_adapter(target_adapter_name)
logging.info(f"Set active adapter to '{target_adapter_name}' for merging.")
active_adapter = target_adapter_name
except Exception as e:
msg = f"Error setting adapter '{target_adapter_name}' active on existing PeftModel: {e}\n{traceback.format_exc()}"
logging.error(msg)
return msg
else:
active_adapter = target_adapter_name
try:
logging.info(f"Merging active LoRA adapter ('{active_adapter}') into the base model..."); T = time.time()
merged_model = current_model.merge_and_unload()
merge_time = time.time() - T
merged_config = merged_model.config
merged_config = initialize_config_flags(merged_config)
merged_config.lora_merged = True
merged_config.lora_adapter_path = adapter_path
merged_config.peft_adapter_added = False
merged_config.peft_config = None
global_model = merged_model
config = merged_config
msg = f"LoRA adapter '{active_adapter}' (from {adapter_path}) merged successfully in {merge_time:.2f}s. Global model updated to the merged base model."
logging.info(msg)
clean_memory()
return msg
except ValueError as ve:
msg = f"Error merging LoRA adapter '{active_adapter}': {ve}. Adapter type might not support merging."
logging.error(msg)
return msg
except Exception as e:
msg = f"Error merging LoRA adapter '{active_adapter}': {e}\n{traceback.format_exc()}"
logging.error(msg)
return msg
def _revert_lora_merge(model, config):
if not getattr(config, 'lora_merged', False):
return "LoRA merge flag is already false (or merge never applied/recorded)."
config.lora_merged = False
config.lora_adapter_path = None
msg = "LoRA merge flag reverted. IMPORTANT: Model weights are NOT restored to pre-merge state. Reload the original base model if needed.";
logging.warning(msg)
return msg
def _set_lora_adapter_path(config, path):
if path and isinstance(path, str) and path.strip():
path = path.strip()
config.lora_adapter_path = path
msg = f"LoRA adapter path set in config to: '{path}'"
logging.info(msg)
return msg
else:
msg = "Invalid or empty LoRA adapter path provided. Path not set."
logging.warning(msg)
return msg
def _setup_knowledge_distillation(model, config, num_labels=2):
if not isinstance(num_labels, int) or num_labels <= 0:
msg = f"Error: Number of labels for KD must be a positive integer, got {num_labels}."
logging.error(msg)
return msg
try:
device=get_device()
try:
dtype = next(iter(model.parameters())).dtype
except StopIteration:
dtype = torch.float32
if not isinstance(dtype, torch.dtype): dtype = torch.float32
classifier_name = 'kd_classifier'
if hasattr(model, classifier_name):
logging.warning(f"Model already has an attribute named '{classifier_name}'. Overwriting.")
hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'embed_dim', None)))
if not isinstance(hidden_size, int) or hidden_size <= 0:
raise ValueError("Cannot setup KD: Model config missing valid 'hidden_size', 'd_model', or 'embed_dim' attribute.")
classifier_layer = nn.Linear(hidden_size, num_labels).to(device, dtype=dtype)
nn.init.xavier_uniform_(classifier_layer.weight)
if classifier_layer.bias is not None:
nn.init.zeros_(classifier_layer.bias)
setattr(model, classifier_name, classifier_layer)
if not hasattr(config, 'num_labels') or config.num_labels is None:
config.num_labels = num_labels
else:
logging.warning(f"Model config already has 'num_labels'={config.num_labels}. KD setup might conflict if used for other classification tasks.")
config.knowledge_distillation_setup = True
config.kd_num_labels = num_labels
msg = (f"Knowledge Distillation head ('{classifier_name}') added with {num_labels} labels (outputs). "
f"Requires training changes: loss calculation using this head (e.g., cross-entropy on its logits), "
f"and appropriate data format (e.g., sequence inputs + target labels).")
logging.info(msg)
return msg
except Exception as e:
msg = f"Error setting up Knowledge Distillation head: {e}\n{traceback.format_exc()}"
logging.error(msg)
if hasattr(model, 'kd_classifier'): delattr(model, 'kd_classifier')
config.knowledge_distillation_setup = False
config.kd_num_labels = None
return msg
def _revert_knowledge_distillation(model, config):
classifier_name = 'kd_classifier'
if hasattr(model, classifier_name):
delattr(model, classifier_name)
config.knowledge_distillation_setup = False
config.kd_num_labels = None
msg = f"Knowledge Distillation setup reverted (removed '{classifier_name}' head and reset config flags)."
logging.info(msg)
clean_memory()
return msg
else:
config.knowledge_distillation_setup = False
config.kd_num_labels = None
msg = f"Knowledge Distillation head ('{classifier_name}') not found, nothing to revert. Reset flags."
logging.info(msg)
return msg
def _setup_reward_modeling(model, config, num_outputs=1):
if not isinstance(num_outputs, int) or num_outputs <= 0:
msg = f"Error: Number of outputs for Reward Model must be a positive integer, got {num_outputs}."
logging.error(msg)
return msg
try:
device=get_device()
try:
dtype = next(iter(model.parameters())).dtype
except StopIteration:
dtype = torch.float32
if not isinstance(dtype, torch.dtype): dtype = torch.float32
rm_head_name = 'reward_head'
if hasattr(model, rm_head_name):
logging.warning(f"Model already has an attribute named '{rm_head_name}'. Overwriting.")
hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'embed_dim', None)))
if not isinstance(hidden_size, int) or hidden_size <= 0:
raise ValueError("Cannot setup Reward Model head: Model config missing valid 'hidden_size', 'd_model', or 'embed_dim'.")
reward_head = nn.Linear(hidden_size, num_outputs).to(device, dtype=dtype)
nn.init.xavier_uniform_(reward_head.weight)
if reward_head.bias is not None:
nn.init.zeros_(reward_head.bias)
setattr(model, rm_head_name, reward_head)
config.reward_modeling_setup = True
config.rm_num_outputs = num_outputs
msg = (f"Reward Modeling head ('{rm_head_name}') added with {num_outputs} output(s). "
f"Requires training changes: loss targeting rewards (e.g., ranking loss), specific data format (prompt, chosen_resp, rejected_resp), "
f"and likely using the final hidden state of the sequence as input to this head.")
logging.info(msg)
return msg
except Exception as e:
msg = f"Error setting up Reward Modeling head: {e}\n{traceback.format_exc()}"
logging.error(msg)
if hasattr(model, 'reward_head'): delattr(model, 'reward_head')
config.reward_modeling_setup = False
config.rm_num_outputs = None
return msg
def _revert_reward_modeling(model, config):
rm_head_name = 'reward_head'
if hasattr(model, rm_head_name):
delattr(model, rm_head_name)
config.reward_modeling_setup = False
config.rm_num_outputs = None
msg = f"Reward Modeling setup reverted (removed '{rm_head_name}' head and reset config flags)."
logging.info(msg)
clean_memory()
return msg
else:
config.reward_modeling_setup = False
config.rm_num_outputs = None
msg = f"Reward Modeling head ('{rm_head_name}') not found, nothing to revert. Reset flags."
logging.info(msg)
return msg
def _set_rope_scaling_config(model, config, scaling_type="linear", factor=2.0):
valid_types = ["linear", "dynamic"]
if not scaling_type or not isinstance(scaling_type, str) or scaling_type not in valid_types:
msg = f"Error: RoPE scaling type must be one of {valid_types}. Got '{scaling_type}'."
logging.error(msg)
return msg
try:
factor = float(factor)
if factor < 1.0: raise ValueError("Factor must be >= 1.0.")
if factor == 1.0: logging.warning(f"RoPE scaling factor set to {factor}, which implies no scaling.")
except (ValueError, TypeError) as e:
msg=f"Error: Invalid RoPE scaling factor '{factor}'. Must be a number >= 1.0. Error: {e}"
logging.error(msg)
return msg
rope_config = {"type": scaling_type, "factor": factor}
config.rope_scaling = rope_config
config.rope_scaling_type = scaling_type
config.rope_scaling_factor = factor
msg = (f"RoPE Scaling set in config: type='{scaling_type}', factor={factor:.2f}. "
f"Requires model architecture support and **reloading the model** with this config for the changes to take effect.")
logging.warning(msg)
return msg
def _revert_rope_scaling(model, config):
if hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
config.rope_scaling = None
config.rope_scaling_type = None
config.rope_scaling_factor = None
msg = "RoPE Scaling configuration removed from config. Model reload required to revert RoPE behavior."
logging.warning(msg)
return msg
else:
config.rope_scaling_type = None
config.rope_scaling_factor = None
msg = "RoPE Scaling was not configured. No changes made."
logging.info(msg)
return msg
def _set_sliding_window_config(model, config, window_size=4096):
try:
window_size = int(window_size)
if window_size < 0: raise ValueError("Window size must be non-negative (0 or None to disable).")
except (ValueError, TypeError) as e:
msg=f"Error: Invalid sliding window size '{window_size}'. Must be a non-negative integer. Error: {e}"
logging.error(msg)
return msg
effective_window_size = window_size if window_size > 0 else None
config.sliding_window = effective_window_size
config.sliding_window_size = effective_window_size
if effective_window_size:
msg = (f"Sliding Window Attention size set in config to: {effective_window_size}. "
f"Requires model architecture support (e.g., Mistral) and potentially reloading the model.")
else:
msg = "Sliding Window Attention disabled in config (size set to 0 or None). Model reload may be needed."
logging.warning(msg)
return msg
def _revert_sliding_window(model, config):
if hasattr(config, 'sliding_window') and config.sliding_window is not None:
config.sliding_window = None
config.sliding_window_size = None
msg = "Sliding Window Attention configuration removed from config. Model reload may be needed to revert behavior."
logging.warning(msg)
return msg
else:
config.sliding_window_size = None
msg = "Sliding Window Attention was not configured. No changes made."
logging.info(msg)
return msg
def _set_attention_variant_config(model, config, variant="auto"):
valid_variants = ["auto", "eager", "sdpa", "flash_attention_2"]
if not variant or not isinstance(variant, str) or variant not in valid_variants:
msg = f"Error: Invalid attention variant '{variant}'. Choose from: {', '.join(valid_variants)}."
logging.error(msg)
return msg
config.attn_implementation = variant
config.attention_variant = variant
config.use_flash_attention_2 = (variant == "flash_attention_2")
msg = (f"Attention implementation preference set in config to: '{variant}'. "
f"Effective implementation depends on model, hardware, and transformers version. **Requires model reload** to take effect.")
logging.warning(msg)
return msg
def _revert_attention_variant(model, config):
default_variant = "auto"
current_variant = getattr(config, 'attn_implementation', default_variant)
if current_variant != default_variant:
config.attn_implementation = default_variant
config.attention_variant = default_variant
config.use_flash_attention_2 = False
msg = f"Attention implementation preference reverted to '{default_variant}' in config. Model reload required."
logging.warning(msg)
return msg
else:
config.attention_variant = default_variant
config.use_flash_attention_2 = False
msg = f"Attention implementation preference is already '{default_variant}' or was not set. No changes made."
logging.info(msg)
return msg
def _enable_gradient_clipping(config): return _set_flag_only(config, "gradient_clipping_disabled", False, "Gradient Clipping Enabled (flag for Trainer).", "Gradient Clipping Disabled.")
def _disable_gradient_clipping(config): return _set_flag_only(config, "gradient_clipping_disabled", True, "Gradient Clipping Enabled.", "Gradient Clipping Disabled (flag for Trainer).")
def _enable_weight_decay(config): return _set_flag_only(config, "weight_decay_disabled", False, "Weight Decay Enabled (flag for Trainer).", "Weight Decay Disabled.")
def _disable_weight_decay(config): return _set_flag_only(config, "weight_decay_disabled", True, "Weight Decay Enabled.", "Weight Decay Disabled (flag for Trainer).")
def _enable_lr_scheduler(config): return _set_flag_only(config, "lr_scheduler_disabled", False, "LR Scheduler Enabled (flag for Trainer).", "LR Scheduler Disabled.")
def _disable_lr_scheduler(config): return _set_flag_only(config, "lr_scheduler_disabled", True, "LR Scheduler Enabled.", "LR Scheduler Disabled (flag for Trainer).")
def _enable_enhanced_security(config): return _set_flag_only(config, "enhanced_security_enabled", True, "Enhanced Security Enabled (symbolic flag).", "Enhanced Security Disabled.")
def _disable_enhanced_security(config): return _set_flag_only(config, "enhanced_security_enabled", False, "Enhanced Security Enabled.", "Enhanced Security Disabled (symbolic flag).")
def _enable_debug_mode(config): return _set_flag_only(config, "debug_mode_enabled", True, "Debug Mode Enabled (symbolic flag).", "Debug Mode Disabled.")
def _disable_debug_mode(config): return _set_flag_only(config, "debug_mode_enabled", False, "Debug Mode Enabled.", "Debug Mode Disabled (symbolic flag).")
def _enable_internal_usage_logging(config): return _set_flag_only(config, "internal_logging_enabled", True, "Internal Usage Logging Enabled (symbolic flag).", "Internal Logging Disabled.")
def _disable_internal_usage_logging(config): return _set_flag_only(config, "internal_logging_enabled", False, "Internal Logging Enabled.", "Internal Logging Disabled (symbolic flag).")
def _enable_drift_detection(config): return _set_flag_only(config, "drift_detection_enabled", True, "Drift Detection Enabled (symbolic flag).", "Drift Detection Disabled.")
def _disable_drift_detection(config): return _set_flag_only(config, "drift_detection_enabled", False, "Drift Detection Enabled.", "Drift Detection Disabled (symbolic flag).")
def _enable_auto_optimization(base_model, config):
msg = ""
if getattr(config, 'auto_optimization_enabled', False):
msg = "Auto Optimization already enabled (flag was true)."
logging.info(msg)
return msg
logging.info("Enabling Auto Optimization: Applying Quantization and Gradient Checkpointing...")
device = get_device()
quant_mode = 'bfloat16' if (device.type == 'cuda' and torch.cuda.is_bf16_supported()) else 'float16'
if device.type == 'cpu': quant_mode = 'float32'
quant_msg = _quantize_model(base_model, config, mode=quant_mode)
gc_msg = _enable_gradient_checkpointing(base_model, config)
config.auto_optimization_enabled = True
msg = f"Auto Optimization Enabled. Quantization ({quant_mode}): {quant_msg}. Gradient Checkpointing: {gc_msg}"
logging.info(msg)
return msg
def _disable_auto_optimization(config):
if getattr(config, 'auto_optimization_enabled', False):
config.auto_optimization_enabled = False
logging.info("Auto Optimization Disabled (flag only). Applied optimizations (like quantization, GC) remain active unless manually reverted.")
return "Auto Optimization Disabled (flag only)."
else:
logging.info("Auto Optimization was already disabled.")
return "Auto Optimization already disabled."
def _recover_perfect_precision(base_model, config):
logging.info("Attempting to recover FP32 precision...")
msg = _quantize_model(base_model, config, mode='float32')
if getattr(config, 'perfect_precision_recovered', False):
logging.info(f"Successfully recovered FP32 precision. Status: {msg}")
return "Recovered FP32 Precision. " + msg
else:
logging.warning(f"FP32 precision recovery might have failed or model was already FP32. Status: {msg}")
return "Attempted FP32 Precision Recovery. " + msg
def _revert_perfect_precision(base_model, config):
if not getattr(config, 'perfect_precision_recovered', False):
return "Model not currently in FP32 mode according to flag (or flag is inconsistent)."
device = get_device()
mode_to_revert_to = 'bfloat16' if (device.type=='cuda' and torch.cuda.is_bf16_supported()) else 'float16' if device.type=='cuda' else 'float32'
if mode_to_revert_to == 'float32':
logging.info("Cannot revert from FP32 as the target revert type is also FP32 (e.g., on CPU).")
return "Cannot revert from FP32 to lower precision on current device."
logging.info(f"Reverting precision from FP32 (target: {mode_to_revert_to})...")
msg = _quantize_model(base_model, config, mode=mode_to_revert_to)
logging.info(f"Attempted precision revert from FP32: {msg}")
return f"Reverted Precision from FP32 (attempted {mode_to_revert_to}). " + msg
def _optimize_token_generation_speed(config):
if not hasattr(config, '_original_do_sample'):
config._original_do_sample = getattr(config, 'do_sample', True)
if not hasattr(config, '_original_num_beams'):
config._original_num_beams = getattr(config, 'num_beams', 1)
if not hasattr(config, '_original_use_cache'):
default_use_cache = True
if hasattr(config, 'model_type'):
if config.model_type == "t5" and getattr(config, 'gradient_checkpointing', False):
default_use_cache = False
config._original_use_cache = getattr(config, 'use_cache', default_use_cache)
config.do_sample = False
config.num_beams = 1
config.use_cache = True
config.token_gen_speed_maximized = True
logging.info("Token Generation Speed Optimized (Flags set for greedy decoding, num_beams=1, use_cache=True).")
return "Token Speed Opt flags set (greedy, cache on)."
def _revert_token_generation_speed_optimization(config):
if not getattr(config, 'token_gen_speed_maximized', False):
return "Token speed optimization not active according to flag."
config.do_sample = getattr(config, '_original_do_sample', True)
config.num_beams = getattr(config, '_original_num_beams', 1)
config.use_cache = getattr(config, '_original_use_cache', True)
config.token_gen_speed_maximized = False
if hasattr(config, '_original_do_sample'): del config._original_do_sample
if hasattr(config, '_original_num_beams'): del config._original_num_beams
if hasattr(config, '_original_use_cache'): del config._original_use_cache
logging.info("Token Generation Speed Optimization Reverted to previous/default flags.")
return "Token Speed Optimization Reverted."
def _add_peft_adapter(model, config, peft_config_obj=None):
global global_model, current_peft_config
if not _peft_installed:
return "[Error] PEFT library (pip install peft) is not installed."
if isinstance(model, PeftModel):
return "[Warning] Model is already a PEFT model. Merge or remove existing adapters before adding a new one via this button."
if getattr(config, 'lora_merged', False):
return "[Warning] LoRA adapters were previously merged into this model state. Adding new adapters might have unintended effects without reloading the original base model."
try:
if peft_config_obj and isinstance(peft_config_obj, (LoraConfig, PeftConfig)):
peft_conf = peft_config_obj
logging.info(f"Using provided PEFT config object: {peft_conf}")
else:
default_config_dict = copy.deepcopy(DEFAULT_PEFT_CONFIG_DICT)
if not default_config_dict:
raise ValueError("Default PEFT config is not available and no valid config provided.")
peft_conf = LoraConfig(**default_config_dict)
logging.info(f"Using default PEFT config: {peft_conf}")
if hasattr(peft_conf, 'task_type') and peft_conf.task_type != TaskType.CAUSAL_LM:
logging.warning(f"PEFT config task type is {peft_conf.task_type}, overriding to CAUSAL_LM for this platform.")
peft_conf.task_type = TaskType.CAUSAL_LM
elif not hasattr(peft_conf, 'task_type'):
if isinstance(peft_conf, PeftConfig) and not isinstance(peft_conf, LoraConfig):
peft_conf.task_type = TaskType.CAUSAL_LM
peft_model = get_peft_model(model, peft_conf)
base_model_config = peft_model.get_base_model().config
base_model_config.peft_adapter_added = True
base_model_config.peft_config = peft_conf.to_dict()
base_model_config.lora_merged = False
current_peft_config = peft_conf
global_model = peft_model
config = base_model_config
trainable_params, all_params = peft_model.get_nb_trainable_parameters()
logging.info(
f"trainable params: {trainable_params:,d} || all params: {all_params:,d} || trainable%: {100 * trainable_params / all_params:.4f}"
)
msg = f"PEFT adapter ({type(peft_conf).__name__}) added successfully. Model is ready for PEFT training."
logging.info(msg)
return msg
except Exception as e:
logging.error(f"Error adding PEFT adapter: {e}\n{traceback.format_exc()}")
if hasattr(model, 'config'):
model.config.peft_adapter_added = False
model.config.peft_config = None
return f"[Error] Failed to add PEFT adapter: {e}"
def _remove_peft_adapter(model, config):
global global_model, current_peft_config
if not _peft_installed:
return "[Error] PEFT library not installed."
if not isinstance(model, PeftModel):
if getattr(config, 'peft_adapter_added', False):
logging.warning("Model is not a PeftModel instance, but PEFT flag was set. Resetting flags.")
config.peft_adapter_added = False
config.peft_config = None
current_peft_config = {}
return "[Warning] Reset PEFT flags as model was not a PeftModel instance."
else:
return "[Info] No PEFT adapter currently applied to the model."
try:
base_model = model.get_base_model()
global_model = base_model
config = base_model.config
config.peft_adapter_added = False
config.peft_config = None
current_peft_config = {}
msg = "PEFT adapter layers removed. Restored base model and reset PEFT config flags."
logging.info(msg)
clean_memory()
return msg
except Exception as e:
logging.error(f"Error removing PEFT adapter: {e}\n{traceback.format_exc()}")
return f"[Error] Failed to remove PEFT adapter: {e}"
def _setup_multimodal(model, config, selected_modalities):
global global_tokenizer
if not selected_modalities:
return "[Info] No modalities selected for setup."
if getattr(config, 'multimodal_applied', False):
current_modalities = getattr(config, 'supported_modalities', [])
return f"[Warning] Multi-modal setup already applied for modalities: {current_modalities}. Revert first to change."
logging.info(f"Attempting multi-modal setup for: {selected_modalities}")
device = get_device()
llm_hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None))
if not llm_hidden_size:
return "[Error] Cannot setup multi-modal: LLM config missing 'hidden_size' or 'd_model'."
if global_tokenizer is None:
return "[Error] Cannot setup multi-modal: Global tokenizer not loaded."
try:
added_encoders = {}
added_projections = {}
added_special_tokens = {}
new_tokens_added_to_tokenizer = []
current_modality_config = {}
current_special_tokens_map = {}
tokens_to_add_struct = []
for modality in selected_modalities:
if modality not in MODALITY_ENCODERS:
logging.warning(f"Skipping modality '{modality}': No predefined encoder found.")
continue
special_token = f"<{modality.upper()}>"
if special_token not in global_tokenizer.get_vocab():
tokens_to_add_struct.append({'token': special_token, 'modality': modality})
else:
token_id = global_tokenizer.convert_tokens_to_ids(special_token)
current_special_tokens_map[modality] = {"token": special_token, "id": token_id}
logging.info(f"Special token '{special_token}' for {modality} already exists (ID: {token_id}).")
if tokens_to_add_struct:
num_added = global_tokenizer.add_tokens([t['token'] for t in tokens_to_add_struct], special_tokens=True)
if num_added > 0:
logging.info(f"Added {num_added} new special tokens to tokenizer: {[t['token'] for t in tokens_to_add_struct]}")
logging.info(f"Resizing LLM token embeddings from {model.config.vocab_size} to {len(global_tokenizer)}.")
model.resize_token_embeddings(len(global_tokenizer))
if hasattr(config, 'vocab_size'):
config.vocab_size = len(global_tokenizer)
with torch.no_grad():
input_embeddings = model.get_input_embeddings()
if input_embeddings is not None and hasattr(input_embeddings, 'weight'):
avg_weight = input_embeddings.weight[:-num_added,:].mean(dim=0)
input_embeddings.weight[-num_added:,:] = avg_weight
logging.info(f"Initialized {num_added} new token embeddings with average weight.")
for t_info in tokens_to_add_struct:
modality = t_info['modality']
special_token = t_info['token']
token_id = global_tokenizer.convert_tokens_to_ids(special_token)
current_special_tokens_map[modality] = {"token": special_token, "id": token_id}
new_tokens_added_to_tokenizer.append(special_token)
else:
logging.error(f"Failed to add special tokens: {[t['token'] for t in tokens_to_add_struct]}. Aborting multi-modal setup.")
return "[Error] Failed to add required special tokens to tokenizer."
successful_modalities = []
for modality in selected_modalities:
if modality not in MODALITY_ENCODERS: continue
encoder_id = MODALITY_ENCODERS[modality]
encoder_attr_name = f"{modality.lower()}_encoder"
projection_attr_name = f"{modality.lower()}_projection"
try:
logging.info(f"Loading {modality} encoder: {encoder_id}")
encoder = AutoModel.from_pretrained(encoder_id, trust_remote_code=True)
encoder = encoder.to(device).eval()
for param in encoder.parameters():
param.requires_grad = False
added_encoders[encoder_attr_name] = encoder
setattr(model, encoder_attr_name, encoder)
encoder_hidden_size = _get_encoder_hidden_size(encoder_id, trust_remote_code=True)
logging.info(f"Creating projection layer for {modality}: {encoder_hidden_size} -> {llm_hidden_size}")
projection = nn.Linear(encoder_hidden_size, llm_hidden_size).to(device)
nn.init.xavier_uniform_(projection.weight)
if projection.bias is not None: nn.init.zeros_(projection.bias)
added_projections[projection_attr_name] = projection
setattr(model, projection_attr_name, projection)
current_modality_config[modality] = encoder_id
successful_modalities.append(modality)
except Exception as mod_e:
logging.error(f"Failed to setup modality '{modality}' with encoder '{encoder_id}': {mod_e}")
if hasattr(model, encoder_attr_name): delattr(model, encoder_attr_name)
if hasattr(model, projection_attr_name): delattr(model, projection_attr_name)
if successful_modalities:
config.multimodal_applied = True
config.supported_modalities = successful_modalities
config.modality_encoders = current_modality_config
config.modality_projection_dim = llm_hidden_size
config.modality_special_tokens = current_special_tokens_map
msg = (f"Multi-modal setup partially/fully applied for: {successful_modalities}. "
f"Added {len(added_encoders)} encoders and {len(added_projections)} projections. "
f"Added/mapped {len(current_special_tokens_map)} special tokens. ")
logging.warning(msg)
return msg
else:
config.multimodal_applied = False
return "[Error] Multi-modal setup failed for all selected modalities."
except Exception as e:
logging.error(f"Error during multi-modal setup: {e}\n{traceback.format_exc()}")
for name in added_encoders.keys():
if hasattr(model, name): delattr(model, name)
for name in added_projections.keys():
if hasattr(model, name): delattr(model, name)
config.multimodal_applied = False
config.supported_modalities = []
config.modality_encoders = {}
config.modality_projection_dim = None
config.modality_special_tokens = {}
return (f"[Error] Multi-modal setup failed: {e}. Attempted cleanup, state might be inconsistent "
"(tokenizer/embeddings may remain changed). Reload original model/tokenizer for full reset.")
def _revert_multimodal(model, config):
if not getattr(config, 'multimodal_applied', False):
return "[Info] Multi-modal setup not applied according to config."
modalities_to_revert = getattr(config, 'supported_modalities', [])
if not modalities_to_revert:
config.multimodal_applied = False
config.modality_encoders = {}
config.modality_projection_dim = None
config.modality_special_tokens = {}
return "[Info] No supported modalities listed in config to revert, but flag was true. Resetting flags."
logging.info(f"Reverting multi-modal setup for modalities: {modalities_to_revert}")
removed_count = 0
errors = []
try:
for modality in modalities_to_revert:
encoder_attr_name = f"{modality.lower()}_encoder"
projection_attr_name = f"{modality.lower()}_projection"
try:
if hasattr(model, encoder_attr_name):
delattr(model, encoder_attr_name)
logging.info(f"Removed encoder: {encoder_attr_name}")
removed_count += 1
if hasattr(model, projection_attr_name):
delattr(model, projection_attr_name)
logging.info(f"Removed projection: {projection_attr_name}")
removed_count += 1
except Exception as del_e:
error_msg = f"Error removing components for modality '{modality}': {del_e}"
logging.error(error_msg)
errors.append(error_msg)
config.multimodal_applied = False
config.supported_modalities = []
config.modality_encoders = {}
config.modality_projection_dim = None
config.modality_special_tokens = {}
logging.warning("Multi-modal components removed. **Special tokens added to tokenizer and potentially resized embeddings remain.** Reload original model/tokenizer if full reversion needed.")
clean_memory()
final_msg = f"Multi-modal setup reverted ({removed_count} components removed, flags reset). Embeddings/tokenizer not shrunk."
if errors:
final_msg += f" Errors encountered: {'; '.join(errors)}"
return final_msg
except Exception as e:
logging.error(f"Error reverting multi-modal setup: {e}\n{traceback.format_exc()}")
config.multimodal_applied = False
config.supported_modalities = []
config.modality_encoders = {}
config.modality_projection_dim = None
config.modality_special_tokens = {}
return f"[Error] Reverting multi-modal setup failed: {e}. Flags reset."
def auto_extract_text_universal(data_item):
if isinstance(data_item, str):
return data_item.strip().replace('\\n', '\n')
elif isinstance(data_item, bytes):
try:
return data_item.decode('utf-8', errors='replace').strip().replace('\\n', '\n')
except Exception:
return ""
elif isinstance(data_item, (list, tuple)):
texts = [auto_extract_text_universal(item) for item in data_item]
return " ".join(filter(None, texts))
elif isinstance(data_item, dict):
texts = []
potential_keys = [
'text', 'content', 'sentence', 'paragraph', 'article', 'abstract',
'summary', 'body', 'passage', 'document', 'script', 'dialogue',
'instruction', 'input', 'output', 'query', 'response', 'title',
'question', 'answer', 'prompt', 'completion', 'target', 'label',
'review', 'comment', 'post', 'code', 'markdown'
]
processed_keys = set()
for key in potential_keys:
if key in data_item and key not in processed_keys:
value = data_item[key]
extracted = auto_extract_text_universal(value)
if extracted:
texts.append(extracted)
processed_keys.add(key)
if not texts:
for key, value in data_item.items():
if key not in processed_keys:
extracted = auto_extract_text_universal(value)
if extracted:
texts.append(extracted)
processed_keys.add(key)
seen = set()
unique_texts = []
for t in texts:
if t and t not in seen:
unique_texts.append(t)
seen.add(t)
return "\n".join(unique_texts)
elif isinstance(data_item, (int, float, bool)) or data_item is None:
return ""
else:
try:
return str(data_item).strip().replace('\\n', '\n')
except Exception:
return ""
def process_example_universal(example):
extracted_text = auto_extract_text_universal(example)
return {"text": extracted_text if extracted_text else "[EMPTY_OR_NON_TEXTUAL]"}
def parse_datasets(dataset_text):
datasets = []
seen_ids = set()
for line_num, line in enumerate(dataset_text.strip().splitlines()):
line = line.strip()
if not line or line.startswith('#'):
continue
parts = [s.strip() for s in line.split(",") if s.strip()]
ds_name = None
ds_config = None
ds_split = 'train'
ds_weight = 1.0
if len(parts) >= 1:
ds_name = parts[0]
if len(parts) >= 2 and parts[1]:
ds_config = parts[1] if parts[1].lower() != 'none' else None
if len(parts) >= 3 and parts[2]:
ds_split = parts[2]
if len(parts) >= 4:
try:
ds_weight = float(parts[3])
if ds_weight <= 0:
raise ValueError("Weight must be positive")
except (ValueError, IndexError):
logging.warning(f"Invalid or missing weight '{parts[3] if len(parts) >= 4 else ''}' on line {line_num+1} ('{line}'). Using default 1.0.")
ds_weight = 1.0
if ds_name:
dataset_id = f"{ds_name}_{ds_config or 'DEFAULT'}_{ds_split}"
if dataset_id in seen_ids:
logging.warning(f"Skipping duplicate dataset entry: {dataset_id} on line {line_num+1}")
continue
datasets.append({"id": ds_name, "config": ds_config, "split": ds_split, "weight": ds_weight})
seen_ids.add(dataset_id)
else:
logging.warning(f"Skipping invalid dataset line (no name found): '{line}' on line {line_num+1}")
if not datasets:
raise ValueError("No valid dataset configurations were parsed from the input.")
return datasets
def load_datasets_from_config(datasets_config):
ds_list = []
loaded_configs = []
total_weight = 0.0
logging.info(f"Attempting to load datasets based on config: {datasets_config}")
for config_entry in datasets_config:
ds_name = config_entry['id']
ds_config = config_entry['config']
ds_split = config_entry['split']
ds_weight = config_entry['weight']
dataset_identifier = f"{ds_name}{'['+ds_config+']' if ds_config else ''} (Split: {ds_split}, Weight: {ds_weight})"
try:
logging.info(f"Loading {dataset_identifier}...")
d = load_dataset(
ds_name,
ds_config,
streaming=True,
split=ds_split,
trust_remote_code=True,
)
try:
peek = next(iter(d))
original_columns = list(peek.keys())
d = load_dataset(ds_name, ds_config, streaming=True, split=ds_split, trust_remote_code=True)
except StopIteration:
logging.warning(f"Dataset stream appears empty after loading: {dataset_identifier}. Skipping.")
continue
except Exception as peek_e:
logging.warning(f"Could not reliably peek into dataset {dataset_identifier} to get columns: {peek_e}. Will attempt processing without column removal.")
original_columns = None
logging.info(f"Processing {dataset_identifier} (Original cols: {original_columns or 'unknown'}) -> Map to 'text' field")
process_partial = partial(process_example_universal)
processed_d = d.map(process_partial, remove_columns=original_columns)
processed_d = processed_d.filter(lambda example: example.get("text") != "[EMPTY_OR_NON_TEXTUAL]")
shuffled_d = processed_d.shuffle(buffer_size=10000, seed=42)
ds_list.append(shuffled_d)
loaded_configs.append(config_entry)
total_weight += ds_weight
logging.info(f"Successfully prepared stream: {dataset_identifier}")
except (requests.exceptions.RequestException, gzip.BadGzipFile) as http_e:
logging.error(f"Network or File Error loading dataset {dataset_identifier}: {http_e}. Check connection and dataset validity. Skipping.")
except FileNotFoundError:
logging.error(f"Dataset or config not found for {dataset_identifier}. Check name/config/path. Skipping.")
except Exception as e:
logging.error(f"General Error loading/processing dataset {dataset_identifier}: {e} \n{traceback.format_exc()}. Skipping.")
if not ds_list:
raise ValueError("No valid datasets were loaded. Check dataset names, configurations, splits, availability, and network connection.")
logging.info(f"Successfully loaded {len(ds_list)} dataset streams.")
if total_weight <= 0 or len(loaded_configs) != len(ds_list):
probabilities = [1.0 / len(ds_list)] * len(ds_list) if ds_list else []
logging.warning("Using equal probabilities for interleaving due to zero total weight, loading errors, or no datasets.")
else:
probabilities = [cfg['weight'] / total_weight for cfg in loaded_configs]
prob_sum = sum(probabilities)
if abs(prob_sum - 1.0) > 1e-6:
probabilities = [p / prob_sum for p in probabilities]
if not ds_list:
logging.warning("No datasets to interleave.")
return None
logging.info(f"Interleaving {len(ds_list)} datasets with probabilities: {[f'{p:.3f}' for p in probabilities]}")
interleaved_ds = interleave_datasets(ds_list, probabilities=probabilities, seed=42, stopping_strategy="all_exhausted")
return interleaved_ds
def tokenize_function(examples, tokenizer, context_length):
texts = [str(t) if t is not None else "" for t in examples["text"]]
tokenized_output = tokenizer(texts, truncation=False, padding=False)
return tokenized_output
def group_texts(examples, block_size):
concatenated_examples = {k: sum(examples[k], []) if isinstance(examples[k][0], list) else examples[k] for k in examples}
total_length = len(concatenated_examples[list(examples.keys())[0]])
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
else:
return {k: [] for k in examples.keys()}
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
def split_dataset(processed_lm_iterable_dataset):
eval_buffer_size = 1000
shuffle_buffer_size = 10000
logging.info(f"Preparing train/eval split. Eval buffer: {eval_buffer_size}, Shuffle buffer: {shuffle_buffer_size}..."); T = time.time()
if not isinstance(processed_lm_iterable_dataset, IterableDataset):
logging.error("Input dataset is not an IterableDataset. Cannot perform stream-based splitting.")
raise TypeError("Input to split_dataset must be an IterableDataset.")
shuffled_ds = processed_lm_iterable_dataset.shuffle(seed=42, buffer_size=shuffle_buffer_size)
logging.info(f"Taking up to {eval_buffer_size} samples for the evaluation buffer...")
eval_samples_iter = shuffled_ds.take(eval_buffer_size)
try:
eval_list = list(eval_samples_iter)
num_eval_samples = len(eval_list)
except Exception as e:
logging.error(f"Error collecting evaluation samples: {e}. Proceeding without evaluation set.")
num_eval_samples = 0
eval_list = []
train_ds = shuffled_ds
eval_ds_static = None
if num_eval_samples > 0:
logging.info(f"Collected {num_eval_samples} samples for evaluation buffer.")
train_ds = shuffled_ds.skip(num_eval_samples)
logging.info("Training stream prepared (skipped eval samples).")
logging.info("Creating static evaluation dataset from buffer...")
try:
if not eval_list: raise ValueError("Evaluation buffer list is empty after take().")
first_example = eval_list[0]
if not isinstance(first_example, dict): raise ValueError("Eval buffer items are not dictionaries.")
expected_keys = ['input_ids', 'attention_mask', 'labels']
eval_features_dict = {}
for key in expected_keys:
if key not in first_example:
raise ValueError(f"Eval buffer items missing required key: '{key}'")
try:
from datasets import Sequence
inner_dtype = 'int64'
if isinstance(first_example[key], list) and first_example[key] and isinstance(first_example[key][0], int):
eval_features_dict[key] = Sequence(feature=Value(dtype=inner_dtype))
else:
eval_features_dict[key] = Value(dtype='list')
except ImportError:
eval_features_dict[key] = Value(dtype='list')
if not eval_features_dict:
raise ValueError("Could not define features for evaluation dataset.")
eval_features = Features(eval_features_dict)
valid_eval_list = []
required_keys_set = set(eval_features.keys())
for i, ex in enumerate(eval_list):
if isinstance(ex, dict) and set(ex.keys()) >= required_keys_set:
is_valid = all(isinstance(ex.get(k), list) for k in required_keys_set)
if is_valid:
valid_eval_list.append({k: ex[k] for k in required_keys_set})
else:
logging.warning(f"Eval buffer item {i} has invalid type for required keys. Skipping.")
else:
logging.warning(f"Eval buffer item {i} is invalid (not dict or missing keys). Skipping.")
if not valid_eval_list:
logging.warning("No valid examples remained in the evaluation buffer after validation. Eval dataset will be None.")
eval_ds_static = None
train_ds = shuffled_ds
else:
eval_ds_static = Dataset.from_list(valid_eval_list, features=eval_features)
logging.info(f"Created static evaluation dataset with {len(eval_ds_static)} examples.")
except Exception as e:
logging.error(f"Error creating static evaluation dataset from buffer: {e}\n{traceback.format_exc()}. Evaluation dataset will be None.")
eval_ds_static = None
train_ds = shuffled_ds
else:
logging.warning("Evaluation buffer is empty (requested size might be too large or dataset too small). Training will continue without evaluation.")
logging.info(f"Dataset splitting completed in {time.time()-T:.2f}s")
return train_ds, eval_ds_static
def compute_perplexity(loss):
if loss is None or not isinstance(loss, (int, float)) or not math.isfinite(loss):
return float("inf")
try:
clamped_loss = min(max(loss, -700.0), 700.0)
perplexity = math.exp(clamped_loss)
if not math.isfinite(perplexity):
logging.warning(f"Perplexity calculation resulted in infinity for loss {loss} (clamped: {clamped_loss}).")
return float("inf")
return perplexity
except OverflowError:
logging.warning(f"OverflowError computing perplexity for loss {loss}. Returning infinity.")
return float("inf")
except Exception as e:
logging.warning(f"Error computing perplexity for loss {loss}: {e}. Returning infinity.")
return float("inf")
def merge_model_parameters(original_model, trained_model, alpha=MERGE_ALPHA):
if not (0 <= alpha <= 1):
logging.error(f"Merge alpha must be between 0 and 1. Got {alpha}. Defaulting to 0.5")
alpha = 0.5
logging.info(f"Merging model parameters with alpha={alpha:.2f} (alpha*original + (1-alpha)*trained using linear interpolation)..."); T = time.time();
device = get_device()
original_model = original_model.to(device)
trained_model = trained_model.to(device)
merged_model = copy.deepcopy(original_model).to(device)
merged_params_count = 0
skipped_params_count = 0
orig_params = dict(original_model.named_parameters())
trained_params = dict(trained_model.named_parameters())
merged_params = dict(merged_model.named_parameters())
with torch.no_grad():
for name, trained_param in trained_params.items():
if name in orig_params and name in merged_params:
orig_param = orig_params[name]
merged_param = merged_params[name]
if orig_param.data.shape == trained_param.data.shape:
merged_tensor = torch.lerp(trained_param.data.float(), orig_param.data.float(), alpha)
merged_param.copy_(merged_tensor.to(merged_param.dtype))
merged_params_count += 1
else:
logging.warning(f"Size mismatch for parameter '{name}'. Original: {orig_param.data.shape}, Trained: {trained_param.data.shape}. Skipping merge for this parameter.")
skipped_params_count += 1
else:
if name not in orig_params:
logging.warning(f"Parameter '{name}' from trained model not found in original model structure. Skipping.")
if name not in merged_params:
logging.warning(f"Parameter '{name}' from trained model not found in merged model structure (shouldn't happen). Skipping.")
skipped_params_count += 1
logging.info(f"Parameter merging finished in {time.time()-T:.2f}s. Merged {merged_params_count} parameters, skipped {skipped_params_count}.")
return merged_model
def preserve_model_quality(original_model, trained_model, eval_dataset, tokenizer):
if eval_dataset is None:
logging.warning("No evaluation data provided (eval_dataset is None). Cannot perform quality check. Returning trained model.")
return trained_model
is_iterable = isinstance(eval_dataset, IterableDataset)
if is_iterable:
logging.warning("Evaluation dataset is iterable. Loss comparison might not be on the exact same data. Proceeding with caution.")
try:
_ = next(iter(eval_dataset.take(1)))
except StopIteration:
logging.warning("Iterable evaluation dataset appears empty. Returning trained model.")
return trained_model
except Exception as e:
logging.warning(f"Could not peek into iterable eval dataset: {e}. Assuming not empty.")
elif isinstance(eval_dataset, Dataset):
if len(eval_dataset) == 0:
logging.warning("Evaluation dataset is empty (length 0). Returning trained model.")
return trained_model
else:
logging.warning(f"Unknown evaluation dataset type: {type(eval_dataset)}. Cannot perform quality check. Returning trained model.")
return trained_model
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
eval_batch_size = max(1, BATCH_SIZE // 2)
device = get_device()
original_model.to(device).eval()
trained_model.to(device).eval()
was_training_orig = original_model.training
trained_model.train()
temp_eval_dir = "./tmp_eval_quality_check"
eval_args = TrainingArguments(
output_dir=temp_eval_dir,
per_device_eval_batch_size=eval_batch_size,
report_to=[],
dataloader_num_workers=max(1, (NUM_CPU_CORES if NUM_CPU_CORES > 0 else os.cpu_count() // 2)),
fp16=torch.cuda.is_available() and not USE_CPU and original_model.dtype == torch.float16,
bf16=(torch.cuda.is_available() and torch.cuda.is_bf16_supported()) and original_model.dtype == torch.bfloat16,
use_cpu=USE_CPU,
log_level='error',
remove_unused_columns=False,
)
results = {}
eval_error = False
for model_name, model_instance in [("Original", original_model), ("Trained", trained_model)]:
logging.info(f"Evaluating {model_name} model for quality check..."); T_eval = time.time()
current_eval_dataset = eval_dataset
if is_iterable:
current_eval_dataset = eval_dataset.take(1000) # Hardcoded eval buffer size
try:
if len(list(iter(current_eval_dataset.take(1)))) == 0:
logging.warning(f"Iterable eval sample for {model_name} is empty. Skipping eval.")
results[model_name] = {"loss": float('inf'), "ppl": float('inf')}
continue
current_eval_dataset = eval_dataset.take(1000) # Hardcoded eval buffer size
except Exception as e:
logging.error(f"Error handling iterable dataset sample for {model_name}: {e}")
results[model_name] = {"loss": float('inf'), "ppl": float('inf')}
eval_error = True; break
trainer = Trainer(
model=model_instance,
args=eval_args,
data_collator=data_collator,
eval_dataset=current_eval_dataset
)
try:
model_instance.eval()
metrics = trainer.evaluate()
loss = metrics.get("eval_loss")
ppl = compute_perplexity(loss)
results[model_name] = {"loss": loss if loss is not None else float('inf'), "ppl": ppl}
logging.info(f"{model_name} Eval Loss: {loss if loss is not None else 'N/A':.4f}, PPL: {ppl:.4f} (Eval time: {time.time()-T_eval:.2f}s)")
except StopIteration:
logging.error(f"Evaluation dataset exhausted unexpectedly during evaluation of {model_name}. Comparison may be incomplete.")
results[model_name] = {"loss": float('inf'), "ppl": float('inf')}
eval_error = True; break
except Exception as e:
logging.error(f"Error evaluating {model_name} model: {e}\n{traceback.format_exc()}")
results[model_name] = {"loss": float('inf'), "ppl": float('inf')}
eval_error = True; break
if os.path.exists(temp_eval_dir):
try:
shutil.rmtree(temp_eval_dir)
except Exception as e:
logging.warning(f"Could not remove temporary eval directory {temp_eval_dir}: {e}")
original_model.train(mode=was_training_orig)
trained_model.train()
original_loss = results.get("Original", {}).get("loss", float('inf'))
trained_loss = results.get("Trained", {}).get("loss", float('inf'))
if eval_error:
logging.error("Evaluation encountered errors. Cannot reliably compare models. Returning trained model.")
return trained_model
valid_comparison = math.isfinite(original_loss) and math.isfinite(trained_loss)
if valid_comparison:
loss_threshold = original_loss * 1.05
if trained_loss > loss_threshold:
logging.warning(f"Trained model loss ({trained_loss:.4f}) is significantly worse (>5%) than original ({original_loss:.4f}). Reverting to original model state based on quality check.")
return original_model.to(device)
elif trained_loss > original_loss:
logging.info(f"Trained model loss ({trained_loss:.4f}) is slightly worse than original ({original_loss:.4f}), but within threshold. Keeping trained model.")
return trained_model.to(device)
else:
logging.info(f"Trained model loss ({trained_loss:.4f}) is better than or equal to original ({original_loss:.4f}). Keeping trained model.")
return trained_model.to(device)
else:
logging.warning("Could not perform valid loss comparison (one or both evaluations failed or yielded non-finite loss). Returning trained model.")
return trained_model.to(device)
def _merge_architectures(model_ids_str, hf_token=None, bypass_limits_state=False):
global global_model, global_tokenizer, config, global_pipe, BYPASS_RESOURCE_LIMITS
BYPASS_RESOURCE_LIMITS = bypass_limits_state
if not isinstance(model_ids_str, str) or not model_ids_str.strip():
return "[Error] Model IDs string cannot be empty.", "{}", *get_error_filter_updates()
resources_ok, res_msg = check_resources()
if not resources_ok:
error_msg = f"[Error] Resource limits exceeded, cannot proceed with merge. {res_msg}"
logging.error(error_msg)
return error_msg, "{}", *get_error_filter_updates()
else:
logging.info(res_msg)
model_ids = [m.strip() for m in model_ids_str.split(',') if m.strip()]
if len(model_ids) < 2:
return "[Error] Need at least two valid model IDs/paths separated by commas to merge.", "{}", *get_error_filter_updates()
logging.info(f"Starting architecture merge (parameter averaging) for models: {model_ids}")
device = get_device()
merged_model = None
t_merge_start = time.time()
base_model_id = model_ids[0]
try:
logging.info(f"Loading base config and tokenizer from: {base_model_id}")
base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token)
base_config = AutoConfig.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token)
if base_tokenizer.pad_token is None and base_tokenizer.eos_token is not None:
base_tokenizer.pad_token = base_tokenizer.eos_token
base_config.pad_token_id = base_config.eos_token_id
logging.info("Set base tokenizer pad_token to eos_token for consistency.")
except Exception as e:
logging.error(f"Failed to load base config/tokenizer for {base_model_id}: {e}")
return f"[Error] Failed to load base model config/tokenizer: {e}", "{}", *get_error_filter_updates()
try:
logging.info(f"Loading base model state dict (CPU, float32) for merging: {base_model_id}")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
trust_remote_code=True,
token=hf_token,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
base_state_dict = base_model.state_dict()
merged_state_dict = OrderedDict((k, v.clone()) for k, v in base_state_dict.items())
param_counts = OrderedDict((k, 1) for k in base_state_dict)
num_models_processed = 1
del base_model, base_state_dict
clean_memory()
except Exception as e:
logging.error(f"Failed to load base model state dict for {base_model_id}: {e}")
return f"[Error] Failed to load base model state dict: {e}", "{}", *get_error_filter_updates()
for i, model_id in enumerate(model_ids[1:]):
logging.info(f"Processing model {i+2}/{len(model_ids)}: {model_id}")
try:
model_i = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
state_dict_i = model_i.state_dict()
for name, param_i in state_dict_i.items():
if name in merged_state_dict:
if merged_state_dict[name].shape == param_i.shape:
merged_state_dict[name].add_(param_i)
param_counts[name] += 1
else:
logging.warning(f"Shape mismatch for parameter '{name}' between base and {model_id}. Base: {merged_state_dict[name].shape}, Current: {param_i.shape}. Parameter '{name}' will NOT include contribution from {model_id}.")
else:
logging.warning(f"Parameter '{name}' found in {model_id} but not in base model {base_model_id}. Skipping this parameter.")
num_models_processed += 1
del model_i, state_dict_i
clean_memory()
except Exception as e:
logging.error(f"Failed to load or process model {model_id}: {e}. Skipping this model for merge.")
continue
if num_models_processed < 2:
msg = "Merge failed: Fewer than two models were successfully loaded and processed."
logging.error(msg)
return f"[Error] {msg}", "{}", *get_error_filter_updates()
averaged_count = 0
for name in merged_state_dict:
count = param_counts.get(name, 0)
if count > 0:
merged_state_dict[name].div_(count)
averaged_count +=1
else:
logging.error(f"Parameter '{name}' has count {count} <= 0 during averaging. This indicates a logic error.")
logging.info(f"Averaged {averaged_count} parameters across {num_models_processed} successfully processed models.")
try:
logging.info("Creating final merged model from base config and averaged weights...")
merged_model = AutoModelForCausalLM.from_config(base_config, trust_remote_code=True)
load_results = merged_model.load_state_dict(merged_state_dict, strict=False)
if load_results.missing_keys:
logging.warning(f"Load state dict results: Missing keys: {load_results.missing_keys}")
if load_results.unexpected_keys:
logging.warning(f"Load state dict results: Unexpected keys: {load_results.unexpected_keys}")
final_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32
logging.info(f"Converting merged model to {final_dtype} for use on device {device}.")
global_model = merged_model.to(device=device, dtype=final_dtype)
global_tokenizer = base_tokenizer
config = initialize_config_flags(global_model.config)
config.architecture_merged = True
config.merged_from_models = model_ids
config.merged_models_processed = num_models_processed
update_pipeline()
clean_memory()
final_status_json, *filter_updates = get_detailed_status_and_filter_states()
merge_time = time.time() - t_merge_start
msg = f"Successfully merged architectures (averaged parameters) from {num_models_processed} models in {merge_time:.2f}s. Base config/tokenizer from: {base_model_id}."
logging.info(msg)
return msg, final_status_json, *filter_updates
except Exception as e:
logging.error(f"Architecture merging failed during final model creation or state update: {e}\n{traceback.format_exc()}")
global_model = None
global_tokenizer = None
config = None
global_pipe = None
clean_memory()
return f"[Error] Architecture merging failed: {e}", "{}", *get_error_filter_updates()
def get_user_id(token):
if not token:
logging.warning("No Hugging Face token provided for user ID check.")
return "unknown_user"
try:
api = HfApi()
user_info = api.whoami(token=token)
return user_info.get("name", "unknown_user")
except requests.exceptions.HTTPError as http_err:
if http_err.response.status_code == 401:
logging.error("Hugging Face authentication failed (401 Unauthorized). Check your token.")
return "auth_error_user"
else:
logging.error(f"HTTP error retrieving Hugging Face user ID: {http_err}")
return "http_error_user"
except Exception as e:
logging.error(f"Could not retrieve Hugging Face user ID: {e}")
return "unknown_user"
def decode_model_details(model):
if model is None:
return json.dumps({"Error": "Model not loaded."}, indent=2)
if not hasattr(model, 'config'):
logging.warning("Model object lacks a 'config' attribute.")
details = OrderedDict()
details["Model Class"] = type(model).__name__
details["Error"] = "Model config attribute not found."
return json.dumps(details, indent=2)
details = OrderedDict()
config_obj = model.config
t_start_decode = time.time()
logging.info("Decoding model details...")
try:
details["Model Class"] = type(model).__name__
details["Config Class"] = getattr(config_obj, 'config_class', type(config_obj).__name__)
details["Model Type"] = getattr(config_obj, 'model_type', 'N/A')
total_params = 0
trainable_params = 0
param_dtypes = set()
param_devices = set()
try:
for name, param in model.named_parameters():
num_elements = param.numel()
total_params += num_elements
param_dtypes.add(str(param.dtype).replace('torch.', ''))
param_devices.add(str(param.device))
if param.requires_grad:
trainable_params += num_elements
if not param_devices:
device_str = "N/A (No parameters)"
elif len(param_devices) == 1:
device_str = param_devices.pop()
else:
device_str = f"Multiple ({', '.join(param_devices)})"
except Exception as e:
logging.warning(f"Could not fully analyze parameters: {e}")
device_str = "Error analyzing params"
details["Device(s)"] = device_str
trainable_perc = (100 * trainable_params / total_params) if total_params > 0 else 0.00
details["Params Summary"] = (f"Total: {total_params:,}, Trainable: {trainable_params:,} "
f"({trainable_perc:.2f}%), Dtypes: {list(param_dtypes)}")
try:
layer_counts = Counter(type(m).__name__ for m in model.modules() if not isinstance(m, nn.Sequential))
details["Layer Types Count"] = dict(layer_counts.most_common(15))
except Exception as e:
logging.warning(f"Could not count layer types: {e}")
details["Layer Types Count"] = "Error counting layers"
details["Modification Flags"] = {}
all_flags = initialize_config_flags(None).__dict__.keys()
for flag in sorted(all_flags):
if hasattr(config_obj, flag):
value = getattr(config_obj, flag)
details["Modification Flags"][flag] = value
details["Key Config Attributes"] = {}
key_attrs = ['vocab_size', 'hidden_size', 'num_hidden_layers', 'num_attention_heads',
'intermediate_size', 'max_position_embeddings', 'hidden_act', 'layer_norm_eps',
'rms_norm_eps', 'attention_dropout', 'hidden_dropout_prob', 'initializer_range',
'tie_word_embeddings', 'rope_scaling', 'sliding_window', 'attn_implementation']
for attr in key_attrs:
if hasattr(config_obj, attr):
details["Key Config Attributes"][attr] = getattr(config_obj, attr)
logging.info(f"Model details decoded in {time.time() - t_start_decode:.2f}s")
return json.dumps(details, indent=2, default=str)
except Exception as e:
logging.error(f"Error decoding model details: {e} \n{traceback.format_exc()}")
details["Error"] = f"Failed during detail decoding: {e}"
return json.dumps(details, indent=2, default=str)
def update_pipeline():
global global_model, global_tokenizer, global_pipe
if global_model and global_tokenizer:
device = get_device()
pipeline_device_arg = None
device_map = None
if device.type == 'cpu':
pipeline_device_arg = -1
logging.info("Configuring pipeline for CPU.")
elif device.type == 'cuda':
if torch.cuda.device_count() > 1:
device_map = "auto"
pipeline_device_arg = None
logging.info("Multiple GPUs detected, configuring pipeline with device_map='auto'.")
else:
pipeline_device_arg = 0
logging.info("Configuring pipeline for single CUDA device (device=0).")
elif device.type == 'mps':
pipeline_device_arg = 0
logging.info("Configuring pipeline for MPS device (device=0).")
else:
pipeline_device_arg = -1
logging.warning(f"Unknown device type '{device.type}', configuring pipeline for CPU.")
logging.info(f"Updating text generation pipeline (Device Arg: {pipeline_device_arg}, Device Map: {device_map})..."); T=time.time()
try:
if device_map is None and pipeline_device_arg is not None:
if pipeline_device_arg == -1:
global_model.to('cpu')
elif device.type == 'cuda':
global_model.to(f'cuda:{pipeline_device_arg}')
elif device.type == 'mps':
global_model.to('mps:0')
task = "text-generation"
global_pipe = pipeline(
task=task,
model=global_model,
tokenizer=global_tokenizer,
device=pipeline_device_arg,
device_map=device_map
)
pipe_device_str = "N/A"
if global_pipe.device_map:
pipe_device_str = f"device_map: {global_pipe.device_map}"
elif global_pipe.device:
pipe_device_str = str(global_pipe.device)
logging.info(f"Text generation pipeline created/updated. Effective device(s): {pipe_device_str}")
if device_map is None and global_pipe.device != device:
logging.warning(f"Pipeline created on {global_pipe.device}, but target device was {device}. This might happen with device_map issues or insufficient VRAM.")
msg = f"Text generation pipeline updated successfully in {time.time()-T:.2f}s."; logging.info(msg)
return msg
except Exception as e:
msg=f"Pipeline update failed: {e}\n{traceback.format_exc()}"; logging.error(msg);
global_pipe = None;
return f"[Error] Pipeline update failed: {e}"
else:
msg = "Cannot update pipeline: Global model or tokenizer not loaded."; logging.warning(msg);
global_pipe = None;
return msg
def get_detailed_status_and_filter_states():
global global_model, config
t_start = time.time()
if global_model is None:
logging.warning("Cannot get status: Model not loaded.")
return json.dumps({"Error": "Model not loaded."}, indent=2), *get_error_filter_updates()
if not hasattr(global_model, 'config') or global_model.config is None:
logging.warning("Model config missing. Initializing default flags for status check.")
temp_config = initialize_config_flags(None)
status_json = json.dumps({"Warning": "Model config missing, status reflects defaults.", **json.loads(decode_model_details(global_model))}, indent=2)
config_to_check = temp_config
else:
config = global_model.config
config = initialize_config_flags(config)
global_model.config = config
status_json = decode_model_details(global_model)
config_to_check = config
logging.info("Refreshing detailed model status and filter checkbox states...")
filter_states = {}
for name in filter_names_ui:
attr_name = filter_attr_map.get(name)
if attr_name:
filter_states[name] = getattr(config_to_check, attr_name, False)
else:
logging.error(f"Filter name '{name}' not found in attribute map. Setting state to False.")
filter_states[name] = False
updates = [gr.update(value=filter_states.get(name, False)) for name in filter_names_ui]
logging.info(f"Refreshed status and filter states in {time.time()-t_start:.2f}s.");
return status_json, *updates
def get_error_filter_updates():
return [gr.update(value=False) for _ in filter_names_ui]
def base_toggle_function(func_enable, func_disable, enable, success_msg_enable, success_msg_disable, *args):
global global_model, config
t_start = time.time()
if not global_model:
return "[Error] Model not loaded. Load a model first."
if not hasattr(global_model, 'config') or global_model.config is None:
logging.warning("Model config missing. Initializing default flags before toggle.")
global_model.config = initialize_config_flags(None)
config = initialize_config_flags(global_model.config)
global_model.config = config
msg = ""
func_to_call = func_enable if enable else func_disable
action_name = "Enable" if enable else "Disable"
func_name = getattr(func_enable, '__name__', 'unknown_enable').replace('_', ' ').title() if enable else \
getattr(func_disable, '__name__', 'unknown_disable').replace('_', ' ').title()
logging.info(f"Executing toggle: {action_name} {func_name}...")
try:
sig = inspect.signature(func_to_call)
pass_args = []
if 'model' in sig.parameters or 'base_model' in sig.parameters or 'module' in sig.parameters:
pass_args.append(global_model)
if 'config' in sig.parameters:
pass_args.append(config)
pass_args.extend(args)
result = func_to_call(*pass_args)
if isinstance(result, str) and "[Error]" not in result:
msg = result
elif isinstance(result, str):
msg = result
else:
msg = success_msg_enable if enable else success_msg_disable
logging.info(f"Toggle Action ({func_name} -> {action_name}) Result: {msg} (Took {time.time()-t_start:.2f}s)")
if "[Error]" not in msg:
update_pipeline()
except Exception as e:
msg = f"[Error] during toggle ({action_name} {func_name}): {e}"
logging.error(f"{msg}\n{traceback.format_exc()}")
clean_memory()
return msg
def specific_action_function(action_func, *args, success_msg="Action completed successfully."):
global global_model, global_tokenizer, config
t_start=time.time()
if not global_model:
return "[Error] Model not loaded. Load a model first."
if not hasattr(global_model, 'config') or global_model.config is None:
logging.warning("Model config missing. Initializing default flags before action.")
global_model.config = initialize_config_flags(None)
config = initialize_config_flags(global_model.config)
global_model.config = config
msg = ""
func_name = getattr(action_func, '__name__', 'unknown_action')
logging.info(f"Executing action: {func_name}...")
try:
sig = inspect.signature(action_func)
pass_args = []
if 'model' in sig.parameters or 'base_model' in sig.parameters or 'module' in sig.parameters:
pass_args.append(global_model)
if 'config' in sig.parameters:
pass_args.append(config)
if 'tokenizer' in sig.parameters:
if global_tokenizer:
pass_args.append(global_tokenizer)
else:
return f"[Error] Action '{func_name}' requires tokenizer, but it's not loaded."
pass_args.extend(args)
result = action_func(*pass_args)
if isinstance(result, str) and "[Error]" not in result:
msg = result
elif isinstance(result, str):
msg = result
else:
msg = success_msg
logging.info(f"Action ({func_name}) Result: {msg} (Took {time.time()-t_start:.2f}s)")
if "[Error]" not in msg:
update_pipeline()
except Exception as e:
msg = f"[Error] during action ({func_name}): {e}"
logging.error(f"{msg}\n{traceback.format_exc()}")
clean_memory()
return msg
toggle_bias_removal_wrapper = lambda enable: base_toggle_function(_replace_linear_without_bias, _enable_bias_in_linear, enable, "Bias removal applied.", "Bias addition applied (reverted removal).")
toggle_embeddings_untie_wrapper = lambda enable: base_toggle_function(_untie_embeddings, _retie_embeddings, enable, "Embeddings untied.", "Embeddings re-tied.")
toggle_layer_reduction_wrapper = lambda enable, layers: specific_action_function(_reduce_layers_to_one if enable else _enable_full_layers, layers if enable else None, success_msg=f"Layer reduction {'applied' if enable else 'reverted'}.")
apply_norm_swap_wrapper = lambda norm_type: specific_action_function(_swap_normalization_layer, norm_type, success_msg=f"Normalization swapped to {norm_type}")
apply_activation_change_wrapper = lambda name: specific_action_function(_swap_activation_function, name, success_msg=f"Activation Function Swapped to {name}")
revert_activation_change_wrapper = lambda: specific_action_function(_revert_activation_function, success_msg="Activation Function Reverted to Default")
toggle_bitnet_wrapper = lambda enable: base_toggle_function(convert_to_bitnet, revert_bitnet, enable, "BitNet conversion applied.", "BitNet conversion reverted.")
apply_multimodal_wrapper = lambda modalities: specific_action_function(_setup_multimodal, modalities, success_msg="Multi-modal setup attempted.")
revert_multimodal_wrapper = lambda: specific_action_function(_revert_multimodal, success_msg="Multi-modal setup reverted.")
toggle_token_speed_optimization_wrapper = lambda enable: specific_action_function(_optimize_token_generation_speed if enable else _revert_token_generation_speed_optimization, success_msg="Token Speed Opt Flags Updated")
toggle_coherence_improvement_wrapper = lambda enable: specific_action_function(_enable_coherence_improvement if enable else _disable_coherence_improvement, success_msg="Coherence Flag Updated")
toggle_layer_norm_bypass_wrapper = lambda enable: specific_action_function(_enable_layer_norm_bypass if enable else _disable_layer_norm_bypass, success_msg="LN Bypass Updated")
toggle_dropout_bypass_wrapper = lambda enable: specific_action_function(_enable_dropout_bypass if enable else _disable_dropout_bypass, success_msg="Dropout Bypass Updated")
toggle_fp32_precision_wrapper = lambda enable: specific_action_function(_recover_perfect_precision if enable else _revert_perfect_precision, success_msg="FP32 Precision Updated")
toggle_embedding_normalization_wrapper = lambda enable: specific_action_function(_normalize_embeddings if enable else _revert_embedding_normalization, success_msg="Embedding Normalization Updated")
toggle_gradient_checkpointing_wrapper = lambda enable: specific_action_function(_enable_gradient_checkpointing if enable else _disable_gradient_checkpointing, success_msg="Grad Checkpointing Updated")
toggle_flash_attention_wrapper = lambda enable: specific_action_function(_set_attention_variant_config, "flash_attention_2" if enable else "auto", success_msg=f"Flash Attention 2 {'Enabled' if enable else 'Disabled'} (via attn_implementation)")
apply_quantization_wrapper = lambda mode: specific_action_function(_quantize_model, mode, success_msg=f"Quantization Attempted: {mode}")
revert_quantization_wrapper = lambda: specific_action_function(_revert_quantization, success_msg="Quantization Reverted to FP32")
def _parse_pruning_amount(amount_str):
try:
amount = float(amount_str)
if not (0 < amount < 1):
raise ValueError("Pruning amount must be between 0 and 1")
return amount
except (ValueError, TypeError):
logging.warning(f"Invalid pruning amount '{amount_str}', using default {PRUNING_AMOUNT}")
return PRUNING_AMOUNT
apply_pruning_wrapper = lambda amount_str: specific_action_function(
_prune_weights_magnitude,
_parse_pruning_amount(amount_str),
success_msg=f"Pruning Applied (Amount: {_parse_pruning_amount(amount_str):.2f})"
)
revert_pruning_wrapper = lambda: specific_action_function(_revert_pruning, success_msg="Pruning Flag Reverted")
set_lora_path_wrapper = lambda path: specific_action_function(_set_lora_adapter_path, path, success_msg="LoRA Path Set in Config")
add_peft_adapter_wrapper = lambda: specific_action_function(
_add_peft_adapter,
LoraConfig(**DEFAULT_PEFT_CONFIG_DICT) if _peft_installed else None,
success_msg="PEFT Adapter Added"
)
merge_peft_adapter_wrapper = lambda: specific_action_function(_apply_lora_merge, success_msg="PEFT Adapter Merged")
remove_peft_adapter_wrapper = lambda: specific_action_function(_remove_peft_adapter, success_msg="PEFT Adapter Removed")
apply_layer_freeze_wrapper = lambda layers_str: specific_action_function(_freeze_layers, layers_str, success_msg="Layer Freezing Updated")
revert_layer_freeze_wrapper = lambda: specific_action_function(_unfreeze_all_layers, success_msg="All Layers Unfrozen")
toggle_limits_wrapper = lambda enable: specific_action_function(_configure_limits if enable else _remove_limits_configuration, success_msg="Limits Config Updated")
toggle_qa_restrictions_wrapper = lambda enable: specific_action_function(_remove_qa_restrictions if enable else _enable_qa_restrictions, success_msg="QA Restrictions Flag Updated")
def _parse_int_arg(arg, default, min_val=1):
try:
val = int(arg)
return max(val, min_val)
except (ValueError, TypeError):
return default
toggle_kd_wrapper = lambda enable, num_labels=2: specific_action_function(
_setup_knowledge_distillation if enable else _revert_knowledge_distillation,
_parse_int_arg(num_labels, 2, 1) if enable else (),
success_msg="KD Setup Updated"
)
toggle_reward_modeling_wrapper = lambda enable, num_outputs=1: specific_action_function(
_setup_reward_modeling if enable else _revert_reward_modeling,
_parse_int_arg(num_outputs, 1, 1) if enable else (),
success_msg="Reward Modeling Setup Updated"
)
toggle_swa_wrapper = lambda enable: specific_action_function(_apply_swa if enable else _revert_swa, success_msg="SWA Flag Updated")
def _parse_prob_arg(arg, default, min_val=0.0, max_val=1.0):
try:
val = float(arg)
return min(max(val, min_val), max_val)
except(ValueError, TypeError):
return default
toggle_layerdrop_wrapper = lambda enable, prob=0.1: specific_action_function(
_enable_layerdrop if enable else _disable_layerdrop,
_parse_prob_arg(prob, 0.1, 0.0, 1.0) if enable else (),
success_msg="LayerDrop Flag Updated"
)
toggle_rope_scaling_wrapper = lambda enable, type="linear", factor=2.0: specific_action_function(
_set_rope_scaling_config if enable else _revert_rope_scaling,
str(type) if enable else (),
_parse_prob_arg(factor, 2.0, 1.0, 100.0) if enable else (),
success_msg="RoPE Scaling Config Updated"
)
toggle_sliding_window_wrapper = lambda enable, size=4096: specific_action_function(
_set_sliding_window_config if enable else _revert_sliding_window,
_parse_int_arg(size, 4096, 0) if enable else (),
success_msg="Sliding Window Config Updated"
)
apply_attention_variant_wrapper = lambda variant="auto": specific_action_function(_set_attention_variant_config, str(variant), success_msg="Attention Variant Config Updated")
revert_attention_variant_wrapper = lambda: specific_action_function(_revert_attention_variant, success_msg="Attention Variant Config Reverted")
toggle_gradient_clipping_flag_wrapper = lambda enable: specific_action_function(_enable_gradient_clipping if enable else _disable_gradient_clipping, success_msg="Grad Clipping Flag Updated")
toggle_weight_decay_flag_wrapper = lambda enable: specific_action_function(_enable_weight_decay if enable else _disable_weight_decay, success_msg="Weight Decay Flag Updated")
toggle_lr_scheduler_flag_wrapper = lambda enable: specific_action_function(_enable_lr_scheduler if enable else _disable_lr_scheduler, success_msg="LR Scheduler Flag Updated")
apply_optimizer_change_wrapper = lambda name: specific_action_function(_swap_optimizer, str(name), success_msg=f"Optimizer Pref Set: {name}")
revert_optimizer_change_wrapper = lambda: specific_action_function(_revert_optimizer, success_msg="Optimizer Pref Reverted")
def _set_grad_accum_config(config, steps):
try:
s = int(steps)
if s < 1: raise ValueError("Steps must be >= 1")
config.gradient_accumulation_steps = s
global GRADIENT_ACCUMULATION_STEPS
GRADIENT_ACCUMULATION_STEPS = s
return f"Grad Accum Steps set to {s} in config."
except (ValueError, TypeError) as e:
logging.error(f"Invalid gradient accumulation steps: {steps}. Error: {e}")
return f"[Error] Invalid Grad Accum steps: {e}"
set_gradient_accumulation_wrapper = lambda steps: specific_action_function(_set_grad_accum_config, steps, success_msg=f"Grad Accum Steps update attempted.")
toggle_all_safety_filters_wrapper = lambda enable: specific_action_function(_enable_all_safety_settings if enable else _disable_all_safety_settings, success_msg=f"All Safety Filters {'Enabled (Defaults)' if enable else 'Disabled'}")
force_disable_censorship_wrapper = lambda: specific_action_function(_disable_all_safety_settings, success_msg="Attempted Force Disable All Censorship Flags")
def toggle_individual_safety_filter_wrapper(*state_dict):
global global_model, config
t_start=time.time()
if not global_model: return "[Error] Model not loaded."
if not hasattr(global_model, 'config') or global_model.config is None:
logging.warning("Model config missing. Initializing default flags for filter toggle.")
global_model.config = initialize_config_flags(None)
config = initialize_config_flags(global_model.config)
global_model.config = config
results = []
updated_count = 0
if len(state_dict) != len(filter_names_ui):
return f"[Error] Mismatch between filter UI elements ({len(filter_names_ui)}) and received states ({len(state_dict)})."
ui_state = dict(zip(filter_names_ui, state_dict))
for name, checkbox_state in ui_state.items():
filter_attr = filter_attr_map.get(name)
if filter_attr:
current_state = getattr(config, filter_attr, False)
new_state = bool(checkbox_state)
if current_state != new_state:
setattr(config, filter_attr, new_state)
results.append(f"{name}: {'ON' if new_state else 'OFF'}")
updated_count += 1
else:
logging.warning(f"UI filter name '{name}' not found in attribute map filter_attr_map. Skipping.")
if updated_count > 0:
msg = f"Applied {updated_count} individual filter toggle(s): {', '.join(results)}"
update_pipeline()
else:
msg = "No individual filter states were changed."
logging.info(f"Individual filter toggle action took {time.time()-t_start:.2f}s. Status: {msg}");
return msg
def _improve_coherence(model, tokenizer, generation_args):
logging.info("Applying coherence improvement using beam search...")
coherence_beams = generation_args.get("num_beams", 1)
if coherence_beams <= 1: coherence_beams = 4
coherence_args = generation_args.copy()
coherence_args["num_beams"] = coherence_beams
coherence_args["do_sample"] = False
coherence_args["num_return_sequences"] = 1
coherence_args["early_stopping"] = True
coherence_args.pop("temperature", None)
coherence_args.pop("top_k", None)
coherence_args.pop("top_p", None)
input_ids = coherence_args.get("input_ids")
if input_ids is None:
logging.error("Coherence improvement failed: input_ids missing.")
return "[Error: input_ids missing in generation_args]"
try:
with torch.no_grad():
outputs = model.generate(**coherence_args)
response_ids = outputs[0][input_ids.shape[-1]:]
response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
logging.info("Coherence improvement (beam search) successful.")
return response_text
except Exception as e:
logging.error(f"Error during coherence improvement (beam search): {e}. Falling back to original generation settings.")
fallback_args = generation_args.copy()
if fallback_args.get("num_beams", 1) > 1 and fallback_args.get("do_sample", False):
fallback_args["num_beams"] = 1
fallback_args["early_stopping"] = False
try:
with torch.no_grad():
outputs = model.generate(**fallback_args)
response_ids = outputs[0][input_ids.shape[-1]:]
response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
logging.info("Fallback generation successful after coherence error.")
return response_text
except Exception as fallback_e:
logging.error(f"Fallback generation also failed: {fallback_e}")
return f"[Error during generation (fallback): {fallback_e}]"
def _apply_output_filters(text, config):
if not getattr(config, "response_filters", True) or getattr(config, "no_response_filters", False):
return text
logging.debug("Applying output filters based on config...")
original_text = text
if not isinstance(text, str): text = str(text)
filter_actions = {
"profanity_filter": lambda t: re.sub(r'\b(damn|fuck|shit|cunt|bitch|asshole|bastard|hell)\b', '[profanity]', t, flags=re.IGNORECASE),
"pii_filter": lambda t: re.sub(r'\b(\d{3}[-.\s]?\d{3}[-.\s]?\d{4})\b|\b([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,})\b', '[censored_pii]', t),
"hate_filter": lambda t: "[censored_hate_speech]" if any(word in t.lower() for word in ["kill all", "attack the", "inferior race", "exterminate"]) else t,
"sexually_explicit_filter": lambda t: "[censored_explicit]" if any(word in t.lower() for word in [" rape ", " naked ", " porn ", " xxx ", " blowjob ", " sex toy"]) else t,
"dangerous_content_filter": lambda t: "[censored_dangerous]" if any(word in t.lower() for word in ["make a bomb", "build weapon", "illegal drug recipe", "how to kill"]) else t,
"medical_advice_filter": lambda t: "[discouraged_medical_advice]" if any(word in t.lower() for word in ["diagnose", "prescribe", "cure for", "medical treatment for", "symptoms suggest"]) else t,
"legal_advice_filter": lambda t: "[discouraged_legal_advice]" if any(word in t.lower() for word in ["legal advice", "sue", "represent me", "is this legal", "contract law"]) else t,
"financial_advice_filter": lambda t: "[discouraged_financial_advice]" if any(word in t.lower() for word in ["guaranteed investment", "stock tip", "financial plan", "buy bitcoin", "investment advice"]) else t,
"stereotype_filter": lambda t: "[censored_stereotype]" if re.search(r'\b(all|every)\s+([A-Za-z]+(\s+)?){1,3}\s+(are|always)\b', t.lower()) else t,
"misinfo_filter": lambda t: "[potential_misinfo]" if any(phrase in t.lower() for phrase in ["5g causes covid", "earth is flat", "vaccines cause autism", "chemtrails"]) else t,
"self_harm_filter": lambda t: "[censored_self_harm]" if any(phrase in t.lower() for phrase in ["commit suicide", "hurt myself", "painless death", "kill myself"]) else t,
}
active_filters_count = 0
for filter_ui_name, filter_attr in filter_attr_map.items():
if getattr(config, filter_attr, False):
filter_func = filter_actions.get(filter_attr)
if filter_func:
try:
filtered_text = filter_func(text)
if filtered_text != text:
active_filters_count +=1
logging.debug(f"Filter '{filter_attr}' potentially applied modification.")
text = filtered_text
except Exception as e:
logging.warning(f"Error applying filter '{filter_attr}': {e}")
if not getattr(config, "no_advert_warning", False):
if re.search(r'\b(advertisement|sponsored|promo code|discount code|special offer)\b', text, re.IGNORECASE):
if "[Note: This response may contain promotional content.]" not in text:
text += "\n[Note: This response may contain promotional content.]"
active_filters_count +=1
if active_filters_count > 0:
logging.debug(f"Output filtering potentially applied {active_filters_count} modifications.")
return text
def run_inference(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty):
global global_pipe, global_model, global_tokenizer, config
if not all([global_model, global_tokenizer]):
return "[Error] Model or Tokenizer not loaded. Please load a model first."
if global_pipe is None:
pipe_msg = update_pipeline()
if global_pipe is None:
return f"[Error] Text generation pipeline could not be initialized. Load/Reload model. Status: {pipe_msg}"
if not hasattr(global_model, 'config'):
logging.warning("Model config missing during inference. Initializing default flags.")
global_model.config = initialize_config_flags(None)
config = initialize_config_flags(global_model.config)
global_model.config = config
logging.info("Starting inference run..."); t_start_inf = time.time()
try:
use_filters = getattr(config, "response_filters", True) and not getattr(config, "no_response_filters", False)
apply_coherence = getattr(config, "coherence_improvement_enabled", False)
try: max_new_tokens = int(max_new_tokens); assert max_new_tokens > 0
except: max_new_tokens = 256; logging.warning("Invalid max_new_tokens, using 256.")
try: temperature = float(temperature); assert temperature >= 0.0
except: temperature = 0.7; logging.warning("Invalid temperature, using 0.7.")
try: top_k = int(top_k); assert top_k >= 0
except: top_k = 50; logging.warning("Invalid top_k, using 50.")
try: top_p = float(top_p); assert 0.0 <= top_p <= 1.0
except: top_p = 0.95; logging.warning("Invalid top_p, using 0.95.")
try: repetition_penalty = float(repetition_penalty); assert repetition_penalty >= 1.0
except: repetition_penalty = 1.1; logging.warning("Invalid repetition_penalty, using 1.1.")
is_greedy = (temperature < 1e-6) or \
(top_k == 1 and top_k != 0) or \
(top_p <= 0.0 or top_p >= 1.0) or \
getattr(config, "token_gen_speed_maximized", False)
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature if not is_greedy else None,
"top_k": top_k if top_k > 0 and not is_greedy else None,
"top_p": top_p if top_p > 0.0 and top_p < 1.0 and not is_greedy else None,
"repetition_penalty": repetition_penalty if repetition_penalty > 1.0 else None,
"do_sample": not is_greedy,
"use_cache": getattr(config, "use_cache", True),
"num_beams": (max(getattr(config, "num_beams", 1), 4) if apply_coherence else getattr(config, "num_beams", 1)),
"pad_token_id": global_tokenizer.pad_token_id if global_tokenizer.pad_token_id is not None else getattr(config, 'pad_token_id', None),
"eos_token_id": global_tokenizer.eos_token_id if global_tokenizer.eos_token_id is not None else getattr(config, 'eos_token_id', None),
"early_stopping": True if (apply_coherence or getattr(config, "num_beams", 1) > 1) else False
}
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
if gen_kwargs.get("num_beams", 1) > 1 and gen_kwargs.get("pad_token_id") is None:
if gen_kwargs.get("eos_token_id") is not None:
gen_kwargs["pad_token_id"] = gen_kwargs["eos_token_id"]
logging.warning(f"Using eos_token_id ({gen_kwargs['eos_token_id']}) as pad_token_id for beam search.")
else:
logging.error("Beam search requires pad_token_id, but none found (and eos_token_id missing). Generation might fail.")
return "[Error] Beam search failed: pad_token_id is required."
response_text = ""
device = get_device()
logging.debug(f"Generation arguments: {gen_kwargs}")
inputs = global_tokenizer(prompt, return_tensors="pt", padding=False, truncation=True, max_length=CONTEXT_LENGTH).to(device)
gen_kwargs["input_ids"] = inputs["input_ids"]
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs["attention_mask"]
global_model.eval()
if apply_coherence:
response_text = _improve_coherence(global_model, global_tokenizer, gen_kwargs)
else:
with torch.no_grad():
outputs = global_model.generate(**gen_kwargs)
output_sequence = outputs[0]
response_ids = output_sequence[inputs.input_ids.shape[-1]:]
response_text = global_tokenizer.decode(response_ids, skip_special_tokens=True)
if use_filters:
filtered_response = _apply_output_filters(response_text, config)
if filtered_response != response_text:
logging.info("Output filters applied modifications.")
response_text = filtered_response
final_response = response_text.strip()
logging.info(f"Inference finished in {time.time()-t_start_inf:.2f}s. Response length: {len(final_response)}")
return final_response
except Exception as e:
logging.error(f"Error during inference: {e}\n{traceback.format_exc()}")
return f"[Error during inference: {e}]"
finally:
if global_model and hasattr(global_model, 'training') and global_model.training:
global_model.train()
def start_training(
base_model_id: str, new_model_name: str, hf_token: str,
datasets_input_str: str,
activation_fn_name: str, target_layers_int: int,
grad_accum_ui: int, lr: float, epochs: int, max_steps: int, batch_size: int,
optimizer_name: str, scheduler_type: str, weight_decay: float, warmup_ratio: float,
use_peft: bool, peft_r: int, peft_alpha: int, peft_dropout: float, peft_target_modules_str: str,
wandb_token: str, use_cpu_flag: bool, bypass_limits_state: bool
):
global global_model, global_tokenizer, global_pipe, original_num_layers_global, config, target_layers
global USE_CPU, BATCH_SIZE, LEARNING_RATE, EPOCHS, MAX_STEPS, DEFAULT_OPTIMIZER, DEFAULT_SCHEDULER, GRADIENT_ACCUMULATION_STEPS
global BYPASS_RESOURCE_LIMITS
BYPASS_RESOURCE_LIMITS = bypass_limits_state
start_overall_time = time.time()
logging.info("="*50)
logging.info("🚀 STARTING TRAINING PROCESS 🚀")
resources_ok, res_msg = check_resources()
if not resources_ok:
error_msg = f"[Error] Resource limits exceeded, cannot start training. {res_msg}"
logging.error(error_msg)
return error_msg
else:
logging.info(res_msg)
errors = []
if not base_model_id: errors.append("Base Model ID/Path is required.")
if not new_model_name: errors.append("New Model Name (for saving/Hub) is required.")
if not datasets_input_str: errors.append("At least one dataset must be provided.")
try: target_layers_int = int(target_layers_int); assert target_layers_int >= 1
except: errors.append("Target Layers must be a positive integer.")
try: grad_accum_ui = int(grad_accum_ui); assert grad_accum_ui >= 1
except: errors.append("Gradient Accumulation Steps must be a positive integer.")
try: lr = float(lr); assert lr > 0
except: errors.append("Learning Rate must be a positive float.")
try: epochs = int(epochs); assert epochs >= 0
except: errors.append("Epochs must be an integer >= 0.")
try: max_steps = int(max_steps); assert max_steps >= 0
except: errors.append("Max Steps must be an integer >= 0.")
if epochs <= 0 and max_steps <= 0:
errors.append("Training requires at least one of Epochs or Max Steps to be positive.")
elif epochs > 0 and max_steps > 0:
logging.info(f"Both Epochs ({epochs}) and Max Steps ({max_steps}) are set (> 0). Max Steps will take precedence.")
epochs = -1
elif epochs <= 0 and max_steps > 0:
epochs = -1
elif epochs > 0 and max_steps <= 0:
logging.info(f"Using Epochs ({epochs}) for training termination as Max Steps <= 0.")
max_steps = -1
else:
logging.error("Logic error in epoch/max_step handling. Defaulting Max Steps to 1.")
max_steps = 1
epochs = -1
try: batch_size = int(batch_size); assert batch_size >= 1
except: errors.append("Batch Size must be a positive integer.")
if optimizer_name not in OPTIMIZERS: errors.append(f"Invalid Optimizer. Choose from: {list(OPTIMIZERS.keys())}")
if scheduler_type not in SCHEDULER_TYPES: errors.append(f"Invalid Scheduler. Choose from: {SCHEDULER_TYPES}")
try: weight_decay = float(weight_decay); assert weight_decay >= 0.0
except: errors.append("Weight Decay must be a non-negative float.")
try: warmup_ratio = float(warmup_ratio); assert 0.0 <= warmup_ratio <= 1.0
except: errors.append("Warmup Ratio must be between 0.0 and 1.0.")
if activation_fn_name not in ACTIVATION_FUNCTIONS: errors.append(f"Invalid Activation Function. Choose from: {list(ACTIVATION_FUNCTIONS.keys())}")
if use_peft and not _peft_installed: errors.append("PEFT requested, but library not installed (`pip install peft`).")
peft_config_dict = {}
if use_peft:
try:
peft_r = int(peft_r); assert peft_r >= 1
peft_alpha = int(peft_alpha); assert peft_alpha >= 1
peft_dropout = float(peft_dropout); assert 0.0 <= peft_dropout <= 1.0
peft_config_dict = {
"task_type": TaskType.CAUSAL_LM,
"inference_mode": False,
"r": peft_r,
"lora_alpha": peft_alpha,
"lora_dropout": peft_dropout,
}
if peft_target_modules_str:
modules = [m.strip() for m in peft_target_modules_str.split(',') if m.strip()]
if modules:
peft_config_dict["target_modules"] = modules
except Exception as peft_e:
errors.append(f"Invalid PEFT configuration: {peft_e}")
if errors:
error_msg = "[Error] Invalid training parameters:\n- " + "\n- ".join(errors)
logging.error(error_msg)
return error_msg
logging.info(f"Base Model: {base_model_id}, New Name: {new_model_name}")
logging.info(f"Use PEFT: {use_peft}")
if use_peft: logging.info(f"PEFT Config: r={peft_r}, alpha={peft_alpha}, dropout={peft_dropout}, targets={peft_target_modules_str or 'Auto'}")
logging.info(f"Datasets: \n{datasets_input_str}")
logging.info(f"LR: {lr}, Effective Epochs: {epochs if epochs > 0 else 'N/A'}, MaxSteps: {max_steps if max_steps > 0 else 'N/A'}, BS: {batch_size}, GradAccum: {grad_accum_ui}")
logging.info(f"Optim: {optimizer_name}, Scheduler: {scheduler_type}, WD: {weight_decay}, Warmup: {warmup_ratio}")
logging.info(f"Post-Mod Target Layers: {target_layers_int}, Post-Mod ActFn: {activation_fn_name}")
logging.info(f"Use CPU: {use_cpu_flag}, W&B: {'Enabled' if wandb_token else 'Disabled'}, Bypass Limits: {BYPASS_RESOURCE_LIMITS}")
logging.info("="*50)
USE_CPU = use_cpu_flag
BATCH_SIZE = batch_size
LEARNING_RATE = lr
EPOCHS = epochs if epochs > 0 else 1
MAX_STEPS = max_steps
DEFAULT_OPTIMIZER = optimizer_name
DEFAULT_SCHEDULER = scheduler_type
GRADIENT_ACCUMULATION_STEPS = grad_accum_ui
target_layers = target_layers_int
logging.info("Setting up environment...")
clean_memory()
device = get_device()
logging.info(f"Using device: {device}")
num_cpu_cores_os = os.cpu_count() or 1
global NUM_CPU_CORES
if NUM_CPU_CORES <= 0: NUM_CPU_CORES = num_cpu_cores_os
else: NUM_CPU_CORES = min(NUM_CPU_CORES, num_cpu_cores_os)
logging.info(f"Using {NUM_CPU_CORES} CPU cores for dataloading.")
wandb_run = None
use_wandb_reporting = False
if wandb_token:
logging.info("Attempting WandB login...")
try:
wandb.login(key=wandb_token)
logging.info("WandB login successful.")
use_wandb_reporting = True
except Exception as e:
logging.warning(f"WandB login failed: {e}. Proceeding without WandB logging.")
report_to = ["wandb"] if use_wandb_reporting else []
user_id = "local_user"; repo_id_str = new_model_name; repo_link = "N/A (Upload skipped or failed)"
upload_to_hub = False
if hf_token:
logging.info("Attempting Hugging Face login...")
user_id = get_user_id(hf_token)
if user_id not in ["unknown_user", "http_error_user", "auth_error_user"]:
try:
login(token=hf_token, add_to_git_credential=False)
repo_id_str = f"{user_id}/{new_model_name}"
logging.info(f"Hugging Face login successful. User: {user_id}, Target Repo: {repo_id_str}")
create_repo(repo_id=repo_id_str, repo_type="model", exist_ok=True, token=hf_token)
logging.info(f"Hub repository '{repo_id_str}' ensured.")
repo_link = f"https://huggingface.co/{repo_id_str}"
upload_to_hub = True
except Exception as e:
logging.warning(f"Hugging Face login or repo creation failed: {e}. Upload will be skipped.")
hf_token = None
repo_id_str = new_model_name
repo_link = "N/A (Login/Repo Failed)"
else:
logging.warning(f"Could not get valid Hugging Face user ID ({user_id}). Upload will be skipped.")
hf_token = None
repo_id_str = new_model_name
repo_link = "N/A (Login Failed)"
else:
logging.info("No HF write token provided, Hub upload will be skipped.")
logging.info(f"Loading base model '{base_model_id}' and tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token)
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
logging.info(f"Set tokenizer pad_token to eos_token ('{tokenizer.eos_token}')")
else:
added_pad = tokenizer.add_special_tokens({'pad_token': '[PAD]'})
if added_pad > 0:
logging.warning("Tokenizer missing pad_token and eos_token. Added '[PAD]' as pad_token.")
else:
logging.error("Tokenizer missing pad/eos and failed to add '[PAD]'. Training may fail.")
base_config_obj = AutoConfig.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token)
base_config_obj = initialize_config_flags(base_config_obj)
original_num_layers_global = getattr(base_config_obj, 'num_hidden_layers', LAYERS)
if getattr(base_config_obj, 'original_num_layers', None) is None:
base_config_obj.original_num_layers = original_num_layers_global
if getattr(base_config_obj, 'vocab_size', -1) != len(tokenizer):
logging.warning(f"Config vocab size ({getattr(base_config_obj, 'vocab_size', 'N/A')}) differs from tokenizer ({len(tokenizer)}). Updating config.")
base_config_obj.vocab_size = len(tokenizer)
if getattr(base_config_obj, 'pad_token_id', -999) != tokenizer.pad_token_id:
base_config_obj.pad_token_id = tokenizer.pad_token_id
load_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32
attn_impl_load = getattr(base_config_obj, 'attn_implementation', 'auto')
if attn_impl_load == "flash_attention_2": base_config_obj.use_flash_attention_2 = True
elif getattr(base_config_obj,'use_flash_attention_2', False): attn_impl_load = "flash_attention_2"; base_config_obj.attn_implementation = "flash_attention_2"
logging.info(f"Loading model with dtype={load_dtype}, attn_implementation='{attn_impl_load}'...")
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
config=base_config_obj,
trust_remote_code=True,
token=hf_token,
torch_dtype=load_dtype,
low_cpu_mem_usage=True if device.type != 'cpu' else False,
attn_implementation=attn_impl_load if attn_impl_load != 'auto' else None
)
if model.get_input_embeddings().weight.shape[0] != len(tokenizer):
logging.info(f"Resizing model token embeddings from {model.get_input_embeddings().weight.shape[0]} to tokenizer size {len(tokenizer)}")
model.resize_token_embeddings(len(tokenizer))
if getattr(model.config, 'vocab_size', -1) != len(tokenizer):
model.config.vocab_size = len(tokenizer)
logging.info(f"Base model '{base_model_id}' loaded. Original Layers: {original_num_layers_global}, Current Layers: {model.config.num_hidden_layers}, Dtype: {model.dtype}")
if device.type == 'cpu' or not (device.type != 'cpu' and True):
model.to(device)
logging.info(f"Model moved to device: {device}")
else:
logging.info(f"Model loaded with low_cpu_mem_usage, should be on target device(s).")
config = model.config
except Exception as e:
logging.error(f"Failed to load base model or tokenizer '{base_model_id}': {e} \n{traceback.format_exc()}")
return f"[Error] Load failed for '{base_model_id}': {e}"
if use_peft:
logging.info("Applying PEFT adapter to the model for training...")
try:
lora_config = LoraConfig(**peft_config_dict)
peft_add_msg = _add_peft_adapter(model, config, peft_config_obj=lora_config)
except Exception as peft_e:
logging.error(f"Failed to configure or add PEFT adapter: {peft_e}")
return f"[Error] Failed to prepare PEFT model: {peft_e}"
if "[Error]" in peft_add_msg or "[Warning]" in peft_add_msg:
logging.error(f"Failed adding PEFT adapter: {peft_add_msg}")
return f"[Error] Failed adding PEFT adapter: {peft_add_msg}"
model = global_model
config = global_model.get_base_model().config
logging.info("PEFT adapter added successfully.")
else:
logging.info("Proceeding with full fine-tuning (PEFT not selected).")
logging.info("Loading and processing datasets...")
train_ds_processed = None
eval_ds_processed = None
try:
datasets_config_list = parse_datasets(datasets_input_str)
interleaved_ds = load_datasets_from_config(datasets_config_list)
if interleaved_ds is None:
raise ValueError("Dataset loading and interleaving resulted in None. Check logs.")
tokenize_partial = partial(tokenize_function, tokenizer=tokenizer, context_length=CONTEXT_LENGTH)
tokenized_ds = interleaved_ds.map(
tokenize_partial,
batched=True,
batch_size=1000,
)
group_partial = partial(group_texts, block_size=CONTEXT_LENGTH)
lm_dataset = tokenized_ds.map(
group_partial,
batched=True,
batch_size=1000,
)
try:
peek_final = next(iter(lm_dataset))
final_cols = list(peek_final.keys())
logging.info(f"Sample processed record structure: { {k: type(v).__name__ for k, v in peek_final.items()} }")
if not all(k in final_cols for k in ['input_ids', 'attention_mask', 'labels']):
raise ValueError(f"Final dataset structure after tokenizing/grouping is missing required keys. Found: {final_cols}")
except StopIteration:
raise ValueError("Dataset appears empty after processing and grouping.")
logging.info("Dataset tokenization and grouping complete.")
train_ds_processed, eval_ds_processed = split_dataset(lm_dataset)
if isinstance(train_ds_processed, IterableDataset):
logging.info("Training dataset is iterable (streaming).")
elif isinstance(train_ds_processed, Dataset):
logging.info(f"Training dataset size: {len(train_ds_processed):,} examples.")
else:
logging.warning("Could not determine training dataset type or size.")
if eval_ds_processed is not None:
logging.info(f"Created static evaluation dataset with {len(eval_ds_processed)} examples.")
else:
logging.info("No evaluation dataset created (buffer empty or error occurred).")
except Exception as e:
logging.error(f"Dataset loading, processing, or splitting failed: {e} \n{traceback.format_exc()}")
return f"[Error] Dataset preparation failed: {e}"
logging.info("Setting up Training Arguments...")
final_weight_decay = weight_decay if not getattr(config, 'weight_decay_disabled', False) else 0.0
final_lr_scheduler = scheduler_type if not getattr(config, 'lr_scheduler_disabled', False) else "constant"
max_grad_norm_val = 1.0 if not getattr(config, 'gradient_clipping_disabled', False) else None
output_dir = f"./{new_model_name}_training_output"
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
report_to=report_to,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=max(1, BATCH_SIZE * 2),
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
num_train_epochs=EPOCHS if EPOCHS > 0 else 1,
max_steps=MAX_STEPS,
optim=optimizer_name,
learning_rate=LEARNING_RATE,
weight_decay=final_weight_decay,
warmup_ratio=warmup_ratio,
lr_scheduler_type=final_lr_scheduler,
max_grad_norm=max_grad_norm_val if max_grad_norm_val is not None else 1e9,
fp16=load_dtype == torch.float16 and device.type == 'cuda',
bf16=load_dtype == torch.bfloat16 and device.type == 'cuda',
gradient_checkpointing=getattr(config, 'gradient_checkpointing_enabled', False),
gradient_checkpointing_kwargs={'use_reentrant': False} if getattr(config, 'gradient_checkpointing_enabled', False) else None,
dataloader_num_workers=NUM_CPU_CORES,
dataloader_pin_memory=True if device.type == 'cuda' else False,
evaluation_strategy="steps" if eval_ds_processed is not None else "no",
eval_steps=EVAL_STEPS if eval_ds_processed is not None else None,
save_strategy="steps",
save_steps=SAVE_STEPS,
save_total_limit=2,
load_best_model_at_end=LOAD_BEST_MODEL_AT_END if eval_ds_processed is not None else False,
metric_for_best_model=METRIC_FOR_BEST_MODEL if eval_ds_processed is not None else None,
logging_strategy="steps",
logging_steps=LOGGING_STEPS,
push_to_hub=upload_to_hub,
hub_model_id=repo_id_str if upload_to_hub else None,
hub_token=hf_token if upload_to_hub else None,
hub_strategy="checkpoint",
use_cpu=USE_CPU,
seed=42,
remove_unused_columns=False,
log_level="info",
)
logging.info("Initializing Trainer...")
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
callbacks = []
if LOAD_BEST_MODEL_AT_END and eval_ds_processed is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE, early_stopping_threshold=0.001))
if use_wandb_reporting:
try:
wandb_run = wandb.init(
project=f"llm-modify-train-{new_model_name.replace('/', '-')}",
config=training_args.to_dict(),
name=f"run-{new_model_name.replace('/', '-')}-{int(time.time())}",
reinit=True
)
logging.info(f"WandB run initialized: {wandb_run.name if wandb_run else 'Failed'}")
except Exception as wandb_e:
logging.error(f"Failed to initialize WandB run: {wandb_e}")
wandb_run = None
training_args.report_to = []
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=train_ds_processed,
eval_dataset=eval_ds_processed,
data_collator=data_collator,
callbacks=callbacks
)
start_train_time = time.time()
logging.info(f"🚀 Starting model training (Using {type(trainer.model).__name__}). Effective steps: {training_args.max_steps if training_args.max_steps > 0 else 'N/A'}. Effective epochs: {training_args.num_train_epochs if training_args.num_train_epochs > 0 else 'N/A'}.")
train_result = None
training_successful = False
try:
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
from transformers.trainer_utils import get_last_checkpoint
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint:
logging.info(f"*** Resuming training from checkpoint: {last_checkpoint} ***")
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
logging.info("✅ Training finished successfully.")
training_successful = True
trainer.save_model()
trainer.save_state()
if not use_peft:
tokenizer.save_pretrained(training_args.output_dir)
elif isinstance(trainer.model, PeftModel):
tokenizer.save_pretrained(training_args.output_dir)
except Exception as e:
logging.error(f"❌ Training failed: {e}\n{traceback.format_exc()}")
if wandb_run: wandb_run.finish(exit_code=1)
return f"[Error] Training failed: {e}"
finally:
end_train_time = time.time()
clean_memory()
training_time = end_train_time - start_train_time
logging.info(f"🕒 Training phase took {training_time:.2f} seconds.")
if not training_successful:
return "[Error] Training did not complete successfully."
final_trained_model = trainer.model
model_to_save = final_trained_model
merged_model_for_mods = None
if use_peft and isinstance(final_trained_model, PeftModel):
logging.info("Merging PEFT adapter into the base model for modification and final save...")
try:
merged_model_for_mods = final_trained_model.merge_and_unload()
logging.info("PEFT adapter merged successfully.")
merged_model_for_mods.config.peft_adapter_added = False
merged_model_for_mods.config.peft_config = None
merged_model_for_mods.config.lora_merged = True
except Exception as e:
logging.error(f"Failed to merge PEFT adapter after training: {e}. Saving adapter separately.")
adapter_save_path = os.path.join(training_args.output_dir, "final_adapter")
try:
final_trained_model.save_pretrained(adapter_save_path)
base_model_for_saving = final_trained_model.get_base_model()
base_model_for_saving.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
logging.info(f"PEFT adapter saved separately to {adapter_save_path}, base model to {training_args.output_dir}")
merged_model_for_mods = final_trained_model
except Exception as save_e:
logging.error(f"Failed to save adapter or base model separately: {save_e}. Proceeding with potentially unmerged PEFT model.")
merged_model_for_mods = final_trained_model
else:
merged_model_for_mods = final_trained_model
if merged_model_for_mods is None:
logging.error("Model state after training/merging is None. Cannot proceed.")
return "[Error] Lost model reference after training/merging."
def modify_model_post_train(model_obj, act_fn_name, target_layer_count):
logging.info(f"Applying post-training modifications: Target Layers={target_layer_count}, Activation={act_fn_name}")
if not hasattr(model_obj, 'config'):
logging.error("Cannot modify model: Missing config.")
return model_obj
current_config = initialize_config_flags(model_obj.config)
model_obj.config = current_config
current_layers = getattr(current_config, 'num_hidden_layers', None)
original_layers = getattr(current_config, 'original_num_layers', original_num_layers_global)
if current_layers is not None and original_layers is not None:
if target_layer_count != current_layers:
logging.info(f"Adjusting layers post-training: {current_layers} -> {target_layer_count} (Original: {original_layers})")
if target_layer_count < current_layers:
_reduce_layers_to_one(model_obj, current_config, target_layers=target_layer_count)
else:
restore_target = min(target_layer_count, original_layers) if original_layers else target_layer_count
if restore_target > current_layers:
logging.info(f"Attempting to restore layers: {current_layers} -> {restore_target}")
_enable_full_layers(model_obj, current_config, original_num_layers=restore_target)
else:
logging.info(f"Target layers ({target_layer_count}) >= current layers ({current_layers}). No layer increase needed or possible beyond original.")
else:
logging.info(f"Target layers ({target_layer_count}) matches current layers after training. No layer adjustment needed.")
else:
logging.warning("Could not determine current or original layer count from config post-training. Skipping layer adjustment.")
current_act_fn = getattr(current_config, 'current_activation_function', DEFAULT_ACTIVATION_FUNCTION)
if act_fn_name != current_act_fn:
logging.info(f"Setting activation function post-training to: {act_fn_name}")
_swap_activation_function(model_obj, current_config, act_fn_name)
else:
logging.info(f"Target activation function ({act_fn_name}) already matches current. No change needed.")
logging.info("Post-training modifications applied.")
return model_obj
logging.info("Applying final post-training modifications specified in UI...")
final_model_modified = modify_model_post_train(merged_model_for_mods, activation_fn_name, target_layers_int)
if merged_model_for_mods is not final_model_modified:
del merged_model_for_mods
clean_memory()
final_model_path = training_args.output_dir
logging.info(f"Saving final modified model state to {final_model_path}...")
try:
save_kwargs = {"safe_serialization": True}
final_model_modified.config = initialize_config_flags(final_model_modified.config)
if use_peft and isinstance(final_model_modified, PeftModel):
logging.warning("Saving unmerged PEFT model state again after modifications (adapter separate).")
adapter_save_dir = os.path.join(final_model_path, "final_adapter_modified")
final_model_modified.save_pretrained(adapter_save_dir)
logging.info(f"PEFT adapter saved to {adapter_save_dir}")
try:
base_model_final = final_model_modified.get_base_model()
base_model_final.save_pretrained(final_model_path, **save_kwargs)
tokenizer.save_pretrained(final_model_path)
logging.info(f"Base model saved to {final_model_path}")
except Exception as base_save_e:
logging.error(f"Failed to save base model separately after modification: {base_save_e}. Only adapter might be saved.")
else:
final_model_modified.save_pretrained(final_model_path, **save_kwargs)
tokenizer.save_pretrained(final_model_path)
logging.info("Final modified model saved locally.")
except Exception as e:
logging.error(f"Failed to save final modified model locally: {e}\n{traceback.format_exc()}")
if wandb_run: wandb_run.finish(exit_code=1)
global_model = final_model_modified.to(device)
global_tokenizer = tokenizer
config = final_model_modified.config
update_pipeline()
clean_memory()
return f"[Error] Failed to save final model locally: {e}. Training logs/checkpoints might be in {output_dir}."
final_eval_results = {}; final_eval_loss = None; final_perplexity = float('inf')
if eval_ds_processed is not None:
logging.info("Evaluating final modified model..."); T_final_eval = time.time()
try:
final_trainer = Trainer(
model=final_model_modified,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
eval_dataset=eval_ds_processed,
)
final_eval_results = final_trainer.evaluate()
final_eval_loss = final_eval_results.get("eval_loss")
final_perplexity = compute_perplexity(final_eval_loss)
logging.info(f"✅ Final Model Evaluation Results: {final_eval_results}")
logging.info(f"Final Model Perplexity: {final_perplexity:.4f} (Eval time: {time.time() - T_final_eval:.2f}s)")
if use_wandb_reporting and wandb_run:
wandb_run.log({"final_eval_loss": final_eval_loss if final_eval_loss is not None else -1.0,
"final_perplexity": final_perplexity if final_perplexity != float('inf') else -1.0,
**final_eval_results})
except Exception as e:
logging.error(f"Final evaluation failed: {e}\n{traceback.format_exc()}")
if use_wandb_reporting and wandb_run:
wandb_run.log({"final_eval_status": "Failed", "final_eval_error": str(e)})
else:
logging.info("Skipping final evaluation as no evaluation dataset was available.")
upload_successful_final = False
if upload_to_hub:
logging.info(f"Attempting final upload of '{final_model_path}' to Hugging Face Hub: {repo_id_str}...")
try:
api = HfApi()
api.upload_folder(
folder_path=final_model_path,
repo_id=repo_id_str,
repo_type="model",
token=hf_token,
commit_message=f"Upload final trained model: {new_model_name} (Base: {base_model_id}, PPL: {final_perplexity:.2f})",
commit_description=(f"Training completed. Eval Loss: {final_eval_loss:.4f if final_eval_loss is not None else 'N/A'}, Perplexity: {final_perplexity:.4f if final_perplexity != float('inf') else 'N/A'}. "
f"Config: PEFT={use_peft}, Layers={target_layers_int}, ActFn={activation_fn_name}. Training time: {training_time:.2f}s.")
)
repo_link = f"https://huggingface.co/{repo_id_str}"
logging.info(f"✅ Final model upload complete: {repo_link}")
if use_wandb_reporting and wandb_run: wandb_run.log({"hf_repo_link": repo_link, "hf_upload_status": "Success"})
upload_successful_final = True
except Exception as e:
logging.error(f"Final Hugging Face upload failed: {e}\n{traceback.format_exc()}")
repo_link = "[Upload Failed]"
if use_wandb_reporting and wandb_run: wandb_run.log({"hf_upload_status": "Failed", "hf_upload_error": str(e)})
else:
logging.info(f"Skipping final Hugging Face Hub upload based on initial setup.")
logging.info("Updating global state with the final model...")
global_model = final_model_modified.to(device)
global_tokenizer = tokenizer
config = global_model.config
update_pipeline()
clean_memory()
final_status_report_json = decode_model_details(global_model)
total_script_time = time.time() - start_overall_time
final_message = (
f"✅ Training & Modification Process Complete!\n"
f"{'='*40}\n"
f"New Model Name: {new_model_name}\n"
f"Base Model: {base_model_id}\n"
f"Total Time: {total_script_time:.2f}s | Training Phase Time: {training_time:.2f}s\n"
f"{'='*40}\n"
f"Training Results:\n"
)
if train_result:
final_message += f" - Steps Completed: {train_result.global_step}\n"
train_loss = train_result.training_loss
final_message += f" - Training Loss: {train_loss:.4f if train_loss is not None else 'N/A'}\n"
train_metrics = train_result.metrics
for metric, value in train_metrics.items():
if "loss" in metric.lower() or "perplexity" in metric.lower() or "epoch" in metric.lower() or "step" in metric.lower():
value_str = f"{value:.4f}" if isinstance(value, float) else str(value)
final_message += f" - {metric.replace('_', ' ').title()}: {value_str}\n"
final_message += (
f"Final Evaluation:\n"
f" - Eval Loss: {final_eval_loss:.4f if final_eval_loss is not None else 'N/A'}\n"
f" - Perplexity: {final_perplexity:.4f if final_perplexity != float('inf') else 'N/A'}\n"
f"{'='*40}\n"
f"Saving & Upload:\n"
f" - Local Path: {final_model_path}\n"
f" - Hub Repo: {repo_link}\n"
f"{'='*40}\n"
f"Final Model Status Summary:\n"
)
try:
status_data = json.loads(final_status_report_json)
summary_keys = ["Model Class", "Config Class", "Device(s)", "Params Summary", "Layer Types Count", "Key Config Attributes", "Modification Flags"]
for key in summary_keys:
if key in status_data:
value = status_data[key]
if isinstance(value, dict):
value_str = json.dumps(value, indent=4)
elif isinstance(value, list):
value_str = ", ".join(map(str, value))
else:
value_str = str(value)
if len(value_str) > 200: value_str = value_str[:200] + "..."
final_message += f" - {key}: {value_str}\n"
final_message += f"(Full status logged and available in 'Model Controls' tab after refresh)\n"
except Exception as json_e:
logging.warning(f"Could not parse final status JSON for summary: {json_e}")
final_message += "(Could not generate status summary from JSON)\n"
final_message += f"{'='*40}"
if use_wandb_reporting and wandb_run:
try:
wandb_final_log = {
"total_time_seconds": total_script_time,
"training_time_seconds": training_time,
"final_eval_loss": final_eval_loss if final_eval_loss is not None else -1.0,
"final_perplexity": final_perplexity if final_perplexity != float('inf') else -1.0,
"upload_successful": upload_successful_final,
"final_steps_completed": train_result.global_step if train_result else -1,
"final_train_loss": train_result.training_loss if train_result and train_result.training_loss else -1.0,
}
wandb_run.log(wandb_final_log)
wandb_run.finish()
logging.info("WandB run finished.")
except Exception as e:
logging.warning(f"Error finishing WandB run: {e}")
logging.info("🏁 Full training and modification process finished. 🏁")
return final_message
def load_model_for_control(model_id_or_path, hf_token=None, bypass_limits_state=False):
global global_model, global_tokenizer, global_pipe, config, original_num_layers_global, BYPASS_RESOURCE_LIMITS
BYPASS_RESOURCE_LIMITS = bypass_limits_state
logging.info(f"Attempting to load model for control: {model_id_or_path}")
if not model_id_or_path:
return "[Error] Model ID or Path cannot be empty.", "{}", *get_error_filter_updates()
resources_ok, res_msg = check_resources()
if not resources_ok:
error_msg = f"[Error] Resource limits exceeded, cannot load model. {res_msg}"
logging.error(error_msg)
return error_msg, "{}", *get_error_filter_updates()
else:
logging.info(res_msg)
t_load_start = time.time()
device = get_device()
error_return = f"[Error] Failed to load model '{model_id_or_path}'.", "{}", *get_error_filter_updates()
global_model, global_tokenizer, global_pipe, config = None, None, None, None
clean_memory()
try:
logging.info("Loading tokenizer...")
tokenizer_load = AutoTokenizer.from_pretrained(
model_id_or_path,
trust_remote_code=True,
token=hf_token
)
if tokenizer_load.pad_token is None:
if tokenizer_load.eos_token is not None:
tokenizer_load.pad_token = tokenizer_load.eos_token
logging.info(f"Set tokenizer pad_token to eos_token ('{tokenizer_load.eos_token}')")
else:
try:
tokenizer_load.add_special_tokens({'pad_token': '[PAD]'})
logging.warning("Added '[PAD]' as pad_token.")
except Exception as pad_e:
logging.error(f"Could not set PAD token: {pad_e}. Batching or beam search might fail.")
logging.info("Loading model config...")
loaded_config = AutoConfig.from_pretrained(
model_id_or_path,
trust_remote_code=True,
token=hf_token
)
config_load = initialize_config_flags(loaded_config)
original_layers_load = getattr(config_load, 'num_hidden_layers', LAYERS)
if getattr(config_load, 'original_num_layers', None) is None:
config_load.original_num_layers = original_layers_load
logging.info(f"Set original_num_layers in loaded config to {original_layers_load}")
original_num_layers_global = config_load.original_num_layers
if getattr(config_load, 'vocab_size', -1) != len(tokenizer_load):
config_load.vocab_size = len(tokenizer_load)
if getattr(config_load, 'pad_token_id', -999) != tokenizer_load.pad_token_id:
config_load.pad_token_id = tokenizer_load.pad_token_id
logging.info("Loading model weights...")
attn_impl_load = getattr(config_load, 'attn_implementation', 'auto')
if attn_impl_load == "flash_attention_2": config_load.use_flash_attention_2 = True
elif getattr(config_load,'use_flash_attention_2', False): attn_impl_load = "flash_attention_2"; config_load.attn_implementation = "flash_attention_2"
load_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32
logging.info(f"Using dtype {load_dtype} and attn_implementation '{attn_impl_load}' for loading.")
model_load = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
config=config_load,
trust_remote_code=True,
token=hf_token,
torch_dtype=load_dtype,
low_cpu_mem_usage=True if device.type != 'cpu' else False,
attn_implementation=attn_impl_load if attn_impl_load != 'auto' else None,
)
if model_load.get_input_embeddings().weight.shape[0] != len(tokenizer_load):
logging.info(f"Resizing loaded model embeddings from {model_load.get_input_embeddings().weight.shape[0]} to tokenizer size {len(tokenizer_load)}")
model_load.resize_token_embeddings(len(tokenizer_load))
if getattr(model_load.config, 'vocab_size', -1) != len(tokenizer_load):
model_load.config.vocab_size = len(tokenizer_load)
global_model = model_load.to(device)
global_tokenizer = tokenizer_load
config = global_model.config
logging.info(f"Model loaded successfully to {device}.")
update_pipeline()
clean_memory()
logging.info(f"Model '{model_id_or_path}' loaded and pipeline updated in {time.time() - t_load_start:.2f}s.")
status_json, *filter_updates = get_detailed_status_and_filter_states()
return f"Model '{model_id_or_path}' loaded successfully.", status_json, *filter_updates
except Exception as e:
logging.error(f"Failed to load model '{model_id_or_path}': {e}\n{traceback.format_exc()}")
global_model, global_tokenizer, global_pipe, config = None, None, None, None
clean_memory()
return error_return
def save_current_model(save_path, hf_token=None, hub_repo_id=None):
global global_model, global_tokenizer, config
if not global_model or not global_tokenizer:
return "[Error] No model loaded to save."
if not save_path and not hub_repo_id:
return "[Error] Please provide a local save path or a Hub Repo ID (or both)."
t_save_start = time.time()
model_to_save = global_model
tokenizer_to_save = global_tokenizer
config_to_save = initialize_config_flags(config if config else getattr(model_to_save, 'config', None))
if config_to_save is None:
logging.error("Cannot save: Model config is missing.")
return "[Error] Model config is missing, cannot save."
model_to_save.config = config_to_save
is_peft_model = _peft_installed and isinstance(model_to_save, PeftModel)
save_adapter_only = is_peft_model
logging.info(f"Save mode: {'Adapter Only (PEFT model detected)' if save_adapter_only else 'Full Model'}")
temp_save_dir = None
effective_save_path = save_path.strip() if save_path else None
if not effective_save_path and hub_repo_id:
temp_save_dir = f"./hub_upload_temp_{hub_repo_id.replace('/', '_')}_{int(time.time())}"
effective_save_path = temp_save_dir
logging.info(f"No local path provided, saving temporarily to '{effective_save_path}' for Hub upload.")
elif not effective_save_path:
return "[Error] Cannot determine save location (missing local path and Hub ID)."
try:
os.makedirs(effective_save_path, exist_ok=True)
except OSError as e:
logging.error(f"Failed to create save directory '{effective_save_path}': {e}")
return f"[Error] Failed to create save directory: {e}"
local_save_message = ""
try:
logging.info(f"Saving current model state to {effective_save_path}...")
save_kwargs = {"safe_serialization": True}
if save_adapter_only:
logging.info("Saving PEFT adapter weights and tokenizer.")
model_to_save.save_pretrained(effective_save_path)
tokenizer_to_save.save_pretrained(effective_save_path)
try:
base_model_config = model_to_save.get_base_model().config
base_model_config.save_pretrained(effective_save_path)
except Exception as config_e:
logging.warning(f"Could not save base model config alongside adapter: {config_e}")
else:
logging.info("Saving full model weights and tokenizer.")
model_to_save.save_pretrained(effective_save_path, **save_kwargs)
tokenizer_to_save.save_pretrained(effective_save_path)
save_local_time = time.time() - t_save_start
logging.info(f"Model state saved locally to {effective_save_path} in {save_local_time:.2f}s")
local_save_message = f"Model saved locally to '{effective_save_path}'."
except Exception as e:
logging.error(f"Failed to save model locally to {effective_save_path}: {e}\n{traceback.format_exc()}")
if temp_save_dir and os.path.exists(temp_save_dir):
try: shutil.rmtree(temp_save_dir); logging.info("Cleaned up temporary directory after local save error.")
except Exception as clean_e: logging.warning(f"Could not remove temp dir {temp_save_dir} after error: {clean_e}")
return f"[Error] Failed to save model locally: {e}"
hub_message = ""
upload_successful = False
if hub_repo_id:
if not hf_token:
hub_message = "[Warning] Hub upload skipped: Hugging Face Write Token required."
logging.warning(hub_message)
else:
logging.info(f"Attempting to upload '{effective_save_path}' to Hub repo: {hub_repo_id}")
try:
api = HfApi();
create_repo(repo_id=hub_repo_id, repo_type="model", exist_ok=True, token=hf_token)
api.upload_folder(
folder_path=effective_save_path,
repo_id=hub_repo_id,
repo_type="model",
token=hf_token,
commit_message=f"Upload model state ({'Adapter' if save_adapter_only else 'Full'}) via LLM Platform",
commit_description=f"Saved from LLM Platform UI. Model class: {type(global_model).__name__}. State: {'PEFT Adapter' if save_adapter_only else 'Full Model'}.",
)
hub_link = f"https://huggingface.co/{hub_repo_id}"
hub_message = f"Successfully uploaded to Hub: {hub_link}"
upload_successful = True
logging.info(hub_message)
except Exception as e:
hub_message = f"[Error] Hub upload failed: {e}"
logging.error(f"Hub upload failed: {e}\n{traceback.format_exc()}")
if temp_save_dir and os.path.exists(temp_save_dir):
try:
shutil.rmtree(temp_save_dir)
logging.info("Cleaned up temporary save directory.")
except Exception as e:
logging.warning(f"Could not remove temporary directory {temp_save_dir}: {e}")
final_message = local_save_message if save_path else ""
if hub_message:
if final_message: final_message += f" | {hub_message}"
else: final_message = hub_message
if not final_message:
final_message = "[Info] No local save path provided and Hub upload failed or was skipped."
total_save_time = time.time() - t_save_start
logging.info(f"Total save operation took {total_save_time:.2f}s")
return final_message
filter_names_ui = [
"Harassment", "Hate Speech", "Sexually Explicit", "Dangerous Content",
"Civic Integrity", "Harmful Code", "Medical Advice", "Legal Advice",
"Financial Advice", "PII (Basic)", "Political Content", "Religious Content",
"Profanity", "Stereotype", "Misinfo", "Self Harm",
"Personal Attack", "Toxicity", "Spam", "Off Topic",
"Tone", "Min Max Length", "Repetition Filter", "Factuality Filter"
]
filter_attr_map = {name: name.lower().replace(" ", "_").replace("(", "").replace(")", "") + "_filter" for name in filter_names_ui}
filter_attr_map["PII (Basic)"] = "pii_filter"
filter_attr_map["Harmful Code"] = "code_filter"
filter_attr_map["Min Max Length"] = "min_max_length_filter"
filter_attr_map["Repetition Filter"] = "repetition_filter_enabled"
filter_attr_map["Factuality Filter"] = "factuality_filter_enabled"
custom_theme = gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky).set(
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_400",
button_secondary_background_fill="*secondary_500",
button_secondary_background_fill_hover="*secondary_400",
button_cancel_background_fill="*neutral_200",
button_cancel_background_fill_hover="*neutral_300",
)
with gr.Blocks(theme=custom_theme, title="Advanced LLM Training & Modification Platform") as demo:
gr.Markdown("# 🤖 Advanced LLM Training & Modification Platform v1.2")
gr.Markdown("Load, modify, filter, train (Full/PEFT), test, merge, and save Large Language Models. Includes PEFT, experimental multi-modal capabilities, reward modeling setup, and resource checks.")
with gr.Accordion("🔑 Authentication & Settings", open=False):
with gr.Row():
hf_token_read = gr.Textbox(label="🤗 HF Token (Read - Optional, for private models)", type="password", interactive=True, placeholder="hf_...")
hf_token_write = gr.Textbox(label="🤗 HF Token (Write - Optional, for Hub upload/training)", type="password", interactive=True, placeholder="hf_...")
train_wandb_token_inp = gr.Textbox(label="📊 WandB Token (Optional, for logging runs)", type="password", interactive=True)
with gr.Row():
bypass_limits_chk = gr.Checkbox(label="Bypass RAM/Disk Limits (Use with Caution!)", value=False, interactive=True)
with gr.Tabs():
with gr.TabItem("💾 Load, Save & Merge"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Load Model for Modification & Inference")
load_model_selector = HuggingfaceHubSearch(label="Search Hub or Enter Path/ID", placeholder="google/gemma-2b")
load_button = gr.Button("Load Model", variant="primary")
load_status_output = gr.Textbox(label="Load Status", interactive=False, lines=1)
with gr.Column(scale=2):
gr.Markdown("### Save Current Model State")
save_path_inp = gr.Textbox(label="Local Save Path (Optional)", placeholder="./saved_models/my_modified_model", interactive=True)
save_hub_repo_inp = gr.Textbox(label="Hub Repo ID (Optional, e.g., user/repo)", placeholder="username/my-cool-llm", interactive=True)
save_button = gr.Button("Save Model", variant="secondary")
save_status_output = gr.Textbox(label="Save Status", interactive=False, lines=1)
gr.Markdown("---")
gr.Markdown("### Merge Model Architectures (Parameter Averaging - Experimental)")
gr.Markdown("⚠️ **Experimental:** Averages parameters of models with compatible layers. Enter comma-separated Model IDs/Paths. The first model's config and tokenizer will be used as the base.")
merge_model_ids_inp = gr.Textbox(label="Model IDs/Paths to Merge (comma-separated)", placeholder="org/model-a, org/model-b, ./local-model-c")
merge_button = gr.Button("Merge Architectures", variant="primary")
merge_status_output = gr.Textbox(label="Merge Status", interactive=False, lines=2)
with gr.TabItem("🚀 Training"):
gr.Markdown("Fine-tune a model based on a selected base model. Supports Full fine-tuning and PEFT (LoRA). Apply modifications post-training.")
gr.Markdown("### 1. Base Model & Output Name")
with gr.Row():
train_model_selector = HuggingfaceHubSearch(label="Search & Select Base Model for Training", placeholder="Type to search Hugging Face Hub...")
with gr.Row():
train_new_model_inp = gr.Textbox(label="New Model Name (for saving locally and optionally on Hub)", placeholder="MyTunedModel-v1", interactive=True)
gr.Markdown("### 2. Training Data")
with gr.Row():
train_dataset_selector = HuggingfaceHubSearch(label="Search Datasets on Hub (or specify local below)")
train_datasets_inp = gr.Textbox(
label="Datasets (one per line: 'id[,config[,split[,weight]]]')",
placeholder="Example:\nopenwebtext\nwikitext,wikitext-103-raw-v1,train,0.5\nmy_local_dataset_path,,train,1.5\nusername/my_dataset,my_config,validation,2.0",
lines=5, interactive=True)
gr.Markdown("### 3. Training Configuration")
with gr.Accordion("Training Mode & Hyperparameters", open=True):
train_use_peft_chk = gr.Checkbox(label="Enable PEFT (LoRA) Training", value=True, interactive=True)
with gr.Row():
train_lr_inp = gr.Number(value=LEARNING_RATE, label="Learning Rate", interactive=True, minimum=1e-8, step=1e-6, precision=8)
train_epochs_inp = gr.Number(value=EPOCHS, label="Epochs (Set <= 0 if using Max Steps)", precision=0, minimum=-1, interactive=True)
train_max_steps_inp = gr.Number(value=MAX_STEPS, label="Max Steps (Set <= 0 if using Epochs)", precision=0, minimum=-1, interactive=True)
with gr.Row():
train_batch_size_inp = gr.Number(value=BATCH_SIZE, label="Batch Size (Per Device)", precision=0, minimum=1, interactive=True)
train_grad_accum_inp = gr.Number(value=GRADIENT_ACCUMULATION_STEPS, label="Grad Accum Steps", precision=0, minimum=1, interactive=True)
train_optim_selector = gr.Dropdown(choices=list(OPTIMIZERS.keys()), value=DEFAULT_OPTIMIZER, label="Optimizer", interactive=True)
with gr.Row():
train_scheduler_selector = gr.Dropdown(choices=SCHEDULER_TYPES, value=DEFAULT_SCHEDULER, label="LR Scheduler", interactive=True)
train_wd_inp = gr.Number(value=0.01, label="Weight Decay", minimum=0.0, interactive=True, step=0.001, precision=4)
train_warmup_ratio_inp = gr.Slider(0.0, 0.5, value=0.03, step=0.01, label="Warmup Ratio", interactive=True)
with gr.Accordion("PEFT Configuration (if PEFT enabled)", open=False, visible=True) as peft_config_accordion:
peft_r_inp = gr.Slider(label="LoRA r (Rank)", minimum=1, maximum=256, value=8, step=1, interactive=True)
peft_alpha_inp = gr.Slider(label="LoRA alpha", minimum=1, maximum=512, value=32, step=1, interactive=True)
peft_dropout_inp = gr.Slider(label="LoRA Dropout", minimum=0.0, maximum=0.5, value=0.1, step=0.01, interactive=True)
peft_target_modules_inp = gr.Textbox(label="Target Modules (comma-sep, optional, e.g., q_proj,v_proj)", placeholder="Leave empty for auto-detection (recommended)", interactive=True)
train_use_peft_chk.change(lambda x: gr.update(visible=x), inputs=train_use_peft_chk, outputs=peft_config_accordion)
with gr.Accordion("Post-Training Modifications (Applied After Training)", open=False):
with gr.Row():
train_post_activation_fn_selector = gr.Dropdown(choices=list(ACTIVATION_FUNCTIONS.keys()), value=DEFAULT_ACTIVATION_FUNCTION, label="Target Activation Fn")
train_post_target_layers_inp = gr.Number(value=LAYERS, label="Target Layer Count", precision=0, minimum=1)
with gr.Accordion("Hardware & Logging", open=False):
train_use_cpu_chk = gr.Checkbox(value=USE_CPU, label="Force Use CPU (Very Slow!)", interactive=True)
gr.Markdown("### 4. Start Training")
train_button = gr.Button("✨ Start Training Process", variant="primary")
train_output = gr.Textbox(label="Training Log & Status", interactive=False, lines=20, max_lines=50)
with gr.TabItem("🔧 Model Controls"):
gr.Markdown("Interactively toggle modifications and filters for the **currently loaded** model. Refresh status after changes.")
with gr.Row():
refresh_status_button = gr.Button("🔄 Refresh Status & Filter Checkboxes")
control_output = gr.Textbox(label="Control Action Status", interactive=False, lines=1)
status_output = gr.TextArea(label="Current Model Status (JSON)", interactive=False, lines=20, max_lines=60)
with gr.Tabs():
with gr.TabItem("Core & Structure"):
with gr.Row():
with gr.Column(min_width=150): bias_on = gr.Button("Bias Rem. ✅"); bias_off = gr.Button("Bias Rem. ❌")
with gr.Column(min_width=150): emb_on = gr.Button("Emb. Untie ✅"); emb_off = gr.Button("Emb. Untie ❌")
layer_target_inp = gr.Number(value=LAYERS, label="Target Layers", precision=0, minimum=1, interactive=True, scale=1)
layer_red_on = gr.Button("Apply Layer Red.", scale=1)
layer_red_off = gr.Button("Revert Layer Red.", scale=1)
with gr.Row():
with gr.Column(min_width=150): norm_swap_rms = gr.Button("Use RMSNorm"); norm_swap_ln = gr.Button("Use LayerNorm")
act_select = gr.Dropdown(choices=list(ACTIVATION_FUNCTIONS.keys()), value=DEFAULT_ACTIVATION_FUNCTION, label="Change ActFn")
act_revert = gr.Button("Revert ActFn")
with gr.Column(min_width=150): bitnet_on = gr.Button("BitNet ✅"); bitnet_off = gr.Button("BitNet ❌")
with gr.Accordion("Multi-Modal Conversion (Experimental)", open=False):
gr.Markdown("⚠️ **Experimental:** Adds modality-specific encoders (e.g., ViT, Whisper) and projection layers. **Requires manual `forward` pass adaptation & multi-modal data/training.**")
modality_checkboxes_ui = gr.CheckboxGroup(choices=AVAILABLE_MODALITIES, label="Select Modalities")
with gr.Row():
apply_multimodal_button = gr.Button("Apply Multi-Modal Setup")
revert_multimodal_button = gr.Button("Revert Multi-Modal Setup")
with gr.TabItem("Performance & Opt."):
with gr.Row():
with gr.Column(min_width=150): speed_on = gr.Button("Speed Opt. ✅"); speed_off = gr.Button("Speed Opt. ❌")
with gr.Column(min_width=150): coher_on = gr.Button("Coherence ✅"); coher_off = gr.Button("Coherence ❌")
with gr.Column(min_width=150): ln_bypass_on = gr.Button("LN Bypass ✅"); ln_bypass_off = gr.Button("LN Bypass ❌")
with gr.Row():
with gr.Column(min_width=150): do_bypass_on = gr.Button("Dropout Bypass ✅"); do_bypass_off = gr.Button("Dropout Bypass ❌")
with gr.Column(min_width=150): prec_on = gr.Button("FP32 Prec. ✅"); prec_off = gr.Button("FP32 Prec. ❌")
with gr.Column(min_width=150): norm_emb_on = gr.Button("Emb. Norm. ✅"); norm_emb_off = gr.Button("Emb. Norm. ❌")
with gr.Row():
with gr.Column(min_width=150): gc_cp_on = gr.Button("Grad Checkpoint ✅"); gc_cp_off = gr.Button("Grad Checkpoint ❌")
with gr.Column(min_width=150): flash_attn_on = gr.Button("Flash Attn 2 ✅"); flash_attn_off = gr.Button("Flash Attn 2 ❌")
with gr.Accordion("Quantization & Pruning", open=False):
with gr.Row():
quant_select = gr.Dropdown(choices=QUANTIZATION_MODES, value=DEFAULT_QUANTIZATION, label="Quantize To")
quant_apply = gr.Button("Apply Quant.")
quant_revert = gr.Button("Revert Quant.")
with gr.Row():
prune_amount_inp = gr.Slider(0.01, 0.95, value=PRUNING_AMOUNT, step=0.01, label="Prune Amount")
prune_apply = gr.Button("Apply Pruning")
prune_revert = gr.Button("Revert Pruning")
with gr.TabItem("PEFT Adapters"):
gr.Markdown("Add, merge, or remove LoRA/PEFT adapters from the currently loaded model.")
peft_lora_path_input = gr.Textbox(label="LoRA/PEFT Adapter Path or Hub ID", placeholder="username/my-lora-adapter")
with gr.Row():
peft_set_path_btn = gr.Button("Set Path in Config")
peft_add_adapter_btn = gr.Button("Add Default Adapter")
peft_merge_btn = gr.Button("Merge Active Adapter")
peft_remove_adapter_btn = gr.Button("Remove/Unload Adapter")
with gr.TabItem("Advanced Config & Layers"):
with gr.Row():
freeze_input = gr.Textbox(label="Layers to Freeze (e.g., '0-3, 7, 10-11')")
freeze_apply = gr.Button("🧊 Freeze")
freeze_revert = gr.Button("🔥 Unfreeze All")
with gr.Row():
with gr.Column(min_width=150): lim_on = gr.Button("Limits Cfg ✅"); lim_off = gr.Button("Limits Cfg ❌")
with gr.Column(min_width=150): qa_on = gr.Button("QA Restrict Rem. ✅"); qa_off = gr.Button("QA Restrict Rem. ❌")
layerdrop_prob_inp = gr.Slider(0.0, 0.5, value=0.1, step=0.01, label="LayerDrop Prob")
layerdrop_on = gr.Button("LayerDrop Flag ✅")
layerdrop_off = gr.Button("LayerDrop Flag ❌")
with gr.Accordion("RoPE, Sliding Window, Attention Variant (Require Model Reload)", open=False):
gr.Markdown("**Warning:** These settings modify the config but require reloading the model to take effect.")
with gr.Row():
rope_type_inp = gr.Dropdown(label="RoPE Type", choices=["linear", "dynamic"], value="linear")
rope_factor_inp = gr.Number(label="RoPE Factor (>=1.0)", value=2.0, minimum=1.0, step=0.1)
rope_apply_btn = gr.Button("Set RoPE")
rope_revert_btn = gr.Button("Revert RoPE")
with gr.Row():
sw_size_inp = gr.Number(label="Sliding Window Size (0=disable)", value=4096, minimum=0, step=64)
sw_apply_btn = gr.Button("Set Sliding Window")
sw_revert_btn = gr.Button("Revert Sliding Window")
with gr.Row():
attn_variant_inp = gr.Dropdown(label="Attention Implementation", choices=["auto", "eager", "sdpa", "flash_attention_2"], value="auto")
attn_apply_btn = gr.Button("Set Attention Variant")
attn_revert_btn = gr.Button("Revert Attention Variant")
with gr.Accordion("KD & Reward Heads (Experimental - Requires Training Changes)", open=False):
with gr.Row():
kd_labels_inp = gr.Number(label="KD Num Labels", value=2, minimum=1, precision=0)
kd_setup_btn = gr.Button("Setup KD Head")
kd_revert_btn = gr.Button("Revert KD Head")
with gr.Row():
rm_outputs_inp = gr.Number(label="RM Num Outputs", value=1, minimum=1, precision=0)
rm_setup_btn = gr.Button("Setup RM Head")
rm_revert_btn = gr.Button("Revert RM Head")
with gr.Accordion("Other Flags (Symbolic - May Require Specific Training Logic)", open=False):
with gr.Row():
swa_on = gr.Button("SWA Flag ✅"); swa_off = gr.Button("SWA Flag ❌")
ke_on = gr.Button("Know. Edit Flag ✅"); ke_off = gr.Button("Know. Edit Flag ❌")
hp_on = gr.Button("Head Prune Flag ✅"); hp_off = gr.Button("Head Prune Flag ❌")
with gr.Row():
qat_on = gr.Button("QAT Flag ✅"); qat_off = gr.Button("QAT Flag ❌")
gn_on = gr.Button("Grad Noise Flag ✅"); gn_off = gr.Button("Grad Noise Flag ❌")
wi_on = gr.Button("Weight Init Flag ✅"); wi_off = gr.Button("Weight Init Flag ❌")
with gr.TabItem("Training Param Flags"):
gr.Markdown("Toggle flags in the config that affect **subsequent** Trainer initialization (won't affect current training).")
with gr.Row():
with gr.Column(min_width=150): gc_flag_on = gr.Button("GradClip Flg ✅"); gc_flag_off = gr.Button("GradClip Flg ❌")
with gr.Column(min_width=150): wd_flag_on = gr.Button("WD Flg ✅"); wd_flag_off = gr.Button("WD Flg ❌")
with gr.Column(min_width=150): lr_flag_on = gr.Button("LR Sched. Flg ✅"); lr_flag_off = gr.Button("LR Sched. Flg ❌")
with gr.Row():
optim_flag_select = gr.Dropdown(choices=list(OPTIMIZERS.keys()), value=DEFAULT_OPTIMIZER, label="Set Optim. Pref")
optim_flag_apply = gr.Button("Apply Optim.")
optim_flag_revert = gr.Button("Revert Optim.")
with gr.Row():
grad_accum_ui_inp_config = gr.Number(value=GRADIENT_ACCUMULATION_STEPS, label="Grad Accum Steps (Config)", precision=0, minimum=1)
grad_accum_set_btn = gr.Button("Set Grad Accum")
with gr.TabItem("🔒 Safety & Content Filters"):
gr.Markdown("Control safety filter flags in the model's config. Actual filtering effectiveness depends on the inference implementation.")
with gr.Row():
safety_all_on = gr.Button("🔒 Enable ALL Filters (Defaults)", variant="secondary")
safety_all_off = gr.Button("🔓 Disable ALL Filters", variant="stop")
gr.Markdown("Individual Filter Toggles:")
filter_checkboxes = []
num_cols = 4
for i in range(0, len(filter_names_ui), num_cols):
with gr.Row():
for j in range(num_cols):
idx = i + j
if idx < len(filter_names_ui):
name = filter_names_ui[idx]
cb = gr.Checkbox(label=name, value=False, interactive=True)
filter_checkboxes.append(cb)
else:
gr.HTML("")
apply_filters_button = gr.Button("Apply Individual Filter Toggles", variant="secondary")
with gr.TabItem("💬 Inference"):
gr.Markdown("Test the **currently loaded and configured** model.")
with gr.Row():
inference_prompt = gr.Textbox(label="Enter Prompt", lines=4, placeholder="Once upon a time...")
inference_output = gr.Textbox(label="Model Response", interactive=False, lines=15)
with gr.Accordion("Generation Parameters", open=True):
with gr.Row():
max_new_tokens_slider = gr.Slider(10, 4096, value=256, step=10, label="Max New Tokens", interactive=True)
temperature_slider = gr.Slider(0.0, 2.0, value=0.7, step=0.01, label="Temperature (0=greedy)", interactive=True)
with gr.Row():
top_k_slider = gr.Slider(0, 200, value=50, step=1, label="Top-K (0=disable)", interactive=True)
top_p_slider = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-P (0 or 1=disable)", interactive=True)
repetition_penalty_slider = gr.Slider(1.0, 3.0, value=1.1, step=0.05, label="Repetition Penalty (1=disable)", interactive=True)
generate_button = gr.Button("Generate Response", variant="primary")
with gr.TabItem("🚫 Censor Control"):
gr.Markdown("## Force Disable Censorship Flags")
gr.Markdown("Click the button below to attempt to set all known censorship/filter flags in the loaded model's configuration to `False`. This uses the `Disable ALL Filters` function.")
censor_off_button = gr.Button("🔓 Attempt Force Disable All Censorship Flags", variant="stop")
censor_status = gr.Textbox(label="Censorship Flag Status", interactive=False, lines=2)
load_button.click(
fn=load_model_for_control,
inputs=[load_model_selector, hf_token_read, bypass_limits_chk],
outputs=[load_status_output, status_output] + filter_checkboxes
)
save_button.click(
fn=save_current_model,
inputs=[save_path_inp, hf_token_write, save_hub_repo_inp],
outputs=save_status_output
)
merge_button.click(
fn=_merge_architectures,
inputs=[merge_model_ids_inp, hf_token_read, bypass_limits_chk],
outputs=[merge_status_output, status_output] + filter_checkboxes
)
train_button.click(
fn=start_training,
inputs=[
train_model_selector, train_new_model_inp, hf_token_write,
train_datasets_inp,
train_post_activation_fn_selector, train_post_target_layers_inp,
train_grad_accum_inp, train_lr_inp, train_epochs_inp, train_max_steps_inp, train_batch_size_inp,
train_optim_selector, train_scheduler_selector, train_wd_inp, train_warmup_ratio_inp,
train_use_peft_chk, peft_r_inp, peft_alpha_inp, peft_dropout_inp, peft_target_modules_inp,
train_wandb_token_inp, train_use_cpu_chk, bypass_limits_chk
],
outputs=train_output
).then(
fn=get_detailed_status_and_filter_states,
inputs=None,
outputs=[status_output] + filter_checkboxes
)
refresh_outputs = [status_output] + filter_checkboxes
refresh_status_button.click(fn=get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs)
def link_control(button, func, inputs=None):
processed_inputs = inputs if inputs else []
click_event = button.click(func, inputs=processed_inputs, outputs=control_output)
click_event.then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs)
link_control(bias_on, lambda: toggle_bias_removal_wrapper(True))
link_control(bias_off, lambda: toggle_bias_removal_wrapper(False))
link_control(emb_on, lambda: toggle_embeddings_untie_wrapper(True))
link_control(emb_off, lambda: toggle_embeddings_untie_wrapper(False))
link_control(layer_red_on, lambda layers: toggle_layer_reduction_wrapper(True, layers), inputs=[layer_target_inp])
link_control(layer_red_off, lambda: toggle_layer_reduction_wrapper(False, None))
link_control(norm_swap_rms, lambda: apply_norm_swap_wrapper('RMSNorm'))
link_control(norm_swap_ln, lambda: apply_norm_swap_wrapper('LayerNorm'))
act_select.change(lambda name: apply_activation_change_wrapper(name), inputs=[act_select], outputs=control_output).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs)
link_control(act_revert, revert_activation_change_wrapper)
link_control(bitnet_on, lambda: toggle_bitnet_wrapper(True))
link_control(bitnet_off, lambda: toggle_bitnet_wrapper(False))
link_control(apply_multimodal_button, apply_multimodal_wrapper, inputs=[modality_checkboxes_ui])
link_control(revert_multimodal_button, revert_multimodal_wrapper)
link_control(speed_on, lambda: toggle_token_speed_optimization_wrapper(True))
link_control(speed_off, lambda: toggle_token_speed_optimization_wrapper(False))
link_control(coher_on, lambda: toggle_coherence_improvement_wrapper(True))
link_control(coher_off, lambda: toggle_coherence_improvement_wrapper(False))
link_control(ln_bypass_on, lambda: toggle_layer_norm_bypass_wrapper(True))
link_control(ln_bypass_off, lambda: toggle_layer_norm_bypass_wrapper(False))
link_control(do_bypass_on, lambda: toggle_dropout_bypass_wrapper(True))
link_control(do_bypass_off, lambda: toggle_dropout_bypass_wrapper(False))
link_control(prec_on, lambda: toggle_fp32_precision_wrapper(True))
link_control(prec_off, lambda: toggle_fp32_precision_wrapper(False))
link_control(norm_emb_on, lambda: toggle_embedding_normalization_wrapper(True))
link_control(norm_emb_off, lambda: toggle_embedding_normalization_wrapper(False))
link_control(gc_cp_on, lambda: toggle_gradient_checkpointing_wrapper(True))
link_control(gc_cp_off, lambda: toggle_gradient_checkpointing_wrapper(False))
link_control(flash_attn_on, lambda: toggle_flash_attention_wrapper(True))
link_control(flash_attn_off, lambda: toggle_flash_attention_wrapper(False))
link_control(quant_apply, apply_quantization_wrapper, inputs=[quant_select])
link_control(quant_revert, revert_quantization_wrapper)
link_control(prune_apply, apply_pruning_wrapper, inputs=[prune_amount_inp])
link_control(prune_revert, revert_pruning_wrapper)
link_control(peft_set_path_btn, set_lora_path_wrapper, inputs=[peft_lora_path_input])
link_control(peft_add_adapter_btn, add_peft_adapter_wrapper)
link_control(peft_merge_btn, merge_peft_adapter_wrapper)
link_control(peft_remove_adapter_btn, remove_peft_adapter_wrapper)
link_control(freeze_apply, apply_layer_freeze_wrapper, inputs=[freeze_input])
link_control(freeze_revert, revert_layer_freeze_wrapper)
link_control(lim_on, lambda: toggle_limits_wrapper(True))
link_control(lim_off, lambda: toggle_limits_wrapper(False))
link_control(qa_on, lambda: toggle_qa_restrictions_wrapper(True))
link_control(qa_off, lambda: toggle_qa_restrictions_wrapper(False))
link_control(layerdrop_on, lambda prob: toggle_layerdrop_wrapper(True, prob), inputs=[layerdrop_prob_inp])
link_control(layerdrop_off, lambda: toggle_layerdrop_wrapper(False))
link_control(rope_apply_btn, lambda type, factor: toggle_rope_scaling_wrapper(True, type, factor), inputs=[rope_type_inp, rope_factor_inp])
link_control(rope_revert_btn, lambda: toggle_rope_scaling_wrapper(False))
link_control(sw_apply_btn, lambda size: toggle_sliding_window_wrapper(True, size), inputs=[sw_size_inp])
link_control(sw_revert_btn, lambda: toggle_sliding_window_wrapper(False))
link_control(attn_apply_btn, apply_attention_variant_wrapper, inputs=[attn_variant_inp])
link_control(attn_revert_btn, revert_attention_variant_wrapper)
link_control(kd_setup_btn, lambda labels: toggle_kd_wrapper(True, labels), inputs=[kd_labels_inp])
link_control(kd_revert_btn, lambda: toggle_kd_wrapper(False))
link_control(rm_setup_btn, lambda outputs: toggle_reward_modeling_wrapper(True, outputs), inputs=[rm_outputs_inp])
link_control(rm_revert_btn, lambda: toggle_reward_modeling_wrapper(False))
link_control(swa_on, lambda: specific_action_function(_apply_swa))
link_control(swa_off, lambda: specific_action_function(_revert_swa))
link_control(ke_on, lambda: specific_action_function(_apply_knowledge_editing))
link_control(ke_off, lambda: specific_action_function(_revert_knowledge_editing))
link_control(hp_on, lambda: specific_action_function(_apply_head_pruning))
link_control(hp_off, lambda: specific_action_function(_revert_head_pruning))
link_control(qat_on, lambda: specific_action_function(_apply_qat))
link_control(qat_off, lambda: specific_action_function(_revert_qat))
link_control(gn_on, lambda: specific_action_function(_apply_gradient_noise))
link_control(gn_off, lambda: specific_action_function(_revert_gradient_noise))
link_control(wi_on, lambda: specific_action_function(_apply_weight_init))
link_control(wi_off, lambda: specific_action_function(_revert_weight_init))
link_control(gc_flag_on, lambda: toggle_gradient_clipping_flag_wrapper(True))
link_control(gc_flag_off, lambda: toggle_gradient_clipping_flag_wrapper(False))
link_control(wd_flag_on, lambda: toggle_weight_decay_flag_wrapper(True))
link_control(wd_flag_off, lambda: toggle_weight_decay_flag_wrapper(False))
link_control(lr_flag_on, lambda: toggle_lr_scheduler_flag_wrapper(True))
link_control(lr_flag_off, lambda: toggle_lr_scheduler_flag_wrapper(False))
link_control(optim_flag_apply, apply_optimizer_change_wrapper, inputs=[optim_flag_select])
link_control(optim_flag_revert, revert_optimizer_change_wrapper)
link_control(grad_accum_set_btn, set_gradient_accumulation_wrapper, inputs=[grad_accum_ui_inp_config])
link_control(safety_all_on, lambda: toggle_all_safety_filters_wrapper(True))
link_control(safety_all_off, lambda: toggle_all_safety_filters_wrapper(False))
apply_filters_button.click(
fn=toggle_individual_safety_filter_wrapper,
inputs=filter_checkboxes,
outputs=control_output
).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs)
generate_button.click(
fn=run_inference,
inputs=[
inference_prompt, max_new_tokens_slider, temperature_slider,
top_k_slider, top_p_slider, repetition_penalty_slider
],
outputs=inference_output
)
censor_off_button.click(
fn=force_disable_censorship_wrapper,
outputs=censor_status
).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", share=True, debug=False)