import gradio as gr import os import unicodedata import logging import warnings import torch import torch.nn as nn import torch.nn.functional as F import time from transformers import AutoConfig, get_scheduler, pipeline, AutoTokenizer, AutoModelForPreTraining, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModel from huggingface_hub import HfApi from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter import numpy as np import psutil import signal import sys import gc from gradio_huggingfacehub_search import HuggingfaceHubSearch warnings.filterwarnings("ignore", message="cannot set number of interop threads after parallel work has started or set_num_interop_threads called") torch.set_num_threads(os.cpu_count() // 2 if os.cpu_count() > 1 else 1) device = torch.device("cpu") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) MODELS = {} LOG_TEXT = "" GRADIO_LOG_OUTPUT = None HF_API = HfApi() AUTH_TOKEN = None USER_NAME = None SAVE_ON_RESOURCE_TERMINATION = True CHECKPOINT_DIR = "./fusion_distillation_checkpoints" TENSORBOARD_LOG_DIR = "./tensorboard_logs" SAVE_CHECKPOINTS = True CHECKPOINT_INTERVAL = 1000 LOG_INTERVAL = 100 KD_STEPS = 10000 ACCUMULATION_STEPS = 1 FREEZE_STUDENT_STEPS = 2000 ENABLE_MIXED_PRECISION = False KD_LOSS_FACTOR = 1.0 CE_LOSS_FACTOR = 0.0 KD_TEMPERATURE = 2.0 ENABLE_ATTENTION_KD = True ENABLE_HIDDEN_STATE_KD = True ENABLE_INTERMEDIATE_KD = True ENABLE_LAYER_NORM_KD = True ENABLE_EMBEDDING_KD = True ENABLE_PARAMETER_KD = True ENABLE_ACTIVATION_KD = True ENABLE_LOGIT_MASKING_KD = True ENABLE_SPARCITY_REGULARIZATION = True ENABLE_FEATURE_MAP_KD = True ENABLE_OUTPUT_LOGIT_KD = True ENABLE_LAYERWISE_PARAMETER_KD = True ENABLE_VOCAB_PROJECTION_KD = True ENABLE_CONTRASTIVE_KD = True ENABLE_RDROP_KD = True ENABLE_ADAPTIVE_TEMPERATURE_KD = True ENABLE_LAYER_WISE_KD = True ENABLE_ACTIVATION_REGULARIZATION = True ENABLE_NEURON_SELECTIVITY_KD = True ENABLE_WEIGHTED_PARAMETER_KD = True ENABLE_FSP_KD = True ENABLE_ATTENTION_ALIGNMENT_KD = True ENABLE_GRAM_MATRIX_KD = True ATTENTION_KD_FACTOR = 0.1 HIDDEN_STATE_KD_FACTOR = 0.1 INTERMEDIATE_KD_FACTOR = 0.1 LAYER_NORM_KD_FACTOR = 0.01 EMBEDDING_KD_FACTOR = 0.01 PARAMETER_KD_FACTOR = 0.001 ACTIVATION_KD_FACTOR = 0.0001 LOGIT_MASKING_FACTOR = 0.01 SPARCITY_REGULARIZATION_FACTOR = 1e-5 FEATURE_MAP_KD_FACTOR = 0.005 OUTPUT_LOGIT_KD_FACTOR = 1.0 LAYERWISE_PARAMETER_KD_FACTOR = 0.0005 VOCAB_PROJECTION_KD_FACTOR = 0.001 CONTRASTIVE_KD_FACTOR = 0.05 RDROP_KD_FACTOR = 0.02 ADAPTIVE_TEMPERATURE_KD_FACTOR = 0.01 LAYER_WISE_KD_FACTOR = 0.05 ACTIVATION_REG_LAMBDA = 1e-7 NEURON_SELECTIVITY_KD_FACTOR = 0.001 WEIGHTED_PARAMETER_KD_FACTOR = 0.0002 FSP_KD_FACTOR = 0.001 ATTENTION_ALIGNMENT_KD_FACTOR = 0.01 GRAM_MATRIX_KD_FACTOR = 0.0005 ATTENTION_KD_LOSS_TYPE = 'mse' HIDDEN_STATE_KD_LOSS_TYPE = 'mse' OUTPUT_LOGIT_KD_LOSS_TYPE = 'kl' CONTRASTIVE_KD_LOSS_TYPE = 'cosine' RDROP_KD_LOSS_TYPE = 'kl' LAYER_WISE_LOSS_TYPE = 'mse' NEURON_SELECTIVITY_LOSS_TYPE = 'mse' WEIGHTED_PARAMETER_LOSS_TYPE = 'mse' FSP_KD_LOSS_TYPE = 'mse' ATTENTION_ALIGNMENT_LOSS_TYPE = 'mse' GRAM_MATRIX_LOSS_TYPE = 'mse' INTERMEDIATE_LAYERS = [2, 5, 8] LAYER_NORM_MODULES = ['LayerNorm'] ACTIVATION_MODULES = ['Linear'] FEATURE_MAP_MODULES = ['Linear'] LAYERWISE_PARAMETER_MODULES = ['transformer.h'] VOCAB_PROJECTION_MODULES = ['lm_head'] LAYER_WISE_MODULES = ['transformer.h'] NEURON_SELECTIVITY_MODULES = ['Linear'] WEIGHTED_PARAMETER_MODULES = ['Linear'] FSP_MODULES = ['Linear'] ATTENTION_ALIGNMENT_MODULES = ['SelfAttention'] GRAM_MATRIX_MODULES = ['Linear'] ADAPTIVE_TEMPERATURE_INITIAL = 2.0 ADAPTIVE_TEMPERATURE_DECAY_RATE = 0.999 LR_INITIAL = 5e-5 WEIGHT_DECAY = 1e-4 SCHEDULER_TYPE = "linear" WARMUP_STEPS = 500 ENABLE_LR_SCHEDULER = True ENABLE_STUDENT_PARAMETER_FREEZE = True ENABLE_DYNAMIC_FREEZE = False FREEZE_THRESHOLD = 0.1 ENABLE_GRADIENT_CLIPPING = True GRAD_CLIP_VALUE = 1.0 ENABLE_EARLY_STOPPING = True EARLY_STOPPING_PATIENCE = 5 ENABLE_PARAMETER_COUNT_CHECK = False ENABLE_EMBEDDING_NOISE = True EMBEDDING_NOISE_STD = 1e-4 ENABLE_L2_REGULARIZATION = False L2_LAMBDA = 1e-6 ENABLE_L1_REGULARIZATION = False L1_LAMBDA = 1e-6 def cleanup_and_exit(signal_code): log_message("Initiating cleanup...", level="info") if MODELS.get('student') and SAVE_ON_RESOURCE_TERMINATION: try: student_model = MODELS['student'] os.makedirs(CHECKPOINT_DIR, exist_ok=True) checkpoint_path = os.path.join(CHECKPOINT_DIR, "student_model_terminated.pt") save_checkpoint(student_model, checkpoint_path) log_message(f"Model checkpoint saved to {checkpoint_path} due to termination signal.", level="info") except Exception as e: log_message(f"Error during checkpoint save on cleanup: {e}", level="error") if MODELS.get('writer'): try: MODELS['writer'].close() log_message("TensorBoard writer closed.", level="info") except Exception as e: log_message(f"Error closing TensorBoard writer during cleanup: {e}", level="error") log_message(f"Cleanup completed.", level="info") def signal_handler(sig, frame): log_message(f"Received signal {sig}. Continuing process...", level="warning") signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) def log_message(message, level="info"): global LOG_TEXT, GRADIO_LOG_OUTPUT log_str = f"{message}\n" LOG_TEXT += log_str if GRADIO_LOG_OUTPUT: GRADIO_LOG_OUTPUT.value = LOG_TEXT if level == "info": logger.info(message) elif level == "warning": logger.warning(message) elif level == "error": logger.error(message) print(message) tqdm.write(message) def log_to_file(message, log_file="training_log.txt"): try: with open(log_file, "a") as f: f.write(f"{message}\n") except Exception as e: log_message(f"Error writing to log file: {e}", level="warning") def unify_parameters(student_model, teacher_model, exclude_layers=None): try: teacher_state = teacher_model.model.state_dict() student_state = student_model.model.state_dict() excluded_names = exclude_layers or [] for name, param in student_state.items(): if any(excluded_name in name for excluded_name in excluded_names): continue if name in teacher_state: try: if student_state[name].shape == teacher_state[name].shape: student_state[name].copy_(teacher_state[name]) else: min_shape = [min(s, t) for s, t in zip(student_state[name].shape, teacher_state[name].shape)] student_slice = tuple([slice(0, s) for s in min_shape]) teacher_slice = tuple([slice(0, s) for s in min_shape]) student_state[name][student_slice].copy_(teacher_state[name][teacher_slice]) except Exception as e: log_message(f"Parameter copy error for {name}: {e}", level="warning") student_model.model.load_state_dict(student_state, strict=False) except Exception as e: log_message(f"Error in unify_parameters: {e}", level="warning") def unify_embeddings(student_model, teacher_model, project_embeddings=True, mean_resizing=False): try: if hasattr(student_model.model, "get_input_embeddings") and hasattr(teacher_model.model, "get_input_embeddings"): student_emb = student_model.model.get_input_embeddings() teacher_emb = teacher_model.model.get_input_embeddings() if project_embeddings: in_dim = teacher_emb.weight.shape[1] out_dim = student_emb.weight.shape[1] projection = nn.Linear(in_dim, out_dim).to(device) teacher_emb_projected = projection(teacher_emb.weight) student_vocab_size = student_emb.weight.shape[0] teacher_vocab_size_proj = teacher_emb_projected.shape[0] if mean_resizing and student_vocab_size < teacher_vocab_size_proj: try: student_model.model.resize_token_embeddings(teacher_vocab_size_proj) student_emb = student_model.model.get_input_embeddings() except Exception as e: log_message(f"Error resizing student embeddings: {e}", level="warning") min_vocab_size = min(student_emb.weight.shape[0], teacher_vocab_size_proj) try: student_emb.weight.data[:min_vocab_size].copy_(teacher_emb_projected.data[:min_vocab_size]) except Exception as e: log_message(f"Error copying projected embeddings: {e}", level="warning") else: min_vocab_size = min(student_emb.weight.shape[0], teacher_emb.weight.shape[0]) try: student_emb.weight.data[:min_vocab_size].copy_(teacher_emb.weight.data[:min_vocab_size]) except Exception as e: log_message(f"Error copying embeddings: {e}", level="warning") if hasattr(student_model.model, "get_output_embeddings") and hasattr(teacher_model.model, "get_output_embeddings"): student_out_emb = student_model.model.get_output_embeddings() teacher_out_emb = teacher_model.model.get_output_embeddings() if student_out_emb is not None and teacher_out_emb is not None: if project_embeddings: in_dim = teacher_out_emb.weight.shape[1] out_dim = student_out_emb.weight.shape[1] projection = nn.Linear(in_dim, out_dim).to(device) teacher_out_emb_projected = projection(teacher_out_emb.weight) student_vocab_size = student_out_emb.weight.shape[0] teacher_vocab_size_proj = teacher_out_emb_projected.shape[0] if mean_resizing and student_vocab_size < teacher_vocab_size_proj: try: student_model.model.resize_token_embeddings(teacher_vocab_size_proj) student_out_emb = student_model.model.get_output_embeddings() except Exception as e: log_message(f"Error resizing student output embeddings: {e}", level="warning") min_vocab_size = min(student_out_emb.weight.shape[0], teacher_vocab_size_proj) try: student_out_emb.weight.data[:min_vocab_size].copy_(teacher_out_emb_projected.data[:min_vocab_size]) except Exception as e: log_message(f"Error copying projected output embeddings: {e}", level="warning") else: min_vocab_size = min(student_out_emb.weight.shape[0], teacher_out_emb.weight.shape[0]) try: student_out_emb.weight.data[:min_vocab_size].copy_(teacher_out_emb.weight.data[:min_vocab_size]) except Exception as e: log_message(f"Error copying output embeddings: {e}", level="warning") except Exception as e: log_message(f"Error in unify_embeddings: {e}", level="warning") def unify_tokenizers(student_tokenizer, teacher_tokenizer, student_model): if teacher_tokenizer is None: return student_tokenizer try: teacher_vocab = teacher_tokenizer.get_vocab() student_vocab = student_tokenizer.get_vocab() new_tokens = [token for token in teacher_vocab if token not in student_vocab] if new_tokens: student_tokenizer.add_tokens(new_tokens) student_model.model.resize_token_embeddings(len(student_tokenizer)) except Exception as e: log_message(f"Tokenizer unification error: {e}", level="warning") return student_tokenizer def normalize_text(text): return unicodedata.normalize('NFKC', text) def generate_predictions(model, tokenizer, texts, max_length=150, **tokenizer_kwargs): try: inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=max_length, use_fast=True, **tokenizer_kwargs).to(device) outputs = model.model.generate(**inputs) return list(map(lambda output: tokenizer.decode(output, skip_special_tokens=True), outputs)) except Exception as e: log_message(f"Prediction generation error: {e}", level="warning") return [""] * len(texts) fusion_functions = { 'geometric_mean_double': lambda teacher_states, layer_scale, key, device: geometric_mean_fusion_double(teacher_states, device, len(teacher_states), layer_scale, key), } def geometric_mean_fusion_double(states, device, num_teachers, layer_scale, key): min_len = min([state[key].shape[0] for state in states]) log_abs_sum = sum([np.log(torch.abs(state[key][:min_len].detach().cpu().numpy())) for state in states]) sign_prod = np.prod([np.sign(state[key][:min_len].detach().cpu().numpy()) for state in states]) geometric_mean_val = torch.tensor(np.exp(log_abs_sum * (1.0 / num_teachers)) * sign_prod, device=device) * layer_scale return geometric_mean_val def complete_unify_teacher_models_double(teacher_models, fusion_method='geometric_mean_double', layer_scale=1.0, layer_weights=None, device=device): teacher_states = [teacher.model.state_dict() for teacher in teacher_models] unified_state = {} first_teacher_state = teacher_states[0] for key in first_teacher_state.keys(): teacher_layer_states = [] all_teachers_have_layer = True for state in teacher_states: if key in state: teacher_layer_states.append(state) else: all_teachers_have_layer = False break if all_teachers_have_layer: if fusion_method in fusion_functions: unified_state[key] = fusion_functions['geometric_mean_double'](teacher_layer_states, layer_scale, key, device=device) else: layer_sum = torch.stack([state[key] for state in teacher_layer_states]).sum(dim=0) unified_state[key] = layer_sum / len(teacher_models) else: unified_state[key] = first_teacher_state[key] return unified_state def unify_teacher_into_student(unified_teacher_state, student, force_parameter_copy=False): student_state = student.model.state_dict() new_state = {} for key, student_value in student_state.items(): if key in unified_teacher_state: teacher_value = unified_teacher_state[key] try: if student_value.shape == teacher_value.shape: new_state[key] = teacher_value else: min_shape = [min(s, t) for s, t in zip(student_value.shape, teacher_value.shape)] student_slice = tuple([slice(0, s) for s in min_shape]) teacher_slice = tuple([slice(0, s) for s in min_shape]) new_state[key] = student_value.clone() new_state[key][student_slice] = teacher_value[teacher_slice] except Exception as e: log_message(f"Parameter assignment error for {key}: {e}", level="warning") else: new_state[key] = student_value student.model.load_state_dict(new_state, strict=False) return student def fuse_tokenizers(teacher_tokenizers, student_tokenizer): unified_vocab = set(student_tokenizer.get_vocab().keys()) if student_tokenizer else set() for teacher_tokenizer in teacher_tokenizers: if teacher_tokenizer: teacher_vocab = set(teacher_tokenizer.get_vocab().keys()) unified_vocab = unified_vocab.union(teacher_vocab) if student_tokenizer: try: student_tokenizer.add_tokens(list(unified_vocab - set(student_tokenizer.get_vocab().keys()))) except Exception as e: log_message(f"Error fusing tokenizers: {e}", level="warning") return student_tokenizer def update_student_embeddings_double(student, teacher_models, fusion_type='geometric_mean_double', device=device, mean_resizing=False): emb_student = student.model.get_input_embeddings().weight.data student_vocab_size, student_emb_dim = emb_student.shape teacher_embeddings_proj = [] min_teacher_vocab_size = float('inf') for teacher in teacher_models: emb_teacher = teacher.model.get_input_embeddings().weight.data.detach().cpu() teacher_vocab_size, teacher_emb_dim = emb_teacher.shape min_teacher_vocab_size = min(min_teacher_vocab_size, teacher_vocab_size) if teacher_emb_dim != student_emb_dim: proj = nn.Linear(teacher_emb_dim, student_emb_dim, bias=False).to(device) emb_teacher_proj = proj(emb_teacher.to(device)).cpu() else: emb_teacher_proj = emb_teacher teacher_embeddings_proj.append(emb_teacher_proj) min_vocab = min(min_teacher_vocab_size, student_vocab_size) fusion_function = fusion_functions.get(fusion_type, fusion_functions['geometric_mean_double']) teacher_states = [{'embedding': emb} for emb in teacher_embeddings_proj] emb_updated = fusion_function(teacher_states, 1.0, 'embedding', device=device)[:min_vocab] if fusion_type != 'layerwise' else fusion_functions['geometric_mean_double'](teacher_states, 1.0, 'embedding', device=device)[:min_vocab] if ENABLE_EMBEDDING_NOISE: noise = torch.randn_like(emb_updated) * EMBEDDING_NOISE_STD emb_updated = emb_updated + noise try: student.model.get_input_embeddings().weight.data[:min_vocab].copy_(emb_updated.to(device)) except Exception as e: log_message(f"Error updating student embeddings: {e}", level="warning") return student def save_checkpoint(model, checkpoint_path): try: torch.save(model.model.state_dict(), checkpoint_path) log_message(f"Checkpoint saved to: {checkpoint_path}", level="info") except Exception as e: log_message(f"Error saving checkpoint: {e}", level="warning") kd_loss_functions = { 'mse': F.mse_loss, 'kl': lambda student, teacher: F.kl_div(F.log_softmax(student, dim=-1), F.softmax(teacher.detach(), dim=-1), reduction='batchmean', log_target=True), 'cosine': lambda student, teacher: 1.0 - F.cosine_similarity(student.view(-1, student.size(-1)), teacher.detach().view(-1, teacher.size(-1))).mean(), } class ActivationSaver: def __init__(self, module): self.module = module self.activation_values = None self.feature_map_values = None def __call__(self, *args, **kwargs): output = self.module(*args, **kwargs) if isinstance(output, tuple): self.activation_values = output[0] self.feature_map_values = output[1] if len(output) > 1 else None return output else: self.activation_values = output return output def attach_activation_saver(model, activation_modules, feature_map_modules): for name, module in model.model.named_modules(): if type(module).__name__ in activation_modules or type(module).__name__ in feature_map_modules: wrapped_module = ActivationSaver(module) setattr(model.model, name, wrapped_module) return model def attention_kd_loss(student_attention, teacher_attention, loss_type='mse'): try: if loss_type == 'mse': return F.mse_loss(student_attention, teacher_attention.detach()) elif loss_type == 'kl': return F.kl_div(F.log_softmax(student_attention, dim=-1), F.softmax(teacher_attention.detach(), dim=-1), reduction='batchmean', log_target=True) elif loss_type == 'cosine': return 1.0 - F.cosine_similarity(student_attention.view(-1), teacher_attention.detach().view(-1)).mean() except Exception as e: log_message(f"Attention KD Loss error: {e}", level="warning") return torch.tensor(0.0, device=device) def hidden_state_kd_loss(student_hidden, teacher_hidden, loss_type='mse'): try: if loss_type == 'mse': return F.mse_loss(student_hidden, teacher_hidden.detach()) elif loss_type == 'kl': return F.kl_div(F.log_softmax(student_hidden, dim=-1), F.softmax(teacher_hidden.detach(), dim=-1), reduction='batchmean', log_target=True) elif loss_type == 'cosine': return 1.0 - F.cosine_similarity(student_hidden.view(-1), teacher_hidden.detach().view(-1)).mean() except Exception as e: log_message(f"Hidden State KD Loss error: {e}", level="warning") return torch.tensor(0.0, device=device) def advanced_knowledge_distillation(teacher_models, student_model, device): for teacher_model in teacher_models: teacher_model.model.eval() student_model.model.train() total_steps = KD_STEPS accumulation_steps = ACCUMULATION_STEPS freeze_steps = FREEZE_STUDENT_STEPS use_mixed_precision = ENABLE_MIXED_PRECISION kd_loss_factor = KD_LOSS_FACTOR ce_loss_factor = CE_LOSS_FACTOR kd_temperature = KD_TEMPERATURE attention_kd_enabled = ENABLE_ATTENTION_KD hidden_kd_enabled = ENABLE_HIDDEN_STATE_KD intermediate_kd_enabled = ENABLE_INTERMEDIATE_KD layer_norm_kd_enabled = ENABLE_LAYER_NORM_KD embedding_kd_enabled = ENABLE_EMBEDDING_KD parameter_kd_enabled = ENABLE_PARAMETER_KD activation_kd_enabled = ENABLE_ACTIVATION_KD logit_masking_enabled = ENABLE_LOGIT_MASKING_KD sparsity_regularization_enabled = ENABLE_SPARCITY_REGULARIZATION feature_map_kd_enabled = ENABLE_FEATURE_MAP_KD output_logit_kd_enabled = ENABLE_OUTPUT_LOGIT_KD layerwise_parameter_kd_enabled = ENABLE_LAYERWISE_PARAMETER_KD vocab_projection_kd_enabled = ENABLE_VOCAB_PROJECTION_KD contrastive_kd_enabled = ENABLE_CONTRASTIVE_KD rdrop_kd_enabled = ENABLE_RDROP_KD adaptive_temperature_kd_enabled = ENABLE_ADAPTIVE_TEMPERATURE_KD layer_wise_kd_enabled = ENABLE_LAYER_WISE_KD activation_reg_enabled = ENABLE_ACTIVATION_REGULARIZATION neuron_selectivity_kd_enabled = ENABLE_NEURON_SELECTIVITY_KD weighted_parameter_kd_enabled = ENABLE_WEIGHTED_PARAMETER_KD fsp_kd_enabled = ENABLE_FSP_KD attention_alignment_kd_enabled = ENABLE_ATTENTION_ALIGNMENT_KD gram_matrix_kd_enabled = ENABLE_GRAM_MATRIX_KD attention_kd_factor = ATTENTION_KD_FACTOR hidden_kd_factor = HIDDEN_STATE_KD_FACTOR intermediate_kd_factor = INTERMEDIATE_KD_FACTOR layer_norm_kd_factor = LAYER_NORM_KD_FACTOR embedding_kd_factor = EMBEDDING_KD_FACTOR parameter_kd_factor = PARAMETER_KD_FACTOR activation_kd_factor = ACTIVATION_KD_FACTOR logit_masking_factor = LOGIT_MASKING_FACTOR sparsity_regularization_factor = SPARCITY_REGULARIZATION_FACTOR feature_map_kd_factor = FEATURE_MAP_KD_FACTOR output_logit_kd_factor = OUTPUT_LOGIT_KD_FACTOR layerwise_parameter_kd_factor = LAYERWISE_PARAMETER_KD_FACTOR vocab_projection_kd_factor = VOCAB_PROJECTION_KD_FACTOR contrastive_kd_factor = CONTRASTIVE_KD_FACTOR rdrop_kd_factor = RDROP_KD_FACTOR adaptive_temperature_kd_factor = ADAPTIVE_TEMPERATURE_KD_FACTOR layer_wise_kd_factor = LAYER_WISE_KD_FACTOR activation_reg_lambda = ACTIVATION_REG_LAMBDA neuron_selectivity_kd_factor = NEURON_SELECTIVITY_KD_FACTOR weighted_parameter_kd_factor = WEIGHTED_PARAMETER_KD_FACTOR fsp_kd_factor = FSP_KD_FACTOR attention_alignment_kd_factor = ATTENTION_ALIGNMENT_KD_FACTOR gram_matrix_kd_factor = GRAM_MATRIX_KD_FACTOR attention_kd_loss_type = ATTENTION_KD_LOSS_TYPE hidden_kd_loss_type = HIDDEN_STATE_KD_LOSS_TYPE output_logit_kd_loss_type = OUTPUT_LOGIT_KD_LOSS_TYPE contrastive_kd_loss_type = CONTRASTIVE_KD_LOSS_TYPE rdrop_kd_loss_type = RDROP_KD_LOSS_TYPE layer_wise_loss_type = LAYER_WISE_LOSS_TYPE neuron_selectivity_loss_type = NEURON_SELECTIVITY_LOSS_TYPE weighted_parameter_loss_type = WEIGHTED_PARAMETER_LOSS_TYPE fsp_kd_loss_type = FSP_KD_LOSS_TYPE attention_alignment_loss_type = ATTENTION_ALIGNMENT_LOSS_TYPE gram_matrix_loss_type = GRAM_MATRIX_LOSS_TYPE intermediate_layers = INTERMEDIATE_LAYERS layer_norm_modules = LAYER_NORM_MODULES activation_modules = ACTIVATION_MODULES feature_map_modules = FEATURE_MAP_MODULES layerwise_parameter_modules = LAYERWISE_PARAMETER_MODULES vocab_projection_MODULES = VOCAB_PROJECTION_MODULES layer_wise_modules = LAYER_WISE_MODULES neuron_selectivity_modules = NEURON_SELECTIVITY_MODULES weighted_parameter_modules = WEIGHTED_PARAMETER_MODULES fsp_modules = FSP_MODULES attention_alignment_modules = ATTENTION_ALIGNMENT_MODULES gram_matrix_modules = GRAM_MATRIX_MODULES adaptive_temperature_initial = ADAPTIVE_TEMPERATURE_INITIAL adaptive_temperature_decay_rate = adaptive_temperature_decay_rate adaptive_temperature = torch.tensor(adaptive_temperature_initial, device=device, requires_grad=False) student_model.model = attach_activation_saver(student_model.model, activation_modules, feature_map_modules) for teacher_model in teacher_models: teacher_model.model = attach_activation_saver(teacher_model.model, activation_modules, feature_map_modules) optimizer = torch.optim.AdamW( student_model.model.parameters(), lr=LR_INITIAL, weight_decay=WEIGHT_DECAY ) scheduler = get_scheduler( name=SCHEDULER_TYPE, optimizer=optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps ) if ENABLE_LR_SCHEDULER else None scaler = torch.amp.GradScaler(enabled=use_mixed_precision) if ENABLE_STUDENT_PARAMETER_FREEZE: for param in student_model.model.parameters(): param.requires_grad = False log_to_file(f"Starting KD: freezing for {freeze_steps} steps.") else: log_to_file("KD started without initial freezing.") writer = SummaryWriter(log_dir=TENSORBOARD_LOG_DIR) MODELS['writer'] = writer best_loss = float('inf') patience_counter = 0 start_time = time.time() progress_bar = tqdm(range(KD_STEPS), desc="Knowledge Distillation Progress") for step in progress_bar: try: optimizer.zero_grad() loss_accum = 0.0 gc.collect() torch.cuda.empty_cache() for _ in range(ACCUMULATION_STEPS): batch_texts = ["This is a sample text for distillation.", "Another example sentence."] try: with torch.amp.autocast(device_type=device.type, enabled=ENABLE_MIXED_PRECISION): teacher_outputs_list = [] teacher_inputs_list = [] for teacher_model in MODELS['teacher']: with torch.no_grad(): teacher_inputs = teacher_model.tokenizer( batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=128, ).to(device) teacher_inputs_list.append(teacher_inputs) teacher_outputs = teacher_model.model(**teacher_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD) teacher_outputs_list.append(teacher_outputs) student_inputs = MODELS['student'].tokenizer( batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=128, ).to(device) student_outputs = MODELS['student'].model(**student_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD) student_logits = student_outputs.logits student_attentions = student_outputs.attentions if ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD else None student_hidden_states = student_outputs.hidden_states if ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD else None teacher_logits = torch.stack([teacher.logits for teacher in teacher_outputs_list]).mean(dim=0) teacher_attentions = [teacher.attentions for teacher in teacher_outputs_list] teacher_attentions = torch.stack([torch.stack(attn) for attn in teacher_attentions]).mean(dim=0) if teacher_attentions[0] is not None else None teacher_hidden_states = [teacher.hidden_states for teacher in teacher_outputs_list] teacher_hidden_states = torch.stack([torch.stack(hiddens) for hiddens in teacher_hidden_states]).mean(dim=0) if teacher_hidden_states[0] is not None else None min_vocab_size_logits = min(teacher_logits.size(-1), student_logits.size(-1)) teacher_logits_trimmed = teacher_logits[:, :, :min_vocab_size_logits] student_logits_trimmed = student_logits[:, :, :min_vocab_size_logits] teacher_embeds = torch.stack([MODELS['teacher'][0].model.get_input_embeddings()(teacher_inputs_list[0].input_ids) for _ in range(len(MODELS['teacher']))]).mean(dim=0) student_embeds = MODELS['student'].model.get_input_embeddings()(student_inputs.input_ids) ce_loss = torch.tensor(0.0).to(device) output_kd_loss = OUTPUT_LOGIT_KD_FACTOR * kd_loss_functions[OUTPUT_LOGIT_KD_LOSS_TYPE](student_logits_trimmed / adaptive_temperature, teacher_logits_trimmed / adaptive_temperature) if ENABLE_OUTPUT_LOGIT_KD else torch.tensor(0.0).to(device) loss = output_kd_loss loss = loss + CE_LOSS_FACTOR * ce_loss attn_loss = ATTENTION_KD_FACTOR * attention_kd_loss(student_attentions[-1], teacher_attentions[-1], ATTENTION_KD_LOSS_TYPE) if ENABLE_ATTENTION_KD and teacher_attentions is not None and student_attentions is not None else torch.tensor(0.0).to(device) loss += attn_loss hidden_loss = HIDDEN_STATE_KD_FACTOR * hidden_state_kd_loss(student_hidden_states[-1], teacher_hidden_states[-1], HIDDEN_STATE_KD_LOSS_TYPE) if ENABLE_HIDDEN_STATE_KD and teacher_hidden_states is not None and student_hidden_states is not None else torch.tensor(0.0).to(device) loss += hidden_loss intermediate_loss = 0.0 if ENABLE_INTERMEDIATE_KD and teacher_hidden_states is not None and student_hidden_states is not None: for layer_idx in INTERMEDIATE_LAYERS: if layer_idx < len(teacher_hidden_states) and layer_idx < len(student_hidden_states): teacher_layer_output = teacher_hidden_states[layer_idx] student_layer_output = student_hidden_states[layer_idx] layer_loss = hidden_state_kd_loss(student_layer_output, teacher_layer_output, HIDDEN_STATE_KD_LOSS_TYPE) intermediate_loss += layer_loss loss += INTERMEDIATE_KD_FACTOR * intermediate_loss layer_norm_loss = 0.0 if ENABLE_LAYER_NORM_KD: for layer_norm_module_name in LAYER_NORM_MODULES: if hasattr(MODELS['student'].model, layer_norm_module_name): student_ln = getattr(MODELS['student'].model, layer_norm_module_name) layer_norm_weights = [] layer_norm_biases = [] for teacher_model in MODELS['teacher']: if hasattr(teacher_model.model, layer_norm_module_name): teacher_ln = getattr(teacher_model.model, layer_norm_module_name) if isinstance(student_ln, nn.LayerNorm) and isinstance(teacher_ln, nn.LayerNorm): layer_norm_weights.append(teacher_ln.weight) layer_norm_biases.append(teacher_ln.bias) if layer_norm_weights: teacher_weight_mean = torch.stack(layer_norm_weights).mean(dim=0) teacher_bias_mean = torch.stack(layer_norm_biases).mean(dim=0) layer_norm_loss += F.mse_loss(student_ln.weight, teacher_weight_mean) layer_norm_loss += F.mse_loss(student_ln.bias, teacher_bias_mean) loss += LAYER_NORM_KD_FACTOR * layer_norm_loss embed_loss = EMBEDDING_KD_FACTOR * hidden_state_kd_loss(student_embeds, teacher_embeds, HIDDEN_STATE_KD_LOSS_TYPE) if ENABLE_EMBEDDING_KD else torch.tensor(0.0).to(device) loss += embed_loss parameter_loss = 0.0 if ENABLE_PARAMETER_KD: for name, student_param in MODELS['student'].model.named_parameters(): param_values = [] for teacher_model in MODELS['teacher']: if name in teacher_model.model.state_dict(): param_values.append(teacher_model.model.state_dict()[name]) if param_values: teacher_param_mean = torch.stack(param_values).mean(dim=0) parameter_loss += kd_loss_functions[WEIGHTED_PARAMETER_LOSS_TYPE](student_param, teacher_param_mean) loss += PARAMETER_KD_FACTOR * parameter_loss layerwise_parameter_loss = 0.0 if ENABLE_LAYERWISE_PARAMETER_KD: for layer_module_name in LAYERWISE_PARAMETER_MODULES: if hasattr(MODELS['student'].model, layer_module_name): student_layer_module = getattr(MODELS['student'].model, layer_module_name) for name, student_param in student_layer_module.named_parameters(): param_values = [] full_name = f"{layer_module_name}.{name}" for teacher_model in MODELS['teacher']: if hasattr(teacher_model.model, layer_module_name) and full_name in teacher_model.model.state_dict(): param_values.append(teacher_model.model.state_dict()[full_name]) if param_values: teacher_param_mean = torch.stack(param_values).mean(dim=0) layerwise_parameter_loss += kd_loss_functions[WEIGHTED_PARAMETER_LOSS_TYPE](student_param, teacher_param_mean) loss += LAYERWISE_PARAMETER_KD_FACTOR * layerwise_parameter_loss vocab_projection_loss = 0.0 if ENABLE_VOCAB_PROJECTION_KD: for vocab_module_name in VOCAB_PROJECTION_MODULES: if hasattr(MODELS['student'].model, vocab_module_name): student_vocab_module = getattr(MODELS['student'].model, vocab_module_name) projection_weights = [] for teacher_model in MODELS['teacher']: if hasattr(teacher_model.model, vocab_module_name): teacher_vocab_module = getattr(teacher_model.model, vocab_module_name) projection_weights.append(teacher_vocab_module.weight) if projection_weights: teacher_weight_mean = torch.stack(projection_weights).mean(dim=0) vocab_projection_loss += F.mse_loss(student_vocab_module.weight, teacher_weight_mean) loss += VOCAB_PROJECTION_KD_FACTOR * vocab_projection_loss activation_loss = 0.0 if ENABLE_ACTIVATION_KD: for name, module in MODELS['student'].model.named_modules(): if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES: if module.activation_values is not None: activations = module.activation_values activation_loss += F.mse_loss(activations, torch.zeros_like(activations)) loss += ACTIVATION_KD_FACTOR * activation_loss logit_masking_loss = 0.0 if ENABLE_LOGIT_MASKING_KD: teacher_probs = F.softmax(teacher_logits, dim=-1) top_k_indices = torch.topk(teacher_probs, k=10, dim=-1)[1] mask = torch.ones_like(student_logits).scatter_(-1, top_k_indices, 0.0).bool() masked_student_logits = student_logits.masked_fill(mask, -1e9) logit_masking_loss = kd_loss_functions['kl'](masked_student_logits / adaptive_temperature, teacher_logits / adaptive_temperature) * (adaptive_temperature ** 2) loss += LOGIT_MASKING_FACTOR * logit_masking_loss sparsity_loss = 0.0 if ENABLE_SPARCITY_REGULARIZATION: for name, module in MODELS['student'].model.named_modules(): if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES: if module.activation_values is not None: activations = module.activation_values sparsity_loss += torch.norm(activations, 1) loss += SPARCITY_REGULARIZATION_FACTOR * sparsity_loss feature_map_loss = 0.0 if ENABLE_FEATURE_MAP_KD: for name, student_module in MODELS['student'].model.named_modules(): if isinstance(student_module, ActivationSaver) and type(student_module.module).__name__ in FEATURE_MAP_MODULES: teacher_module = MODELS['teacher'][0].model.get_submodule(name) if isinstance(teacher_module, ActivationSaver) and teacher_module.feature_map_values is not None and student_module.feature_map_values is not None: student_feature_map = student_module.feature_map_values teacher_feature_map = teacher_module.feature_map_values feature_map_loss += F.mse_loss(student_feature_map, teacher_feature_map) loss += FEATURE_MAP_KD_FACTOR * feature_map_loss contrastive_loss = 0.0 if ENABLE_CONTRASTIVE_KD: student_vec = student_hidden_states[-1][:, 0, :] teacher_vec = teacher_hidden_states[-1][:, 0, :] contrastive_loss = CONTRASTIVE_KD_FACTOR * kd_loss_functions[CONTRASTIVE_KD_LOSS_TYPE](student_vec, teacher_vec) loss += contrastive_loss rdrop_loss = 0.0 if ENABLE_RDROP_KD: r_student_outputs = MODELS['student'].model(**student_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD) r_student_logits = r_student_outputs.logits rdrop_loss = RDROP_KD_FACTOR * (kd_loss_functions['kl'](F.log_softmax(student_logits, dim=-1), F.log_softmax(r_student_logits, dim=-1)) + kd_loss_functions['kl'](F.log_softmax(r_student_logits, dim=-1), F.log_softmax(student_logits, dim=-1))) loss += rdrop_loss layer_wise_loss = 0.0 if ENABLE_LAYER_WISE_KD and teacher_hidden_states is not None and student_hidden_states is not None: layer_wise_loss = 0.0 for layer_module_name in LAYER_WISE_MODULES: if hasattr(MODELS['student'].model, layer_module_name): student_layer_module = getattr(MODELS['student'].model, layer_module_name) for student_layer, teacher_layer in zip(student_layer_module.modules(), MODELS['teacher'][0].model.get_module(layer_module_name).modules()): if isinstance(student_layer, nn.Linear) and isinstance(teacher_layer, nn.Linear): student_output = student_layer(student_hidden_states[-1]) teacher_output = teacher_layer(teacher_hidden_states[-1]) layer_wise_loss += hidden_state_kd_loss(student_output, teacher_output, LAYER_WISE_LOSS_TYPE) loss += LAYER_WISE_KD_FACTOR * layer_wise_loss activation_regularization_loss = 0.0 if ENABLE_ACTIVATION_REGULARIZATION: activation_regularization_loss = 0.0 for name, module in MODELS['student'].model.named_modules(): if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES: if module.activation_values is not None: activations = module.activation_values activation_regularization_loss += torch.norm(activations, p=2) loss += ACTIVATION_REG_LAMBDA * activation_regularization_loss neuron_selectivity_loss = 0.0 if ENABLE_NEURON_SELECTIVITY_KD: neuron_selectivity_loss = 0.0 for name, module in MODELS['student'].model.named_modules(): if isinstance(module, ActivationSaver) and type(module.module).__name__ in NEURON_SELECTIVITY_MODULES: if module.activation_values is not None: student_activations = module.activation_values teacher_module = MODELS['teacher'][0].model.get_submodule(name) if isinstance(teacher_module, ActivationSaver) and teacher_module.activation_values is not None: teacher_activations = teacher_module.activation_values neuron_selectivity_loss += NEURON_SELECTIVITY_KD_FACTOR * kd_loss_functions[NEURON_SELECTIVITY_LOSS_TYPE](student_activations.mean(dim=0), teacher_activations.mean(dim=0)) loss += neuron_selectivity_loss l2_reg_loss = 0.0 if ENABLE_L2_REGULARIZATION: for param in MODELS['student'].model.parameters(): l2_reg_loss += torch.norm(param, p=2) loss += L2_LAMBDA * l2_reg_loss l1_reg_loss = 0.0 if ENABLE_L1_REGULARIZATION: for param in MODELS['student'].model.parameters(): l1_reg_loss += torch.norm(param, p=1) loss += L1_LAMBDA * l1_reg_loss loss_accum += (kd_loss_factor * loss + ce_loss_factor * ce_loss) / accumulation_steps except Exception as e: log_message(f"Error in inner distillation loop: {e}", level="warning") scaler.scale(loss_accum).backward() if ENABLE_GRADIENT_CLIPPING: try: torch.nn.utils.clip_grad_norm_(MODELS['student'].model.parameters(), GRAD_CLIP_VALUE) except Exception as e: log_message(f"Gradient clipping error: {e}", level="warning") scaler.step(optimizer) scaler.update() if scheduler: scheduler.step() if ENABLE_ADAPTIVE_TEMPERATURE_KD: adaptive_temperature *= adaptive_temperature_decay_rate adaptive_temperature = max(adaptive_temperature, 1.0) if ENABLE_STUDENT_PARAMETER_FREEZE: if ENABLE_DYNAMIC_FREEZE: if loss_accum.item() < FREEZE_THRESHOLD: for param in MODELS['student'].model.parameters(): param.requires_grad = True if step == 0: log_message("Dynamic unfreezing activated due to low loss.", level="info") elif step + 1 == FREEZE_STUDENT_STEPS: for param in MODELS['student'].model.parameters(): param.requires_grad = True log_message("Unfreezing after freeze_steps.", level="info") if SAVE_CHECKPOINTS and (step + 1) % CHECKPOINT_INTERVAL == 0: ckpt_path = os.path.join(CHECKPOINT_DIR, f"student_step_{step + 1}.pt") os.makedirs(CHECKPOINT_DIR, exist_ok=True) save_checkpoint(MODELS['student'], ckpt_path) log_to_file(f"Checkpoint saved to {ckpt_path}") if (step + 1) % LOG_INTERVAL == 0: elapsed = time.time() - start_time lr = optimizer.param_groups[0]['lr'] grad_norm = 0.0 num_params_with_grad = 0 for param in MODELS['student'].model.parameters(): if param.grad is not None: grad_norm += param.grad.data.norm(2).item() num_params_with_grad += 1 avg_grad_norm = grad_norm / num_params_with_grad if num_params_with_grad > 0 else 0.0 log_msg = (f"[KD] Step {step + 1}/{KD_STEPS}, Loss: {loss_accum.item():.4f}, " f"LR: {lr:.6f}, GradNorm: {avg_grad_norm:.4f}, Time: {elapsed:.2f}s, Temp: {adaptive_temperature.item():.3f}") print(log_msg) log_to_file(log_msg) writer.add_scalar("Loss/accumulated", loss_accum.item(), step + 1) writer.add_scalar("Learning_Rate", lr, step + 1) writer.add_scalar("GradNorm", avg_grad_norm, step + 1) writer.add_scalar("Output_Logit_KD_Loss", output_kd_loss.item(), step + 1) writer.add_scalar("Attention_KD_Loss", attn_loss.item(), step + 1) writer.add_scalar("HiddenState_KD_Loss", hidden_loss.item(), step + 1) writer.add_scalar("Intermediate_KD_Loss", intermediate_loss.item(), step + 1) writer.add_scalar("LayerNorm_KD_Loss", layer_norm_loss.item(), step + 1) writer.add_scalar("Embedding_KD_Loss", embed_loss.item(), step + 1) writer.add_scalar("Parameter_KD_Loss", parameter_loss.item(), step + 1) writer.add_scalar("Activation_KD_Loss", activation_loss.item(), step + 1) writer.add_scalar("LogitMasking_KD_Loss", logit_masking_loss.item(), step + 1) writer.add_scalar("Sparcity_Regularization_Loss", sparsity_loss.item(), step + 1) writer.add_scalar("FeatureMap_KD_Loss", feature_map_loss.item(), step + 1) writer.add_scalar("Layerwise_Parameter_KD_Loss", layerwise_parameter_loss.item(), step + 1) writer.add_scalar("VocabProjection_KD_Loss", vocab_projection_loss.item(), step + 1) writer.add_scalar("Contrastive_KD_Loss", contrastive_loss.item(), step + 1) writer.add_scalar("RDrop_KD_Loss", rdrop_loss.item(), step + 1) writer.add_scalar("Adaptive_Temperature", adaptive_temperature.item(), step + 1) writer.add_scalar("LayerWise_KD_Loss", layer_wise_loss.item(), step + 1) writer.add_scalar("Activation_Regularization_Loss", activation_regularization_loss.item(), step + 1) writer.add_scalar("NeuronSelectivity_KD_Loss", neuron_selectivity_loss.item(), step + 1) writer.add_scalar("WeightedParameter_KD_Loss", weighted_parameter_loss.item(), step + 1) if ENABLE_EARLY_STOPPING: if loss_accum.item() < best_loss: best_loss = loss_accum.item() patience_counter = 0 else: patience_counter += 1 if patience_counter >= EARLY_STOPPING_PATIENCE: log_message(f"Early stopping activated at step {step + 1} after {EARLY_STOPPING_PATIENCE} steps.", level="info") log_to_file(f"Early stopping activated at step {step + 1}.") break except Exception as e: log_message(f"Error in distillation step {step + 1}: {e}", level="warning") continue progress_bar.close() writer.close() MODELS['writer'] = None cleanup_and_exit(0) return MODELS['student'] def push_model_to_hub(model, tokenizer, quantization_method, repo_name, use_auth_token): try: log_message(f"Saving {model.__class__.__name__} to {repo_name} with method '{quantization_method}'...", level="info") model.model.save_pretrained(repo_name, push_to_hub=False) tokenizer.save_pretrained(repo_name, push_to_hub=False) model.push_to_hub(repo_name, use_auth_token=use_auth_token) tokenizer.push_to_hub(repo_name, use_auth_token=use_auth_token) log_message("Upload completed.", level="info") log_to_file(f"Model and tokenizer uploaded to {repo_name} with method {quantization_method}.") except Exception as e: log_message(f"Error during push to hub: {e}", level="warning") def login_to_huggingface(token): global AUTH_TOKEN, USER_NAME try: user_info = HfApi(token=token).whoami() AUTH_TOKEN = token USER_NAME = user_info['name'] log_message(f"Successfully logged in to Hugging Face as {USER_NAME}.", level="info") return AUTH_TOKEN, USER_NAME, None except Exception as e: log_message(f"Hugging Face Hub login error: {e}", level="warning") return None, None, "Invalid Hugging Face token." def run_fusion_distillation(teacher_model_ckpt_1, teacher_model_ckpt_2, student_model_ckpt, repo_name, disable_mean_resizing, huggingface_token): global MODELS, LOG_TEXT, GRADIO_LOG_OUTPUT, AUTH_TOKEN, USER_NAME LOG_TEXT = "" GRADIO_LOG_OUTPUT.value = LOG_TEXT mean_resizing = not disable_mean_resizing token, username, error_message = login_to_huggingface(huggingface_token) if token is None: log_message("Hugging Face login failed.", level="warning") return error_message AUTH_TOKEN = token USER_NAME = username try: log_message(f"Authenticated with Hugging Face Hub as user: {USER_NAME}.", level="info") except Exception as e: log_message(f"Hub login error: {e}", level="warning") return "Hugging Face Hub login failed: " + str(e) try: student_tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt) MODELS['student'] = pipeline(model=student_model_ckpt, tokenizer=student_tokenizer, device=device) if MODELS['student'].tokenizer is None: log_message("Student model tokenizer could not be loaded.", level="warning") return "Student model tokenizer could not be loaded." special_token = "[UNFILTERED]" if special_token not in MODELS['student'].tokenizer.get_vocab(): MODELS['student'].tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) log_message("Special token added to student tokenizer.", level="info") MODELS['student'].model.resize_token_embeddings(len(MODELS['student'].tokenizer)) except Exception as e: log_message(f"Error loading student model: {e}", level="warning") return "Error loading student model: " + str(e) teacher_models = [] teacher_tokenizers = [] teacher_model_checkpoints = [teacher_model_ckpt_1] if teacher_model_ckpt_2: teacher_model_checkpoints.append(teacher_model_ckpt_2) if not isinstance(teacher_model_checkpoints, list): log_message("Error loading teacher models: Teacher Model Checkpoints must be a list.", level="warning") return "Error loading teacher models: Teacher Model Checkpoints must be a list." for i, teacher_model_ckpt in enumerate(teacher_model_checkpoints): if not isinstance(teacher_model_ckpt, str): log_message(f"Error loading teacher models: not a string at index {i}", level="warning") return f"Error loading teacher models: not a string at index {i}" try: teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_ckpt) teacher_pipeline = pipeline(model=teacher_model_ckpt, tokenizer=teacher_tokenizer, device=device) teacher_models.append(teacher_pipeline) teacher_tokenizers.append(teacher_tokenizer) log_message(f"Teacher model {i + 1} loaded: {teacher_model_ckpt}", level="info") except Exception as e: log_message(f"Error loading teacher model {i + 1} ({teacher_model_ckpt}): {e}", level="warning") return f"Error loading teacher model {i + 1}: {e}" MODELS['teacher'] = teacher_models try: unify_parameters(MODELS['student'], MODELS['teacher'][0]) unify_embeddings(MODELS['student'], MODELS['teacher'][0], mean_resizing=mean_resizing) MODELS['student'].tokenizer = unify_tokenizers(MODELS['student'].tokenizer, teacher_tokenizers[0], MODELS['student']) MODELS['student'].tokenizer = fuse_tokenizers(teacher_tokenizers, MODELS['student'].tokenizer) except Exception as e: log_message(f"Model unification error: {e}", level="warning") fusion_method_name = 'geometric_mean_double' try: unified_teacher_state = complete_unify_teacher_models_double(MODELS['teacher'], fusion_method=fusion_method_name, layer_scale=1.0, device=device) MODELS['student'] = unify_teacher_into_student(unified_teacher_state, MODELS['student'], force_parameter_copy=True) MODELS['student'] = update_student_embeddings_double(MODELS['student'], MODELS['teacher'], fusion_type=fusion_method_name, device=device, mean_resizing=mean_resizing) except Exception as e: log_message(f"Teacher model fusion error: {e}", level="warning") try: unified_student = advanced_knowledge_distillation(MODELS['teacher'], MODELS['student'], device) except Exception as e: log_message(f"Knowledge distillation error: {e}", level="warning") return "Knowledge Distillation Failed, but Fusion might be partially completed. Check logs." if ENABLE_PARAMETER_COUNT_CHECK: student_param_count = sum(p.numel() for p in unified_student.model.parameters()) teacher_param_count = sum(p.numel() for p in MODELS['teacher'][0].model.parameters()) if student_param_count != teacher_param_count: log_message(f"Warning: student parameters ({student_param_count:,}) differ from teacher ({teacher_param_count:,}).", level="warning") else: log_message(f"Parameter count consistent: Student and Teacher ({teacher_param_count:,}).", level="info") sample_texts = ["What is the capital of France?", "Solve: 3 + 5 * 2", "Define a function in Python."] try: predictions = generate_predictions(unified_student, MODELS['student'].tokenizer, sample_texts) log_message("Sample Predictions after KD:\n" + "\n".join(predictions), level="info") except Exception as e: log_message(f"Sample prediction error: {e}", level="warning") api = HfApi() try: api.create_repo(repo_name, exist_ok=True, token=AUTH_TOKEN) push_model_to_hub(unified_student, MODELS['student'].tokenizer, "unified_teacher_kd_full_options_default_true_geometric_mean_double_fusion_v11", repo_name, AUTH_TOKEN) final_msg = f"Student model fused, distilled, and uploaded to Hub to repo: {repo_name} with all advanced options and 'geometric_mean_double' fusion!" print(final_msg) log_to_file(final_msg, "training_log.txt") return "Fusion and Distillation Completed! Check console and logs." except Exception as e: log_message(f"Final upload/repo creation error: {e}", level="warning") return "Error during final upload or repository creation: " + str(e) if __name__ == "__main__": with gr.Blocks(css=".gradio-container {padding: 20px}") as iface: gr.Markdown("# Fusion and Distillation Pipeline for Language Models") gr.Markdown( "This application fuses and distills knowledge from teacher language models into a student model. " "It supports advanced knowledge distillation techniques and model fusion strategies." ) huggingface_token_input = gr.Textbox(label="Hugging Face Token", type="password", visible=True) username_display = gr.Textbox(label="Hugging Face Username", interactive=False, visible=False) def process_token(huggingface_token): _, username, error_message = login_to_huggingface(huggingface_token) return username, error_message huggingface_token_input.change( process_token, inputs=[huggingface_token_input], outputs=[username_display] ) with gr.Column(): gr.Markdown("## Model Selection") teacher_model_ckpt_1_input = HuggingfaceHubSearch(label="Teacher Model 1") teacher_model_ckpt_2_input = HuggingfaceHubSearch(label="Teacher Model 2 (Optional)") student_model_ckpt_input = HuggingfaceHubSearch(label="Student Model") repo_name_input = gr.Textbox( label="Repository Name", info="Enter the name of the Hugging Face repository to create or update with the distilled model." ) disable_mean_resizing_checkbox = gr.Checkbox( label="Disable Mean Resizing", info="Check to disable mean resizing of embeddings during unification." ) with gr.Column(): gr.Markdown("## Execution and Logs") run_button = gr.Button("Run Fusion and Distillation", variant="primary", interactive=True, elem_id='run-button') output_status_textbox = gr.Textbox( info="Real-time status and logs of the fusion and distillation process.", label="Status", lines=3 ) GRADIO_LOG_OUTPUT = gr.Textbox( value="", label="Detailed Log Output", lines=10, interactive=False ) def update_run_button_interactivity(token): return gr.Button.update(interactive=True) huggingface_token_input.change( update_run_button_interactivity, inputs=[huggingface_token_input], outputs=[run_button] ) run_button.click( run_fusion_distillation, inputs=[ teacher_model_ckpt_1_input, teacher_model_ckpt_2_input, student_model_ckpt_input, repo_name_input, disable_mean_resizing_checkbox, huggingface_token_input ], outputs=output_status_textbox, ) iface.launch(server_name="0.0.0.0", server_port=7860)