| import gradio as gr | |
| import os | |
| import unicodedata | |
| import logging | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import time | |
| from transformers import AutoConfig, get_scheduler, pipeline, AutoTokenizer, AutoModelForPreTraining, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModel | |
| from huggingface_hub import HfApi | |
| from tqdm import tqdm | |
| from torch.utils.tensorboard import SummaryWriter | |
| import numpy as np | |
| import psutil | |
| import signal | |
| import sys | |
| import gc | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| warnings.filterwarnings("ignore", message="cannot set number of interop threads after parallel work has started or set_num_interop_threads called") | |
| torch.set_num_threads(os.cpu_count() // 2 if os.cpu_count() > 1 else 1) | |
| device = torch.device("cpu") | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| MODELS = {} | |
| LOG_TEXT = "" | |
| GRADIO_LOG_OUTPUT = None | |
| HF_API = HfApi() | |
| AUTH_TOKEN = None | |
| USER_NAME = None | |
| SAVE_ON_RESOURCE_TERMINATION = True | |
| CHECKPOINT_DIR = "./fusion_distillation_checkpoints" | |
| TENSORBOARD_LOG_DIR = "./tensorboard_logs" | |
| SAVE_CHECKPOINTS = True | |
| CHECKPOINT_INTERVAL = 1000 | |
| LOG_INTERVAL = 100 | |
| KD_STEPS = 10000 | |
| ACCUMULATION_STEPS = 1 | |
| FREEZE_STUDENT_STEPS = 2000 | |
| ENABLE_MIXED_PRECISION = False | |
| KD_LOSS_FACTOR = 1.0 | |
| CE_LOSS_FACTOR = 0.0 | |
| KD_TEMPERATURE = 2.0 | |
| ENABLE_ATTENTION_KD = True | |
| ENABLE_HIDDEN_STATE_KD = True | |
| ENABLE_INTERMEDIATE_KD = True | |
| ENABLE_LAYER_NORM_KD = True | |
| ENABLE_EMBEDDING_KD = True | |
| ENABLE_PARAMETER_KD = True | |
| ENABLE_ACTIVATION_KD = True | |
| ENABLE_LOGIT_MASKING_KD = True | |
| ENABLE_SPARCITY_REGULARIZATION = True | |
| ENABLE_FEATURE_MAP_KD = True | |
| ENABLE_OUTPUT_LOGIT_KD = True | |
| ENABLE_LAYERWISE_PARAMETER_KD = True | |
| ENABLE_VOCAB_PROJECTION_KD = True | |
| ENABLE_CONTRASTIVE_KD = True | |
| ENABLE_RDROP_KD = True | |
| ENABLE_ADAPTIVE_TEMPERATURE_KD = True | |
| ENABLE_LAYER_WISE_KD = True | |
| ENABLE_ACTIVATION_REGULARIZATION = True | |
| ENABLE_NEURON_SELECTIVITY_KD = True | |
| ENABLE_WEIGHTED_PARAMETER_KD = True | |
| ENABLE_FSP_KD = True | |
| ENABLE_ATTENTION_ALIGNMENT_KD = True | |
| ENABLE_GRAM_MATRIX_KD = True | |
| ATTENTION_KD_FACTOR = 0.1 | |
| HIDDEN_STATE_KD_FACTOR = 0.1 | |
| INTERMEDIATE_KD_FACTOR = 0.1 | |
| LAYER_NORM_KD_FACTOR = 0.01 | |
| EMBEDDING_KD_FACTOR = 0.01 | |
| PARAMETER_KD_FACTOR = 0.001 | |
| ACTIVATION_KD_FACTOR = 0.0001 | |
| LOGIT_MASKING_FACTOR = 0.01 | |
| SPARCITY_REGULARIZATION_FACTOR = 1e-5 | |
| FEATURE_MAP_KD_FACTOR = 0.005 | |
| OUTPUT_LOGIT_KD_FACTOR = 1.0 | |
| LAYERWISE_PARAMETER_KD_FACTOR = 0.0005 | |
| VOCAB_PROJECTION_KD_FACTOR = 0.001 | |
| CONTRASTIVE_KD_FACTOR = 0.05 | |
| RDROP_KD_FACTOR = 0.02 | |
| ADAPTIVE_TEMPERATURE_KD_FACTOR = 0.01 | |
| LAYER_WISE_KD_FACTOR = 0.05 | |
| ACTIVATION_REG_LAMBDA = 1e-7 | |
| NEURON_SELECTIVITY_KD_FACTOR = 0.001 | |
| WEIGHTED_PARAMETER_KD_FACTOR = 0.0002 | |
| FSP_KD_FACTOR = 0.001 | |
| ATTENTION_ALIGNMENT_KD_FACTOR = 0.01 | |
| GRAM_MATRIX_KD_FACTOR = 0.0005 | |
| ATTENTION_KD_LOSS_TYPE = 'mse' | |
| HIDDEN_STATE_KD_LOSS_TYPE = 'mse' | |
| OUTPUT_LOGIT_KD_LOSS_TYPE = 'kl' | |
| CONTRASTIVE_KD_LOSS_TYPE = 'cosine' | |
| RDROP_KD_LOSS_TYPE = 'kl' | |
| LAYER_WISE_LOSS_TYPE = 'mse' | |
| NEURON_SELECTIVITY_LOSS_TYPE = 'mse' | |
| WEIGHTED_PARAMETER_LOSS_TYPE = 'mse' | |
| FSP_KD_LOSS_TYPE = 'mse' | |
| ATTENTION_ALIGNMENT_LOSS_TYPE = 'mse' | |
| GRAM_MATRIX_LOSS_TYPE = 'mse' | |
| INTERMEDIATE_LAYERS = [2, 5, 8] | |
| LAYER_NORM_MODULES = ['LayerNorm'] | |
| ACTIVATION_MODULES = ['Linear'] | |
| FEATURE_MAP_MODULES = ['Linear'] | |
| LAYERWISE_PARAMETER_MODULES = ['transformer.h'] | |
| VOCAB_PROJECTION_MODULES = ['lm_head'] | |
| LAYER_WISE_MODULES = ['transformer.h'] | |
| NEURON_SELECTIVITY_MODULES = ['Linear'] | |
| WEIGHTED_PARAMETER_MODULES = ['Linear'] | |
| FSP_MODULES = ['Linear'] | |
| ATTENTION_ALIGNMENT_MODULES = ['SelfAttention'] | |
| GRAM_MATRIX_MODULES = ['Linear'] | |
| ADAPTIVE_TEMPERATURE_INITIAL = 2.0 | |
| ADAPTIVE_TEMPERATURE_DECAY_RATE = 0.999 | |
| LR_INITIAL = 5e-5 | |
| WEIGHT_DECAY = 1e-4 | |
| SCHEDULER_TYPE = "linear" | |
| WARMUP_STEPS = 500 | |
| ENABLE_LR_SCHEDULER = True | |
| ENABLE_STUDENT_PARAMETER_FREEZE = True | |
| ENABLE_DYNAMIC_FREEZE = False | |
| FREEZE_THRESHOLD = 0.1 | |
| ENABLE_GRADIENT_CLIPPING = True | |
| GRAD_CLIP_VALUE = 1.0 | |
| ENABLE_EARLY_STOPPING = True | |
| EARLY_STOPPING_PATIENCE = 5 | |
| ENABLE_PARAMETER_COUNT_CHECK = False | |
| ENABLE_EMBEDDING_NOISE = True | |
| EMBEDDING_NOISE_STD = 1e-4 | |
| ENABLE_L2_REGULARIZATION = False | |
| L2_LAMBDA = 1e-6 | |
| ENABLE_L1_REGULARIZATION = False | |
| L1_LAMBDA = 1e-6 | |
| def cleanup_and_exit(signal_code): | |
| log_message("Initiating cleanup...", level="info") | |
| if MODELS.get('student') and SAVE_ON_RESOURCE_TERMINATION: | |
| try: | |
| student_model = MODELS['student'] | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| checkpoint_path = os.path.join(CHECKPOINT_DIR, "student_model_terminated.pt") | |
| save_checkpoint(student_model, checkpoint_path) | |
| log_message(f"Model checkpoint saved to {checkpoint_path} due to termination signal.", level="info") | |
| except Exception as e: | |
| log_message(f"Error during checkpoint save on cleanup: {e}", level="error") | |
| if MODELS.get('writer'): | |
| try: | |
| MODELS['writer'].close() | |
| log_message("TensorBoard writer closed.", level="info") | |
| except Exception as e: | |
| log_message(f"Error closing TensorBoard writer during cleanup: {e}", level="error") | |
| log_message(f"Cleanup completed.", level="info") | |
| def signal_handler(sig, frame): | |
| log_message(f"Received signal {sig}. Continuing process...", level="warning") | |
| signal.signal(signal.SIGTERM, signal_handler) | |
| signal.signal(signal.SIGINT, signal_handler) | |
| def log_message(message, level="info"): | |
| global LOG_TEXT, GRADIO_LOG_OUTPUT | |
| log_str = f"{message}\n" | |
| LOG_TEXT += log_str | |
| if GRADIO_LOG_OUTPUT: | |
| GRADIO_LOG_OUTPUT.value = LOG_TEXT | |
| if level == "info": | |
| logger.info(message) | |
| elif level == "warning": | |
| logger.warning(message) | |
| elif level == "error": | |
| logger.error(message) | |
| print(message) | |
| tqdm.write(message) | |
| def log_to_file(message, log_file="training_log.txt"): | |
| try: | |
| with open(log_file, "a") as f: | |
| f.write(f"{message}\n") | |
| except Exception as e: | |
| log_message(f"Error writing to log file: {e}", level="warning") | |
| def unify_parameters(student_model, teacher_model, exclude_layers=None): | |
| try: | |
| teacher_state = teacher_model.model.state_dict() | |
| student_state = student_model.model.state_dict() | |
| excluded_names = exclude_layers or [] | |
| for name, param in student_state.items(): | |
| if any(excluded_name in name for excluded_name in excluded_names): | |
| continue | |
| if name in teacher_state: | |
| try: | |
| if student_state[name].shape == teacher_state[name].shape: | |
| student_state[name].copy_(teacher_state[name]) | |
| else: | |
| min_shape = [min(s, t) for s, t in zip(student_state[name].shape, teacher_state[name].shape)] | |
| student_slice = tuple([slice(0, s) for s in min_shape]) | |
| teacher_slice = tuple([slice(0, s) for s in min_shape]) | |
| student_state[name][student_slice].copy_(teacher_state[name][teacher_slice]) | |
| except Exception as e: | |
| log_message(f"Parameter copy error for {name}: {e}", level="warning") | |
| student_model.model.load_state_dict(student_state, strict=False) | |
| except Exception as e: | |
| log_message(f"Error in unify_parameters: {e}", level="warning") | |
| def unify_embeddings(student_model, teacher_model, project_embeddings=True, mean_resizing=False): | |
| try: | |
| if hasattr(student_model.model, "get_input_embeddings") and hasattr(teacher_model.model, "get_input_embeddings"): | |
| student_emb = student_model.model.get_input_embeddings() | |
| teacher_emb = teacher_model.model.get_input_embeddings() | |
| if project_embeddings: | |
| in_dim = teacher_emb.weight.shape[1] | |
| out_dim = student_emb.weight.shape[1] | |
| projection = nn.Linear(in_dim, out_dim).to(device) | |
| teacher_emb_projected = projection(teacher_emb.weight) | |
| student_vocab_size = student_emb.weight.shape[0] | |
| teacher_vocab_size_proj = teacher_emb_projected.shape[0] | |
| if mean_resizing and student_vocab_size < teacher_vocab_size_proj: | |
| try: | |
| student_model.model.resize_token_embeddings(teacher_vocab_size_proj) | |
| student_emb = student_model.model.get_input_embeddings() | |
| except Exception as e: | |
| log_message(f"Error resizing student embeddings: {e}", level="warning") | |
| min_vocab_size = min(student_emb.weight.shape[0], teacher_vocab_size_proj) | |
| try: | |
| student_emb.weight.data[:min_vocab_size].copy_(teacher_emb_projected.data[:min_vocab_size]) | |
| except Exception as e: | |
| log_message(f"Error copying projected embeddings: {e}", level="warning") | |
| else: | |
| min_vocab_size = min(student_emb.weight.shape[0], teacher_emb.weight.shape[0]) | |
| try: | |
| student_emb.weight.data[:min_vocab_size].copy_(teacher_emb.weight.data[:min_vocab_size]) | |
| except Exception as e: | |
| log_message(f"Error copying embeddings: {e}", level="warning") | |
| if hasattr(student_model.model, "get_output_embeddings") and hasattr(teacher_model.model, "get_output_embeddings"): | |
| student_out_emb = student_model.model.get_output_embeddings() | |
| teacher_out_emb = teacher_model.model.get_output_embeddings() | |
| if student_out_emb is not None and teacher_out_emb is not None: | |
| if project_embeddings: | |
| in_dim = teacher_out_emb.weight.shape[1] | |
| out_dim = student_out_emb.weight.shape[1] | |
| projection = nn.Linear(in_dim, out_dim).to(device) | |
| teacher_out_emb_projected = projection(teacher_out_emb.weight) | |
| student_vocab_size = student_out_emb.weight.shape[0] | |
| teacher_vocab_size_proj = teacher_out_emb_projected.shape[0] | |
| if mean_resizing and student_vocab_size < teacher_vocab_size_proj: | |
| try: | |
| student_model.model.resize_token_embeddings(teacher_vocab_size_proj) | |
| student_out_emb = student_model.model.get_output_embeddings() | |
| except Exception as e: | |
| log_message(f"Error resizing student output embeddings: {e}", level="warning") | |
| min_vocab_size = min(student_out_emb.weight.shape[0], teacher_vocab_size_proj) | |
| try: | |
| student_out_emb.weight.data[:min_vocab_size].copy_(teacher_out_emb_projected.data[:min_vocab_size]) | |
| except Exception as e: | |
| log_message(f"Error copying projected output embeddings: {e}", level="warning") | |
| else: | |
| min_vocab_size = min(student_out_emb.weight.shape[0], teacher_out_emb.weight.shape[0]) | |
| try: | |
| student_out_emb.weight.data[:min_vocab_size].copy_(teacher_out_emb.weight.data[:min_vocab_size]) | |
| except Exception as e: | |
| log_message(f"Error copying output embeddings: {e}", level="warning") | |
| except Exception as e: | |
| log_message(f"Error in unify_embeddings: {e}", level="warning") | |
| def unify_tokenizers(student_tokenizer, teacher_tokenizer, student_model): | |
| if teacher_tokenizer is None: | |
| return student_tokenizer | |
| try: | |
| teacher_vocab = teacher_tokenizer.get_vocab() | |
| student_vocab = student_tokenizer.get_vocab() | |
| new_tokens = [token for token in teacher_vocab if token not in student_vocab] | |
| if new_tokens: | |
| student_tokenizer.add_tokens(new_tokens) | |
| student_model.model.resize_token_embeddings(len(student_tokenizer)) | |
| except Exception as e: | |
| log_message(f"Tokenizer unification error: {e}", level="warning") | |
| return student_tokenizer | |
| def normalize_text(text): | |
| return unicodedata.normalize('NFKC', text) | |
| def generate_predictions(model, tokenizer, texts, max_length=150, **tokenizer_kwargs): | |
| try: | |
| inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=max_length, use_fast=True, **tokenizer_kwargs).to(device) | |
| outputs = model.model.generate(**inputs) | |
| return list(map(lambda output: tokenizer.decode(output, skip_special_tokens=True), outputs)) | |
| except Exception as e: | |
| log_message(f"Prediction generation error: {e}", level="warning") | |
| return [""] * len(texts) | |
| fusion_functions = { | |
| 'geometric_mean_double': lambda teacher_states, layer_scale, key, device: geometric_mean_fusion_double(teacher_states, device, len(teacher_states), layer_scale, key), | |
| } | |
| def geometric_mean_fusion_double(states, device, num_teachers, layer_scale, key): | |
| min_len = min([state[key].shape[0] for state in states]) | |
| log_abs_sum = sum([np.log(torch.abs(state[key][:min_len].detach().cpu().numpy())) for state in states]) | |
| sign_prod = np.prod([np.sign(state[key][:min_len].detach().cpu().numpy()) for state in states]) | |
| geometric_mean_val = torch.tensor(np.exp(log_abs_sum * (1.0 / num_teachers)) * sign_prod, device=device) * layer_scale | |
| return geometric_mean_val | |
| def complete_unify_teacher_models_double(teacher_models, fusion_method='geometric_mean_double', layer_scale=1.0, layer_weights=None, device=device): | |
| teacher_states = [teacher.model.state_dict() for teacher in teacher_models] | |
| unified_state = {} | |
| first_teacher_state = teacher_states[0] | |
| for key in first_teacher_state.keys(): | |
| teacher_layer_states = [] | |
| all_teachers_have_layer = True | |
| for state in teacher_states: | |
| if key in state: | |
| teacher_layer_states.append(state) | |
| else: | |
| all_teachers_have_layer = False | |
| break | |
| if all_teachers_have_layer: | |
| if fusion_method in fusion_functions: | |
| unified_state[key] = fusion_functions['geometric_mean_double'](teacher_layer_states, layer_scale, key, device=device) | |
| else: | |
| layer_sum = torch.stack([state[key] for state in teacher_layer_states]).sum(dim=0) | |
| unified_state[key] = layer_sum / len(teacher_models) | |
| else: | |
| unified_state[key] = first_teacher_state[key] | |
| return unified_state | |
| def unify_teacher_into_student(unified_teacher_state, student, force_parameter_copy=False): | |
| student_state = student.model.state_dict() | |
| new_state = {} | |
| for key, student_value in student_state.items(): | |
| if key in unified_teacher_state: | |
| teacher_value = unified_teacher_state[key] | |
| try: | |
| if student_value.shape == teacher_value.shape: | |
| new_state[key] = teacher_value | |
| else: | |
| min_shape = [min(s, t) for s, t in zip(student_value.shape, teacher_value.shape)] | |
| student_slice = tuple([slice(0, s) for s in min_shape]) | |
| teacher_slice = tuple([slice(0, s) for s in min_shape]) | |
| new_state[key] = student_value.clone() | |
| new_state[key][student_slice] = teacher_value[teacher_slice] | |
| except Exception as e: | |
| log_message(f"Parameter assignment error for {key}: {e}", level="warning") | |
| else: | |
| new_state[key] = student_value | |
| student.model.load_state_dict(new_state, strict=False) | |
| return student | |
| def fuse_tokenizers(teacher_tokenizers, student_tokenizer): | |
| unified_vocab = set(student_tokenizer.get_vocab().keys()) if student_tokenizer else set() | |
| for teacher_tokenizer in teacher_tokenizers: | |
| if teacher_tokenizer: | |
| teacher_vocab = set(teacher_tokenizer.get_vocab().keys()) | |
| unified_vocab = unified_vocab.union(teacher_vocab) | |
| if student_tokenizer: | |
| try: | |
| student_tokenizer.add_tokens(list(unified_vocab - set(student_tokenizer.get_vocab().keys()))) | |
| except Exception as e: | |
| log_message(f"Error fusing tokenizers: {e}", level="warning") | |
| return student_tokenizer | |
| def update_student_embeddings_double(student, teacher_models, fusion_type='geometric_mean_double', device=device, mean_resizing=False): | |
| emb_student = student.model.get_input_embeddings().weight.data | |
| student_vocab_size, student_emb_dim = emb_student.shape | |
| teacher_embeddings_proj = [] | |
| min_teacher_vocab_size = float('inf') | |
| for teacher in teacher_models: | |
| emb_teacher = teacher.model.get_input_embeddings().weight.data.detach().cpu() | |
| teacher_vocab_size, teacher_emb_dim = emb_teacher.shape | |
| min_teacher_vocab_size = min(min_teacher_vocab_size, teacher_vocab_size) | |
| if teacher_emb_dim != student_emb_dim: | |
| proj = nn.Linear(teacher_emb_dim, student_emb_dim, bias=False).to(device) | |
| emb_teacher_proj = proj(emb_teacher.to(device)).cpu() | |
| else: | |
| emb_teacher_proj = emb_teacher | |
| teacher_embeddings_proj.append(emb_teacher_proj) | |
| min_vocab = min(min_teacher_vocab_size, student_vocab_size) | |
| fusion_function = fusion_functions.get(fusion_type, fusion_functions['geometric_mean_double']) | |
| teacher_states = [{'embedding': emb} for emb in teacher_embeddings_proj] | |
| emb_updated = fusion_function(teacher_states, 1.0, 'embedding', device=device)[:min_vocab] if fusion_type != 'layerwise' else fusion_functions['geometric_mean_double'](teacher_states, 1.0, 'embedding', device=device)[:min_vocab] | |
| if ENABLE_EMBEDDING_NOISE: | |
| noise = torch.randn_like(emb_updated) * EMBEDDING_NOISE_STD | |
| emb_updated = emb_updated + noise | |
| try: | |
| student.model.get_input_embeddings().weight.data[:min_vocab].copy_(emb_updated.to(device)) | |
| except Exception as e: | |
| log_message(f"Error updating student embeddings: {e}", level="warning") | |
| return student | |
| def save_checkpoint(model, checkpoint_path): | |
| try: | |
| torch.save(model.model.state_dict(), checkpoint_path) | |
| log_message(f"Checkpoint saved to: {checkpoint_path}", level="info") | |
| except Exception as e: | |
| log_message(f"Error saving checkpoint: {e}", level="warning") | |
| kd_loss_functions = { | |
| 'mse': F.mse_loss, | |
| 'kl': lambda student, teacher: F.kl_div(F.log_softmax(student, dim=-1), F.softmax(teacher.detach(), dim=-1), reduction='batchmean', log_target=True), | |
| 'cosine': lambda student, teacher: 1.0 - F.cosine_similarity(student.view(-1, student.size(-1)), teacher.detach().view(-1, teacher.size(-1))).mean(), | |
| } | |
| class ActivationSaver: | |
| def __init__(self, module): | |
| self.module = module | |
| self.activation_values = None | |
| self.feature_map_values = None | |
| def __call__(self, *args, **kwargs): | |
| output = self.module(*args, **kwargs) | |
| if isinstance(output, tuple): | |
| self.activation_values = output[0] | |
| self.feature_map_values = output[1] if len(output) > 1 else None | |
| return output | |
| else: | |
| self.activation_values = output | |
| return output | |
| def attach_activation_saver(model, activation_modules, feature_map_modules): | |
| for name, module in model.model.named_modules(): | |
| if type(module).__name__ in activation_modules or type(module).__name__ in feature_map_modules: | |
| wrapped_module = ActivationSaver(module) | |
| setattr(model.model, name, wrapped_module) | |
| return model | |
| def attention_kd_loss(student_attention, teacher_attention, loss_type='mse'): | |
| try: | |
| if loss_type == 'mse': | |
| return F.mse_loss(student_attention, teacher_attention.detach()) | |
| elif loss_type == 'kl': | |
| return F.kl_div(F.log_softmax(student_attention, dim=-1), F.softmax(teacher_attention.detach(), dim=-1), reduction='batchmean', log_target=True) | |
| elif loss_type == 'cosine': | |
| return 1.0 - F.cosine_similarity(student_attention.view(-1), teacher_attention.detach().view(-1)).mean() | |
| except Exception as e: | |
| log_message(f"Attention KD Loss error: {e}", level="warning") | |
| return torch.tensor(0.0, device=device) | |
| def hidden_state_kd_loss(student_hidden, teacher_hidden, loss_type='mse'): | |
| try: | |
| if loss_type == 'mse': | |
| return F.mse_loss(student_hidden, teacher_hidden.detach()) | |
| elif loss_type == 'kl': | |
| return F.kl_div(F.log_softmax(student_hidden, dim=-1), F.softmax(teacher_hidden.detach(), dim=-1), reduction='batchmean', log_target=True) | |
| elif loss_type == 'cosine': | |
| return 1.0 - F.cosine_similarity(student_hidden.view(-1), teacher_hidden.detach().view(-1)).mean() | |
| except Exception as e: | |
| log_message(f"Hidden State KD Loss error: {e}", level="warning") | |
| return torch.tensor(0.0, device=device) | |
| def advanced_knowledge_distillation(teacher_models, student_model, device): | |
| for teacher_model in teacher_models: | |
| teacher_model.model.eval() | |
| student_model.model.train() | |
| total_steps = KD_STEPS | |
| accumulation_steps = ACCUMULATION_STEPS | |
| freeze_steps = FREEZE_STUDENT_STEPS | |
| use_mixed_precision = ENABLE_MIXED_PRECISION | |
| kd_loss_factor = KD_LOSS_FACTOR | |
| ce_loss_factor = CE_LOSS_FACTOR | |
| kd_temperature = KD_TEMPERATURE | |
| attention_kd_enabled = ENABLE_ATTENTION_KD | |
| hidden_kd_enabled = ENABLE_HIDDEN_STATE_KD | |
| intermediate_kd_enabled = ENABLE_INTERMEDIATE_KD | |
| layer_norm_kd_enabled = ENABLE_LAYER_NORM_KD | |
| embedding_kd_enabled = ENABLE_EMBEDDING_KD | |
| parameter_kd_enabled = ENABLE_PARAMETER_KD | |
| activation_kd_enabled = ENABLE_ACTIVATION_KD | |
| logit_masking_enabled = ENABLE_LOGIT_MASKING_KD | |
| sparsity_regularization_enabled = ENABLE_SPARCITY_REGULARIZATION | |
| feature_map_kd_enabled = ENABLE_FEATURE_MAP_KD | |
| output_logit_kd_enabled = ENABLE_OUTPUT_LOGIT_KD | |
| layerwise_parameter_kd_enabled = ENABLE_LAYERWISE_PARAMETER_KD | |
| vocab_projection_kd_enabled = ENABLE_VOCAB_PROJECTION_KD | |
| contrastive_kd_enabled = ENABLE_CONTRASTIVE_KD | |
| rdrop_kd_enabled = ENABLE_RDROP_KD | |
| adaptive_temperature_kd_enabled = ENABLE_ADAPTIVE_TEMPERATURE_KD | |
| layer_wise_kd_enabled = ENABLE_LAYER_WISE_KD | |
| activation_reg_enabled = ENABLE_ACTIVATION_REGULARIZATION | |
| neuron_selectivity_kd_enabled = ENABLE_NEURON_SELECTIVITY_KD | |
| weighted_parameter_kd_enabled = ENABLE_WEIGHTED_PARAMETER_KD | |
| fsp_kd_enabled = ENABLE_FSP_KD | |
| attention_alignment_kd_enabled = ENABLE_ATTENTION_ALIGNMENT_KD | |
| gram_matrix_kd_enabled = ENABLE_GRAM_MATRIX_KD | |
| attention_kd_factor = ATTENTION_KD_FACTOR | |
| hidden_kd_factor = HIDDEN_STATE_KD_FACTOR | |
| intermediate_kd_factor = INTERMEDIATE_KD_FACTOR | |
| layer_norm_kd_factor = LAYER_NORM_KD_FACTOR | |
| embedding_kd_factor = EMBEDDING_KD_FACTOR | |
| parameter_kd_factor = PARAMETER_KD_FACTOR | |
| activation_kd_factor = ACTIVATION_KD_FACTOR | |
| logit_masking_factor = LOGIT_MASKING_FACTOR | |
| sparsity_regularization_factor = SPARCITY_REGULARIZATION_FACTOR | |
| feature_map_kd_factor = FEATURE_MAP_KD_FACTOR | |
| output_logit_kd_factor = OUTPUT_LOGIT_KD_FACTOR | |
| layerwise_parameter_kd_factor = LAYERWISE_PARAMETER_KD_FACTOR | |
| vocab_projection_kd_factor = VOCAB_PROJECTION_KD_FACTOR | |
| contrastive_kd_factor = CONTRASTIVE_KD_FACTOR | |
| rdrop_kd_factor = RDROP_KD_FACTOR | |
| adaptive_temperature_kd_factor = ADAPTIVE_TEMPERATURE_KD_FACTOR | |
| layer_wise_kd_factor = LAYER_WISE_KD_FACTOR | |
| activation_reg_lambda = ACTIVATION_REG_LAMBDA | |
| neuron_selectivity_kd_factor = NEURON_SELECTIVITY_KD_FACTOR | |
| weighted_parameter_kd_factor = WEIGHTED_PARAMETER_KD_FACTOR | |
| fsp_kd_factor = FSP_KD_FACTOR | |
| attention_alignment_kd_factor = ATTENTION_ALIGNMENT_KD_FACTOR | |
| gram_matrix_kd_factor = GRAM_MATRIX_KD_FACTOR | |
| attention_kd_loss_type = ATTENTION_KD_LOSS_TYPE | |
| hidden_kd_loss_type = HIDDEN_STATE_KD_LOSS_TYPE | |
| output_logit_kd_loss_type = OUTPUT_LOGIT_KD_LOSS_TYPE | |
| contrastive_kd_loss_type = CONTRASTIVE_KD_LOSS_TYPE | |
| rdrop_kd_loss_type = RDROP_KD_LOSS_TYPE | |
| layer_wise_loss_type = LAYER_WISE_LOSS_TYPE | |
| neuron_selectivity_loss_type = NEURON_SELECTIVITY_LOSS_TYPE | |
| weighted_parameter_loss_type = WEIGHTED_PARAMETER_LOSS_TYPE | |
| fsp_kd_loss_type = FSP_KD_LOSS_TYPE | |
| attention_alignment_loss_type = ATTENTION_ALIGNMENT_LOSS_TYPE | |
| gram_matrix_loss_type = GRAM_MATRIX_LOSS_TYPE | |
| intermediate_layers = INTERMEDIATE_LAYERS | |
| layer_norm_modules = LAYER_NORM_MODULES | |
| activation_modules = ACTIVATION_MODULES | |
| feature_map_modules = FEATURE_MAP_MODULES | |
| layerwise_parameter_modules = LAYERWISE_PARAMETER_MODULES | |
| vocab_projection_MODULES = VOCAB_PROJECTION_MODULES | |
| layer_wise_modules = LAYER_WISE_MODULES | |
| neuron_selectivity_modules = NEURON_SELECTIVITY_MODULES | |
| weighted_parameter_modules = WEIGHTED_PARAMETER_MODULES | |
| fsp_modules = FSP_MODULES | |
| attention_alignment_modules = ATTENTION_ALIGNMENT_MODULES | |
| gram_matrix_modules = GRAM_MATRIX_MODULES | |
| adaptive_temperature_initial = ADAPTIVE_TEMPERATURE_INITIAL | |
| adaptive_temperature_decay_rate = adaptive_temperature_decay_rate | |
| adaptive_temperature = torch.tensor(adaptive_temperature_initial, device=device, requires_grad=False) | |
| student_model.model = attach_activation_saver(student_model.model, activation_modules, feature_map_modules) | |
| for teacher_model in teacher_models: | |
| teacher_model.model = attach_activation_saver(teacher_model.model, activation_modules, feature_map_modules) | |
| optimizer = torch.optim.AdamW( | |
| student_model.model.parameters(), | |
| lr=LR_INITIAL, | |
| weight_decay=WEIGHT_DECAY | |
| ) | |
| scheduler = get_scheduler( | |
| name=SCHEDULER_TYPE, | |
| optimizer=optimizer, | |
| num_warmup_steps=WARMUP_STEPS, | |
| num_training_steps=total_steps | |
| ) if ENABLE_LR_SCHEDULER else None | |
| scaler = torch.amp.GradScaler(enabled=use_mixed_precision) | |
| if ENABLE_STUDENT_PARAMETER_FREEZE: | |
| for param in student_model.model.parameters(): | |
| param.requires_grad = False | |
| log_to_file(f"Starting KD: freezing for {freeze_steps} steps.") | |
| else: | |
| log_to_file("KD started without initial freezing.") | |
| writer = SummaryWriter(log_dir=TENSORBOARD_LOG_DIR) | |
| MODELS['writer'] = writer | |
| best_loss = float('inf') | |
| patience_counter = 0 | |
| start_time = time.time() | |
| progress_bar = tqdm(range(KD_STEPS), desc="Knowledge Distillation Progress") | |
| for step in progress_bar: | |
| try: | |
| optimizer.zero_grad() | |
| loss_accum = 0.0 | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| for _ in range(ACCUMULATION_STEPS): | |
| batch_texts = ["This is a sample text for distillation.", "Another example sentence."] | |
| try: | |
| with torch.amp.autocast(device_type=device.type, enabled=ENABLE_MIXED_PRECISION): | |
| teacher_outputs_list = [] | |
| teacher_inputs_list = [] | |
| for teacher_model in MODELS['teacher']: | |
| with torch.no_grad(): | |
| teacher_inputs = teacher_model.tokenizer( | |
| batch_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=128, | |
| ).to(device) | |
| teacher_inputs_list.append(teacher_inputs) | |
| teacher_outputs = teacher_model.model(**teacher_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD) | |
| teacher_outputs_list.append(teacher_outputs) | |
| student_inputs = MODELS['student'].tokenizer( | |
| batch_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=128, | |
| ).to(device) | |
| student_outputs = MODELS['student'].model(**student_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD) | |
| student_logits = student_outputs.logits | |
| student_attentions = student_outputs.attentions if ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD else None | |
| student_hidden_states = student_outputs.hidden_states if ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD else None | |
| teacher_logits = torch.stack([teacher.logits for teacher in teacher_outputs_list]).mean(dim=0) | |
| teacher_attentions = [teacher.attentions for teacher in teacher_outputs_list] | |
| teacher_attentions = torch.stack([torch.stack(attn) for attn in teacher_attentions]).mean(dim=0) if teacher_attentions[0] is not None else None | |
| teacher_hidden_states = [teacher.hidden_states for teacher in teacher_outputs_list] | |
| teacher_hidden_states = torch.stack([torch.stack(hiddens) for hiddens in teacher_hidden_states]).mean(dim=0) if teacher_hidden_states[0] is not None else None | |
| min_vocab_size_logits = min(teacher_logits.size(-1), student_logits.size(-1)) | |
| teacher_logits_trimmed = teacher_logits[:, :, :min_vocab_size_logits] | |
| student_logits_trimmed = student_logits[:, :, :min_vocab_size_logits] | |
| teacher_embeds = torch.stack([MODELS['teacher'][0].model.get_input_embeddings()(teacher_inputs_list[0].input_ids) for _ in range(len(MODELS['teacher']))]).mean(dim=0) | |
| student_embeds = MODELS['student'].model.get_input_embeddings()(student_inputs.input_ids) | |
| ce_loss = torch.tensor(0.0).to(device) | |
| output_kd_loss = OUTPUT_LOGIT_KD_FACTOR * kd_loss_functions[OUTPUT_LOGIT_KD_LOSS_TYPE](student_logits_trimmed / adaptive_temperature, teacher_logits_trimmed / adaptive_temperature) if ENABLE_OUTPUT_LOGIT_KD else torch.tensor(0.0).to(device) | |
| loss = output_kd_loss | |
| loss = loss + CE_LOSS_FACTOR * ce_loss | |
| attn_loss = ATTENTION_KD_FACTOR * attention_kd_loss(student_attentions[-1], teacher_attentions[-1], ATTENTION_KD_LOSS_TYPE) if ENABLE_ATTENTION_KD and teacher_attentions is not None and student_attentions is not None else torch.tensor(0.0).to(device) | |
| loss += attn_loss | |
| hidden_loss = HIDDEN_STATE_KD_FACTOR * hidden_state_kd_loss(student_hidden_states[-1], teacher_hidden_states[-1], HIDDEN_STATE_KD_LOSS_TYPE) if ENABLE_HIDDEN_STATE_KD and teacher_hidden_states is not None and student_hidden_states is not None else torch.tensor(0.0).to(device) | |
| loss += hidden_loss | |
| intermediate_loss = 0.0 | |
| if ENABLE_INTERMEDIATE_KD and teacher_hidden_states is not None and student_hidden_states is not None: | |
| for layer_idx in INTERMEDIATE_LAYERS: | |
| if layer_idx < len(teacher_hidden_states) and layer_idx < len(student_hidden_states): | |
| teacher_layer_output = teacher_hidden_states[layer_idx] | |
| student_layer_output = student_hidden_states[layer_idx] | |
| layer_loss = hidden_state_kd_loss(student_layer_output, teacher_layer_output, HIDDEN_STATE_KD_LOSS_TYPE) | |
| intermediate_loss += layer_loss | |
| loss += INTERMEDIATE_KD_FACTOR * intermediate_loss | |
| layer_norm_loss = 0.0 | |
| if ENABLE_LAYER_NORM_KD: | |
| for layer_norm_module_name in LAYER_NORM_MODULES: | |
| if hasattr(MODELS['student'].model, layer_norm_module_name): | |
| student_ln = getattr(MODELS['student'].model, layer_norm_module_name) | |
| layer_norm_weights = [] | |
| layer_norm_biases = [] | |
| for teacher_model in MODELS['teacher']: | |
| if hasattr(teacher_model.model, layer_norm_module_name): | |
| teacher_ln = getattr(teacher_model.model, layer_norm_module_name) | |
| if isinstance(student_ln, nn.LayerNorm) and isinstance(teacher_ln, nn.LayerNorm): | |
| layer_norm_weights.append(teacher_ln.weight) | |
| layer_norm_biases.append(teacher_ln.bias) | |
| if layer_norm_weights: | |
| teacher_weight_mean = torch.stack(layer_norm_weights).mean(dim=0) | |
| teacher_bias_mean = torch.stack(layer_norm_biases).mean(dim=0) | |
| layer_norm_loss += F.mse_loss(student_ln.weight, teacher_weight_mean) | |
| layer_norm_loss += F.mse_loss(student_ln.bias, teacher_bias_mean) | |
| loss += LAYER_NORM_KD_FACTOR * layer_norm_loss | |
| embed_loss = EMBEDDING_KD_FACTOR * hidden_state_kd_loss(student_embeds, teacher_embeds, HIDDEN_STATE_KD_LOSS_TYPE) if ENABLE_EMBEDDING_KD else torch.tensor(0.0).to(device) | |
| loss += embed_loss | |
| parameter_loss = 0.0 | |
| if ENABLE_PARAMETER_KD: | |
| for name, student_param in MODELS['student'].model.named_parameters(): | |
| param_values = [] | |
| for teacher_model in MODELS['teacher']: | |
| if name in teacher_model.model.state_dict(): | |
| param_values.append(teacher_model.model.state_dict()[name]) | |
| if param_values: | |
| teacher_param_mean = torch.stack(param_values).mean(dim=0) | |
| parameter_loss += kd_loss_functions[WEIGHTED_PARAMETER_LOSS_TYPE](student_param, teacher_param_mean) | |
| loss += PARAMETER_KD_FACTOR * parameter_loss | |
| layerwise_parameter_loss = 0.0 | |
| if ENABLE_LAYERWISE_PARAMETER_KD: | |
| for layer_module_name in LAYERWISE_PARAMETER_MODULES: | |
| if hasattr(MODELS['student'].model, layer_module_name): | |
| student_layer_module = getattr(MODELS['student'].model, layer_module_name) | |
| for name, student_param in student_layer_module.named_parameters(): | |
| param_values = [] | |
| full_name = f"{layer_module_name}.{name}" | |
| for teacher_model in MODELS['teacher']: | |
| if hasattr(teacher_model.model, layer_module_name) and full_name in teacher_model.model.state_dict(): | |
| param_values.append(teacher_model.model.state_dict()[full_name]) | |
| if param_values: | |
| teacher_param_mean = torch.stack(param_values).mean(dim=0) | |
| layerwise_parameter_loss += kd_loss_functions[WEIGHTED_PARAMETER_LOSS_TYPE](student_param, teacher_param_mean) | |
| loss += LAYERWISE_PARAMETER_KD_FACTOR * layerwise_parameter_loss | |
| vocab_projection_loss = 0.0 | |
| if ENABLE_VOCAB_PROJECTION_KD: | |
| for vocab_module_name in VOCAB_PROJECTION_MODULES: | |
| if hasattr(MODELS['student'].model, vocab_module_name): | |
| student_vocab_module = getattr(MODELS['student'].model, vocab_module_name) | |
| projection_weights = [] | |
| for teacher_model in MODELS['teacher']: | |
| if hasattr(teacher_model.model, vocab_module_name): | |
| teacher_vocab_module = getattr(teacher_model.model, vocab_module_name) | |
| projection_weights.append(teacher_vocab_module.weight) | |
| if projection_weights: | |
| teacher_weight_mean = torch.stack(projection_weights).mean(dim=0) | |
| vocab_projection_loss += F.mse_loss(student_vocab_module.weight, teacher_weight_mean) | |
| loss += VOCAB_PROJECTION_KD_FACTOR * vocab_projection_loss | |
| activation_loss = 0.0 | |
| if ENABLE_ACTIVATION_KD: | |
| for name, module in MODELS['student'].model.named_modules(): | |
| if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES: | |
| if module.activation_values is not None: | |
| activations = module.activation_values | |
| activation_loss += F.mse_loss(activations, torch.zeros_like(activations)) | |
| loss += ACTIVATION_KD_FACTOR * activation_loss | |
| logit_masking_loss = 0.0 | |
| if ENABLE_LOGIT_MASKING_KD: | |
| teacher_probs = F.softmax(teacher_logits, dim=-1) | |
| top_k_indices = torch.topk(teacher_probs, k=10, dim=-1)[1] | |
| mask = torch.ones_like(student_logits).scatter_(-1, top_k_indices, 0.0).bool() | |
| masked_student_logits = student_logits.masked_fill(mask, -1e9) | |
| logit_masking_loss = kd_loss_functions['kl'](masked_student_logits / adaptive_temperature, teacher_logits / adaptive_temperature) * (adaptive_temperature ** 2) | |
| loss += LOGIT_MASKING_FACTOR * logit_masking_loss | |
| sparsity_loss = 0.0 | |
| if ENABLE_SPARCITY_REGULARIZATION: | |
| for name, module in MODELS['student'].model.named_modules(): | |
| if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES: | |
| if module.activation_values is not None: | |
| activations = module.activation_values | |
| sparsity_loss += torch.norm(activations, 1) | |
| loss += SPARCITY_REGULARIZATION_FACTOR * sparsity_loss | |
| feature_map_loss = 0.0 | |
| if ENABLE_FEATURE_MAP_KD: | |
| for name, student_module in MODELS['student'].model.named_modules(): | |
| if isinstance(student_module, ActivationSaver) and type(student_module.module).__name__ in FEATURE_MAP_MODULES: | |
| teacher_module = MODELS['teacher'][0].model.get_submodule(name) | |
| if isinstance(teacher_module, ActivationSaver) and teacher_module.feature_map_values is not None and student_module.feature_map_values is not None: | |
| student_feature_map = student_module.feature_map_values | |
| teacher_feature_map = teacher_module.feature_map_values | |
| feature_map_loss += F.mse_loss(student_feature_map, teacher_feature_map) | |
| loss += FEATURE_MAP_KD_FACTOR * feature_map_loss | |
| contrastive_loss = 0.0 | |
| if ENABLE_CONTRASTIVE_KD: | |
| student_vec = student_hidden_states[-1][:, 0, :] | |
| teacher_vec = teacher_hidden_states[-1][:, 0, :] | |
| contrastive_loss = CONTRASTIVE_KD_FACTOR * kd_loss_functions[CONTRASTIVE_KD_LOSS_TYPE](student_vec, teacher_vec) | |
| loss += contrastive_loss | |
| rdrop_loss = 0.0 | |
| if ENABLE_RDROP_KD: | |
| r_student_outputs = MODELS['student'].model(**student_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD) | |
| r_student_logits = r_student_outputs.logits | |
| rdrop_loss = RDROP_KD_FACTOR * (kd_loss_functions['kl'](F.log_softmax(student_logits, dim=-1), F.log_softmax(r_student_logits, dim=-1)) + kd_loss_functions['kl'](F.log_softmax(r_student_logits, dim=-1), F.log_softmax(student_logits, dim=-1))) | |
| loss += rdrop_loss | |
| layer_wise_loss = 0.0 | |
| if ENABLE_LAYER_WISE_KD and teacher_hidden_states is not None and student_hidden_states is not None: | |
| layer_wise_loss = 0.0 | |
| for layer_module_name in LAYER_WISE_MODULES: | |
| if hasattr(MODELS['student'].model, layer_module_name): | |
| student_layer_module = getattr(MODELS['student'].model, layer_module_name) | |
| for student_layer, teacher_layer in zip(student_layer_module.modules(), MODELS['teacher'][0].model.get_module(layer_module_name).modules()): | |
| if isinstance(student_layer, nn.Linear) and isinstance(teacher_layer, nn.Linear): | |
| student_output = student_layer(student_hidden_states[-1]) | |
| teacher_output = teacher_layer(teacher_hidden_states[-1]) | |
| layer_wise_loss += hidden_state_kd_loss(student_output, teacher_output, LAYER_WISE_LOSS_TYPE) | |
| loss += LAYER_WISE_KD_FACTOR * layer_wise_loss | |
| activation_regularization_loss = 0.0 | |
| if ENABLE_ACTIVATION_REGULARIZATION: | |
| activation_regularization_loss = 0.0 | |
| for name, module in MODELS['student'].model.named_modules(): | |
| if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES: | |
| if module.activation_values is not None: | |
| activations = module.activation_values | |
| activation_regularization_loss += torch.norm(activations, p=2) | |
| loss += ACTIVATION_REG_LAMBDA * activation_regularization_loss | |
| neuron_selectivity_loss = 0.0 | |
| if ENABLE_NEURON_SELECTIVITY_KD: | |
| neuron_selectivity_loss = 0.0 | |
| for name, module in MODELS['student'].model.named_modules(): | |
| if isinstance(module, ActivationSaver) and type(module.module).__name__ in NEURON_SELECTIVITY_MODULES: | |
| if module.activation_values is not None: | |
| student_activations = module.activation_values | |
| teacher_module = MODELS['teacher'][0].model.get_submodule(name) | |
| if isinstance(teacher_module, ActivationSaver) and teacher_module.activation_values is not None: | |
| teacher_activations = teacher_module.activation_values | |
| neuron_selectivity_loss += NEURON_SELECTIVITY_KD_FACTOR * kd_loss_functions[NEURON_SELECTIVITY_LOSS_TYPE](student_activations.mean(dim=0), teacher_activations.mean(dim=0)) | |
| loss += neuron_selectivity_loss | |
| l2_reg_loss = 0.0 | |
| if ENABLE_L2_REGULARIZATION: | |
| for param in MODELS['student'].model.parameters(): | |
| l2_reg_loss += torch.norm(param, p=2) | |
| loss += L2_LAMBDA * l2_reg_loss | |
| l1_reg_loss = 0.0 | |
| if ENABLE_L1_REGULARIZATION: | |
| for param in MODELS['student'].model.parameters(): | |
| l1_reg_loss += torch.norm(param, p=1) | |
| loss += L1_LAMBDA * l1_reg_loss | |
| loss_accum += (kd_loss_factor * loss + ce_loss_factor * ce_loss) / accumulation_steps | |
| except Exception as e: | |
| log_message(f"Error in inner distillation loop: {e}", level="warning") | |
| scaler.scale(loss_accum).backward() | |
| if ENABLE_GRADIENT_CLIPPING: | |
| try: | |
| torch.nn.utils.clip_grad_norm_(MODELS['student'].model.parameters(), GRAD_CLIP_VALUE) | |
| except Exception as e: | |
| log_message(f"Gradient clipping error: {e}", level="warning") | |
| scaler.step(optimizer) | |
| scaler.update() | |
| if scheduler: | |
| scheduler.step() | |
| if ENABLE_ADAPTIVE_TEMPERATURE_KD: | |
| adaptive_temperature *= adaptive_temperature_decay_rate | |
| adaptive_temperature = max(adaptive_temperature, 1.0) | |
| if ENABLE_STUDENT_PARAMETER_FREEZE: | |
| if ENABLE_DYNAMIC_FREEZE: | |
| if loss_accum.item() < FREEZE_THRESHOLD: | |
| for param in MODELS['student'].model.parameters(): | |
| param.requires_grad = True | |
| if step == 0: | |
| log_message("Dynamic unfreezing activated due to low loss.", level="info") | |
| elif step + 1 == FREEZE_STUDENT_STEPS: | |
| for param in MODELS['student'].model.parameters(): | |
| param.requires_grad = True | |
| log_message("Unfreezing after freeze_steps.", level="info") | |
| if SAVE_CHECKPOINTS and (step + 1) % CHECKPOINT_INTERVAL == 0: | |
| ckpt_path = os.path.join(CHECKPOINT_DIR, f"student_step_{step + 1}.pt") | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| save_checkpoint(MODELS['student'], ckpt_path) | |
| log_to_file(f"Checkpoint saved to {ckpt_path}") | |
| if (step + 1) % LOG_INTERVAL == 0: | |
| elapsed = time.time() - start_time | |
| lr = optimizer.param_groups[0]['lr'] | |
| grad_norm = 0.0 | |
| num_params_with_grad = 0 | |
| for param in MODELS['student'].model.parameters(): | |
| if param.grad is not None: | |
| grad_norm += param.grad.data.norm(2).item() | |
| num_params_with_grad += 1 | |
| avg_grad_norm = grad_norm / num_params_with_grad if num_params_with_grad > 0 else 0.0 | |
| log_msg = (f"[KD] Step {step + 1}/{KD_STEPS}, Loss: {loss_accum.item():.4f}, " | |
| f"LR: {lr:.6f}, GradNorm: {avg_grad_norm:.4f}, Time: {elapsed:.2f}s, Temp: {adaptive_temperature.item():.3f}") | |
| print(log_msg) | |
| log_to_file(log_msg) | |
| writer.add_scalar("Loss/accumulated", loss_accum.item(), step + 1) | |
| writer.add_scalar("Learning_Rate", lr, step + 1) | |
| writer.add_scalar("GradNorm", avg_grad_norm, step + 1) | |
| writer.add_scalar("Output_Logit_KD_Loss", output_kd_loss.item(), step + 1) | |
| writer.add_scalar("Attention_KD_Loss", attn_loss.item(), step + 1) | |
| writer.add_scalar("HiddenState_KD_Loss", hidden_loss.item(), step + 1) | |
| writer.add_scalar("Intermediate_KD_Loss", intermediate_loss.item(), step + 1) | |
| writer.add_scalar("LayerNorm_KD_Loss", layer_norm_loss.item(), step + 1) | |
| writer.add_scalar("Embedding_KD_Loss", embed_loss.item(), step + 1) | |
| writer.add_scalar("Parameter_KD_Loss", parameter_loss.item(), step + 1) | |
| writer.add_scalar("Activation_KD_Loss", activation_loss.item(), step + 1) | |
| writer.add_scalar("LogitMasking_KD_Loss", logit_masking_loss.item(), step + 1) | |
| writer.add_scalar("Sparcity_Regularization_Loss", sparsity_loss.item(), step + 1) | |
| writer.add_scalar("FeatureMap_KD_Loss", feature_map_loss.item(), step + 1) | |
| writer.add_scalar("Layerwise_Parameter_KD_Loss", layerwise_parameter_loss.item(), step + 1) | |
| writer.add_scalar("VocabProjection_KD_Loss", vocab_projection_loss.item(), step + 1) | |
| writer.add_scalar("Contrastive_KD_Loss", contrastive_loss.item(), step + 1) | |
| writer.add_scalar("RDrop_KD_Loss", rdrop_loss.item(), step + 1) | |
| writer.add_scalar("Adaptive_Temperature", adaptive_temperature.item(), step + 1) | |
| writer.add_scalar("LayerWise_KD_Loss", layer_wise_loss.item(), step + 1) | |
| writer.add_scalar("Activation_Regularization_Loss", activation_regularization_loss.item(), step + 1) | |
| writer.add_scalar("NeuronSelectivity_KD_Loss", neuron_selectivity_loss.item(), step + 1) | |
| writer.add_scalar("WeightedParameter_KD_Loss", weighted_parameter_loss.item(), step + 1) | |
| if ENABLE_EARLY_STOPPING: | |
| if loss_accum.item() < best_loss: | |
| best_loss = loss_accum.item() | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= EARLY_STOPPING_PATIENCE: | |
| log_message(f"Early stopping activated at step {step + 1} after {EARLY_STOPPING_PATIENCE} steps.", level="info") | |
| log_to_file(f"Early stopping activated at step {step + 1}.") | |
| break | |
| except Exception as e: | |
| log_message(f"Error in distillation step {step + 1}: {e}", level="warning") | |
| continue | |
| progress_bar.close() | |
| writer.close() | |
| MODELS['writer'] = None | |
| cleanup_and_exit(0) | |
| return MODELS['student'] | |
| def push_model_to_hub(model, tokenizer, quantization_method, repo_name, use_auth_token): | |
| try: | |
| log_message(f"Saving {model.__class__.__name__} to {repo_name} with method '{quantization_method}'...", level="info") | |
| model.model.save_pretrained(repo_name, push_to_hub=False) | |
| tokenizer.save_pretrained(repo_name, push_to_hub=False) | |
| model.push_to_hub(repo_name, use_auth_token=use_auth_token) | |
| tokenizer.push_to_hub(repo_name, use_auth_token=use_auth_token) | |
| log_message("Upload completed.", level="info") | |
| log_to_file(f"Model and tokenizer uploaded to {repo_name} with method {quantization_method}.") | |
| except Exception as e: | |
| log_message(f"Error during push to hub: {e}", level="warning") | |
| def login_to_huggingface(token): | |
| global AUTH_TOKEN, USER_NAME | |
| try: | |
| user_info = HfApi(token=token).whoami() | |
| AUTH_TOKEN = token | |
| USER_NAME = user_info['name'] | |
| log_message(f"Successfully logged in to Hugging Face as {USER_NAME}.", level="info") | |
| return AUTH_TOKEN, USER_NAME, None | |
| except Exception as e: | |
| log_message(f"Hugging Face Hub login error: {e}", level="warning") | |
| return None, None, "Invalid Hugging Face token." | |
| def run_fusion_distillation(teacher_model_ckpt_1, teacher_model_ckpt_2, student_model_ckpt, repo_name, disable_mean_resizing, huggingface_token): | |
| global MODELS, LOG_TEXT, GRADIO_LOG_OUTPUT, AUTH_TOKEN, USER_NAME | |
| LOG_TEXT = "" | |
| GRADIO_LOG_OUTPUT.value = LOG_TEXT | |
| mean_resizing = not disable_mean_resizing | |
| token, username, error_message = login_to_huggingface(huggingface_token) | |
| if token is None: | |
| log_message("Hugging Face login failed.", level="warning") | |
| return error_message | |
| AUTH_TOKEN = token | |
| USER_NAME = username | |
| try: | |
| log_message(f"Authenticated with Hugging Face Hub as user: {USER_NAME}.", level="info") | |
| except Exception as e: | |
| log_message(f"Hub login error: {e}", level="warning") | |
| return "Hugging Face Hub login failed: " + str(e) | |
| try: | |
| student_tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt) | |
| MODELS['student'] = pipeline(model=student_model_ckpt, tokenizer=student_tokenizer, device=device) | |
| if MODELS['student'].tokenizer is None: | |
| log_message("Student model tokenizer could not be loaded.", level="warning") | |
| return "Student model tokenizer could not be loaded." | |
| special_token = "[UNFILTERED]" | |
| if special_token not in MODELS['student'].tokenizer.get_vocab(): | |
| MODELS['student'].tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) | |
| log_message("Special token added to student tokenizer.", level="info") | |
| MODELS['student'].model.resize_token_embeddings(len(MODELS['student'].tokenizer)) | |
| except Exception as e: | |
| log_message(f"Error loading student model: {e}", level="warning") | |
| return "Error loading student model: " + str(e) | |
| teacher_models = [] | |
| teacher_tokenizers = [] | |
| teacher_model_checkpoints = [teacher_model_ckpt_1] | |
| if teacher_model_ckpt_2: | |
| teacher_model_checkpoints.append(teacher_model_ckpt_2) | |
| if not isinstance(teacher_model_checkpoints, list): | |
| log_message("Error loading teacher models: Teacher Model Checkpoints must be a list.", level="warning") | |
| return "Error loading teacher models: Teacher Model Checkpoints must be a list." | |
| for i, teacher_model_ckpt in enumerate(teacher_model_checkpoints): | |
| if not isinstance(teacher_model_ckpt, str): | |
| log_message(f"Error loading teacher models: not a string at index {i}", level="warning") | |
| return f"Error loading teacher models: not a string at index {i}" | |
| try: | |
| teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_ckpt) | |
| teacher_pipeline = pipeline(model=teacher_model_ckpt, tokenizer=teacher_tokenizer, device=device) | |
| teacher_models.append(teacher_pipeline) | |
| teacher_tokenizers.append(teacher_tokenizer) | |
| log_message(f"Teacher model {i + 1} loaded: {teacher_model_ckpt}", level="info") | |
| except Exception as e: | |
| log_message(f"Error loading teacher model {i + 1} ({teacher_model_ckpt}): {e}", level="warning") | |
| return f"Error loading teacher model {i + 1}: {e}" | |
| MODELS['teacher'] = teacher_models | |
| try: | |
| unify_parameters(MODELS['student'], MODELS['teacher'][0]) | |
| unify_embeddings(MODELS['student'], MODELS['teacher'][0], mean_resizing=mean_resizing) | |
| MODELS['student'].tokenizer = unify_tokenizers(MODELS['student'].tokenizer, teacher_tokenizers[0], MODELS['student']) | |
| MODELS['student'].tokenizer = fuse_tokenizers(teacher_tokenizers, MODELS['student'].tokenizer) | |
| except Exception as e: | |
| log_message(f"Model unification error: {e}", level="warning") | |
| fusion_method_name = 'geometric_mean_double' | |
| try: | |
| unified_teacher_state = complete_unify_teacher_models_double(MODELS['teacher'], fusion_method=fusion_method_name, layer_scale=1.0, device=device) | |
| MODELS['student'] = unify_teacher_into_student(unified_teacher_state, MODELS['student'], force_parameter_copy=True) | |
| MODELS['student'] = update_student_embeddings_double(MODELS['student'], MODELS['teacher'], fusion_type=fusion_method_name, device=device, mean_resizing=mean_resizing) | |
| except Exception as e: | |
| log_message(f"Teacher model fusion error: {e}", level="warning") | |
| try: | |
| unified_student = advanced_knowledge_distillation(MODELS['teacher'], MODELS['student'], device) | |
| except Exception as e: | |
| log_message(f"Knowledge distillation error: {e}", level="warning") | |
| return "Knowledge Distillation Failed, but Fusion might be partially completed. Check logs." | |
| if ENABLE_PARAMETER_COUNT_CHECK: | |
| student_param_count = sum(p.numel() for p in unified_student.model.parameters()) | |
| teacher_param_count = sum(p.numel() for p in MODELS['teacher'][0].model.parameters()) | |
| if student_param_count != teacher_param_count: | |
| log_message(f"Warning: student parameters ({student_param_count:,}) differ from teacher ({teacher_param_count:,}).", level="warning") | |
| else: | |
| log_message(f"Parameter count consistent: Student and Teacher ({teacher_param_count:,}).", level="info") | |
| sample_texts = ["What is the capital of France?", "Solve: 3 + 5 * 2", "Define a function in Python."] | |
| try: | |
| predictions = generate_predictions(unified_student, MODELS['student'].tokenizer, sample_texts) | |
| log_message("Sample Predictions after KD:\n" + "\n".join(predictions), level="info") | |
| except Exception as e: | |
| log_message(f"Sample prediction error: {e}", level="warning") | |
| api = HfApi() | |
| try: | |
| api.create_repo(repo_name, exist_ok=True, token=AUTH_TOKEN) | |
| push_model_to_hub(unified_student, MODELS['student'].tokenizer, "unified_teacher_kd_full_options_default_true_geometric_mean_double_fusion_v11", repo_name, AUTH_TOKEN) | |
| final_msg = f"Student model fused, distilled, and uploaded to Hub to repo: {repo_name} with all advanced options and 'geometric_mean_double' fusion!" | |
| print(final_msg) | |
| log_to_file(final_msg, "training_log.txt") | |
| return "Fusion and Distillation Completed! Check console and logs." | |
| except Exception as e: | |
| log_message(f"Final upload/repo creation error: {e}", level="warning") | |
| return "Error during final upload or repository creation: " + str(e) | |
| if __name__ == "__main__": | |
| with gr.Blocks(css=".gradio-container {padding: 20px}") as iface: | |
| gr.Markdown("# Fusion and Distillation Pipeline for Language Models") | |
| gr.Markdown( | |
| "This application fuses and distills knowledge from teacher language models into a student model. " | |
| "It supports advanced knowledge distillation techniques and model fusion strategies." | |
| ) | |
| huggingface_token_input = gr.Textbox(label="Hugging Face Token", type="password", visible=True) | |
| username_display = gr.Textbox(label="Hugging Face Username", interactive=False, visible=False) | |
| def process_token(huggingface_token): | |
| _, username, error_message = login_to_huggingface(huggingface_token) | |
| return username, error_message | |
| huggingface_token_input.change( | |
| process_token, | |
| inputs=[huggingface_token_input], | |
| outputs=[username_display] | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("## Model Selection") | |
| teacher_model_ckpt_1_input = HuggingfaceHubSearch(label="Teacher Model 1") | |
| teacher_model_ckpt_2_input = HuggingfaceHubSearch(label="Teacher Model 2 (Optional)") | |
| student_model_ckpt_input = HuggingfaceHubSearch(label="Student Model") | |
| repo_name_input = gr.Textbox( | |
| label="Repository Name", | |
| info="Enter the name of the Hugging Face repository to create or update with the distilled model." | |
| ) | |
| disable_mean_resizing_checkbox = gr.Checkbox( | |
| label="Disable Mean Resizing", | |
| info="Check to disable mean resizing of embeddings during unification." | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("## Execution and Logs") | |
| run_button = gr.Button("Run Fusion and Distillation", variant="primary", interactive=True, elem_id='run-button') | |
| output_status_textbox = gr.Textbox( | |
| info="Real-time status and logs of the fusion and distillation process.", | |
| label="Status", | |
| lines=3 | |
| ) | |
| GRADIO_LOG_OUTPUT = gr.Textbox( | |
| value="", | |
| label="Detailed Log Output", | |
| lines=10, | |
| interactive=False | |
| ) | |
| def update_run_button_interactivity(token): | |
| return gr.Button.update(interactive=True) | |
| huggingface_token_input.change( | |
| update_run_button_interactivity, | |
| inputs=[huggingface_token_input], | |
| outputs=[run_button] | |
| ) | |
| run_button.click( | |
| run_fusion_distillation, | |
| inputs=[ | |
| teacher_model_ckpt_1_input, | |
| teacher_model_ckpt_2_input, | |
| student_model_ckpt_input, | |
| repo_name_input, | |
| disable_mean_resizing_checkbox, | |
| huggingface_token_input | |
| ], | |
| outputs=output_status_textbox, | |
| ) | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |