Kfjjdjdjdhdhd's picture
Update app.py
3f08f0d verified
import gradio as gr
import os
import unicodedata
import logging
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from transformers import AutoConfig, get_scheduler, pipeline, AutoTokenizer, AutoModelForPreTraining, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModel
from huggingface_hub import HfApi
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import psutil
import signal
import sys
import gc
from gradio_huggingfacehub_search import HuggingfaceHubSearch
warnings.filterwarnings("ignore", message="cannot set number of interop threads after parallel work has started or set_num_interop_threads called")
torch.set_num_threads(os.cpu_count() // 2 if os.cpu_count() > 1 else 1)
device = torch.device("cpu")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
MODELS = {}
LOG_TEXT = ""
GRADIO_LOG_OUTPUT = None
HF_API = HfApi()
AUTH_TOKEN = None
USER_NAME = None
SAVE_ON_RESOURCE_TERMINATION = True
CHECKPOINT_DIR = "./fusion_distillation_checkpoints"
TENSORBOARD_LOG_DIR = "./tensorboard_logs"
SAVE_CHECKPOINTS = True
CHECKPOINT_INTERVAL = 1000
LOG_INTERVAL = 100
KD_STEPS = 10000
ACCUMULATION_STEPS = 1
FREEZE_STUDENT_STEPS = 2000
ENABLE_MIXED_PRECISION = False
KD_LOSS_FACTOR = 1.0
CE_LOSS_FACTOR = 0.0
KD_TEMPERATURE = 2.0
ENABLE_ATTENTION_KD = True
ENABLE_HIDDEN_STATE_KD = True
ENABLE_INTERMEDIATE_KD = True
ENABLE_LAYER_NORM_KD = True
ENABLE_EMBEDDING_KD = True
ENABLE_PARAMETER_KD = True
ENABLE_ACTIVATION_KD = True
ENABLE_LOGIT_MASKING_KD = True
ENABLE_SPARCITY_REGULARIZATION = True
ENABLE_FEATURE_MAP_KD = True
ENABLE_OUTPUT_LOGIT_KD = True
ENABLE_LAYERWISE_PARAMETER_KD = True
ENABLE_VOCAB_PROJECTION_KD = True
ENABLE_CONTRASTIVE_KD = True
ENABLE_RDROP_KD = True
ENABLE_ADAPTIVE_TEMPERATURE_KD = True
ENABLE_LAYER_WISE_KD = True
ENABLE_ACTIVATION_REGULARIZATION = True
ENABLE_NEURON_SELECTIVITY_KD = True
ENABLE_WEIGHTED_PARAMETER_KD = True
ENABLE_FSP_KD = True
ENABLE_ATTENTION_ALIGNMENT_KD = True
ENABLE_GRAM_MATRIX_KD = True
ATTENTION_KD_FACTOR = 0.1
HIDDEN_STATE_KD_FACTOR = 0.1
INTERMEDIATE_KD_FACTOR = 0.1
LAYER_NORM_KD_FACTOR = 0.01
EMBEDDING_KD_FACTOR = 0.01
PARAMETER_KD_FACTOR = 0.001
ACTIVATION_KD_FACTOR = 0.0001
LOGIT_MASKING_FACTOR = 0.01
SPARCITY_REGULARIZATION_FACTOR = 1e-5
FEATURE_MAP_KD_FACTOR = 0.005
OUTPUT_LOGIT_KD_FACTOR = 1.0
LAYERWISE_PARAMETER_KD_FACTOR = 0.0005
VOCAB_PROJECTION_KD_FACTOR = 0.001
CONTRASTIVE_KD_FACTOR = 0.05
RDROP_KD_FACTOR = 0.02
ADAPTIVE_TEMPERATURE_KD_FACTOR = 0.01
LAYER_WISE_KD_FACTOR = 0.05
ACTIVATION_REG_LAMBDA = 1e-7
NEURON_SELECTIVITY_KD_FACTOR = 0.001
WEIGHTED_PARAMETER_KD_FACTOR = 0.0002
FSP_KD_FACTOR = 0.001
ATTENTION_ALIGNMENT_KD_FACTOR = 0.01
GRAM_MATRIX_KD_FACTOR = 0.0005
ATTENTION_KD_LOSS_TYPE = 'mse'
HIDDEN_STATE_KD_LOSS_TYPE = 'mse'
OUTPUT_LOGIT_KD_LOSS_TYPE = 'kl'
CONTRASTIVE_KD_LOSS_TYPE = 'cosine'
RDROP_KD_LOSS_TYPE = 'kl'
LAYER_WISE_LOSS_TYPE = 'mse'
NEURON_SELECTIVITY_LOSS_TYPE = 'mse'
WEIGHTED_PARAMETER_LOSS_TYPE = 'mse'
FSP_KD_LOSS_TYPE = 'mse'
ATTENTION_ALIGNMENT_LOSS_TYPE = 'mse'
GRAM_MATRIX_LOSS_TYPE = 'mse'
INTERMEDIATE_LAYERS = [2, 5, 8]
LAYER_NORM_MODULES = ['LayerNorm']
ACTIVATION_MODULES = ['Linear']
FEATURE_MAP_MODULES = ['Linear']
LAYERWISE_PARAMETER_MODULES = ['transformer.h']
VOCAB_PROJECTION_MODULES = ['lm_head']
LAYER_WISE_MODULES = ['transformer.h']
NEURON_SELECTIVITY_MODULES = ['Linear']
WEIGHTED_PARAMETER_MODULES = ['Linear']
FSP_MODULES = ['Linear']
ATTENTION_ALIGNMENT_MODULES = ['SelfAttention']
GRAM_MATRIX_MODULES = ['Linear']
ADAPTIVE_TEMPERATURE_INITIAL = 2.0
ADAPTIVE_TEMPERATURE_DECAY_RATE = 0.999
LR_INITIAL = 5e-5
WEIGHT_DECAY = 1e-4
SCHEDULER_TYPE = "linear"
WARMUP_STEPS = 500
ENABLE_LR_SCHEDULER = True
ENABLE_STUDENT_PARAMETER_FREEZE = True
ENABLE_DYNAMIC_FREEZE = False
FREEZE_THRESHOLD = 0.1
ENABLE_GRADIENT_CLIPPING = True
GRAD_CLIP_VALUE = 1.0
ENABLE_EARLY_STOPPING = True
EARLY_STOPPING_PATIENCE = 5
ENABLE_PARAMETER_COUNT_CHECK = False
ENABLE_EMBEDDING_NOISE = True
EMBEDDING_NOISE_STD = 1e-4
ENABLE_L2_REGULARIZATION = False
L2_LAMBDA = 1e-6
ENABLE_L1_REGULARIZATION = False
L1_LAMBDA = 1e-6
def cleanup_and_exit(signal_code):
log_message("Initiating cleanup...", level="info")
if MODELS.get('student') and SAVE_ON_RESOURCE_TERMINATION:
try:
student_model = MODELS['student']
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
checkpoint_path = os.path.join(CHECKPOINT_DIR, "student_model_terminated.pt")
save_checkpoint(student_model, checkpoint_path)
log_message(f"Model checkpoint saved to {checkpoint_path} due to termination signal.", level="info")
except Exception as e:
log_message(f"Error during checkpoint save on cleanup: {e}", level="error")
if MODELS.get('writer'):
try:
MODELS['writer'].close()
log_message("TensorBoard writer closed.", level="info")
except Exception as e:
log_message(f"Error closing TensorBoard writer during cleanup: {e}", level="error")
log_message(f"Cleanup completed.", level="info")
def signal_handler(sig, frame):
log_message(f"Received signal {sig}. Continuing process...", level="warning")
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def log_message(message, level="info"):
global LOG_TEXT, GRADIO_LOG_OUTPUT
log_str = f"{message}\n"
LOG_TEXT += log_str
if GRADIO_LOG_OUTPUT:
GRADIO_LOG_OUTPUT.value = LOG_TEXT
if level == "info":
logger.info(message)
elif level == "warning":
logger.warning(message)
elif level == "error":
logger.error(message)
print(message)
tqdm.write(message)
def log_to_file(message, log_file="training_log.txt"):
try:
with open(log_file, "a") as f:
f.write(f"{message}\n")
except Exception as e:
log_message(f"Error writing to log file: {e}", level="warning")
def unify_parameters(student_model, teacher_model, exclude_layers=None):
try:
teacher_state = teacher_model.model.state_dict()
student_state = student_model.model.state_dict()
excluded_names = exclude_layers or []
for name, param in student_state.items():
if any(excluded_name in name for excluded_name in excluded_names):
continue
if name in teacher_state:
try:
if student_state[name].shape == teacher_state[name].shape:
student_state[name].copy_(teacher_state[name])
else:
min_shape = [min(s, t) for s, t in zip(student_state[name].shape, teacher_state[name].shape)]
student_slice = tuple([slice(0, s) for s in min_shape])
teacher_slice = tuple([slice(0, s) for s in min_shape])
student_state[name][student_slice].copy_(teacher_state[name][teacher_slice])
except Exception as e:
log_message(f"Parameter copy error for {name}: {e}", level="warning")
student_model.model.load_state_dict(student_state, strict=False)
except Exception as e:
log_message(f"Error in unify_parameters: {e}", level="warning")
def unify_embeddings(student_model, teacher_model, project_embeddings=True, mean_resizing=False):
try:
if hasattr(student_model.model, "get_input_embeddings") and hasattr(teacher_model.model, "get_input_embeddings"):
student_emb = student_model.model.get_input_embeddings()
teacher_emb = teacher_model.model.get_input_embeddings()
if project_embeddings:
in_dim = teacher_emb.weight.shape[1]
out_dim = student_emb.weight.shape[1]
projection = nn.Linear(in_dim, out_dim).to(device)
teacher_emb_projected = projection(teacher_emb.weight)
student_vocab_size = student_emb.weight.shape[0]
teacher_vocab_size_proj = teacher_emb_projected.shape[0]
if mean_resizing and student_vocab_size < teacher_vocab_size_proj:
try:
student_model.model.resize_token_embeddings(teacher_vocab_size_proj)
student_emb = student_model.model.get_input_embeddings()
except Exception as e:
log_message(f"Error resizing student embeddings: {e}", level="warning")
min_vocab_size = min(student_emb.weight.shape[0], teacher_vocab_size_proj)
try:
student_emb.weight.data[:min_vocab_size].copy_(teacher_emb_projected.data[:min_vocab_size])
except Exception as e:
log_message(f"Error copying projected embeddings: {e}", level="warning")
else:
min_vocab_size = min(student_emb.weight.shape[0], teacher_emb.weight.shape[0])
try:
student_emb.weight.data[:min_vocab_size].copy_(teacher_emb.weight.data[:min_vocab_size])
except Exception as e:
log_message(f"Error copying embeddings: {e}", level="warning")
if hasattr(student_model.model, "get_output_embeddings") and hasattr(teacher_model.model, "get_output_embeddings"):
student_out_emb = student_model.model.get_output_embeddings()
teacher_out_emb = teacher_model.model.get_output_embeddings()
if student_out_emb is not None and teacher_out_emb is not None:
if project_embeddings:
in_dim = teacher_out_emb.weight.shape[1]
out_dim = student_out_emb.weight.shape[1]
projection = nn.Linear(in_dim, out_dim).to(device)
teacher_out_emb_projected = projection(teacher_out_emb.weight)
student_vocab_size = student_out_emb.weight.shape[0]
teacher_vocab_size_proj = teacher_out_emb_projected.shape[0]
if mean_resizing and student_vocab_size < teacher_vocab_size_proj:
try:
student_model.model.resize_token_embeddings(teacher_vocab_size_proj)
student_out_emb = student_model.model.get_output_embeddings()
except Exception as e:
log_message(f"Error resizing student output embeddings: {e}", level="warning")
min_vocab_size = min(student_out_emb.weight.shape[0], teacher_vocab_size_proj)
try:
student_out_emb.weight.data[:min_vocab_size].copy_(teacher_out_emb_projected.data[:min_vocab_size])
except Exception as e:
log_message(f"Error copying projected output embeddings: {e}", level="warning")
else:
min_vocab_size = min(student_out_emb.weight.shape[0], teacher_out_emb.weight.shape[0])
try:
student_out_emb.weight.data[:min_vocab_size].copy_(teacher_out_emb.weight.data[:min_vocab_size])
except Exception as e:
log_message(f"Error copying output embeddings: {e}", level="warning")
except Exception as e:
log_message(f"Error in unify_embeddings: {e}", level="warning")
def unify_tokenizers(student_tokenizer, teacher_tokenizer, student_model):
if teacher_tokenizer is None:
return student_tokenizer
try:
teacher_vocab = teacher_tokenizer.get_vocab()
student_vocab = student_tokenizer.get_vocab()
new_tokens = [token for token in teacher_vocab if token not in student_vocab]
if new_tokens:
student_tokenizer.add_tokens(new_tokens)
student_model.model.resize_token_embeddings(len(student_tokenizer))
except Exception as e:
log_message(f"Tokenizer unification error: {e}", level="warning")
return student_tokenizer
def normalize_text(text):
return unicodedata.normalize('NFKC', text)
def generate_predictions(model, tokenizer, texts, max_length=150, **tokenizer_kwargs):
try:
inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=max_length, use_fast=True, **tokenizer_kwargs).to(device)
outputs = model.model.generate(**inputs)
return list(map(lambda output: tokenizer.decode(output, skip_special_tokens=True), outputs))
except Exception as e:
log_message(f"Prediction generation error: {e}", level="warning")
return [""] * len(texts)
fusion_functions = {
'geometric_mean_double': lambda teacher_states, layer_scale, key, device: geometric_mean_fusion_double(teacher_states, device, len(teacher_states), layer_scale, key),
}
def geometric_mean_fusion_double(states, device, num_teachers, layer_scale, key):
min_len = min([state[key].shape[0] for state in states])
log_abs_sum = sum([np.log(torch.abs(state[key][:min_len].detach().cpu().numpy())) for state in states])
sign_prod = np.prod([np.sign(state[key][:min_len].detach().cpu().numpy()) for state in states])
geometric_mean_val = torch.tensor(np.exp(log_abs_sum * (1.0 / num_teachers)) * sign_prod, device=device) * layer_scale
return geometric_mean_val
def complete_unify_teacher_models_double(teacher_models, fusion_method='geometric_mean_double', layer_scale=1.0, layer_weights=None, device=device):
teacher_states = [teacher.model.state_dict() for teacher in teacher_models]
unified_state = {}
first_teacher_state = teacher_states[0]
for key in first_teacher_state.keys():
teacher_layer_states = []
all_teachers_have_layer = True
for state in teacher_states:
if key in state:
teacher_layer_states.append(state)
else:
all_teachers_have_layer = False
break
if all_teachers_have_layer:
if fusion_method in fusion_functions:
unified_state[key] = fusion_functions['geometric_mean_double'](teacher_layer_states, layer_scale, key, device=device)
else:
layer_sum = torch.stack([state[key] for state in teacher_layer_states]).sum(dim=0)
unified_state[key] = layer_sum / len(teacher_models)
else:
unified_state[key] = first_teacher_state[key]
return unified_state
def unify_teacher_into_student(unified_teacher_state, student, force_parameter_copy=False):
student_state = student.model.state_dict()
new_state = {}
for key, student_value in student_state.items():
if key in unified_teacher_state:
teacher_value = unified_teacher_state[key]
try:
if student_value.shape == teacher_value.shape:
new_state[key] = teacher_value
else:
min_shape = [min(s, t) for s, t in zip(student_value.shape, teacher_value.shape)]
student_slice = tuple([slice(0, s) for s in min_shape])
teacher_slice = tuple([slice(0, s) for s in min_shape])
new_state[key] = student_value.clone()
new_state[key][student_slice] = teacher_value[teacher_slice]
except Exception as e:
log_message(f"Parameter assignment error for {key}: {e}", level="warning")
else:
new_state[key] = student_value
student.model.load_state_dict(new_state, strict=False)
return student
def fuse_tokenizers(teacher_tokenizers, student_tokenizer):
unified_vocab = set(student_tokenizer.get_vocab().keys()) if student_tokenizer else set()
for teacher_tokenizer in teacher_tokenizers:
if teacher_tokenizer:
teacher_vocab = set(teacher_tokenizer.get_vocab().keys())
unified_vocab = unified_vocab.union(teacher_vocab)
if student_tokenizer:
try:
student_tokenizer.add_tokens(list(unified_vocab - set(student_tokenizer.get_vocab().keys())))
except Exception as e:
log_message(f"Error fusing tokenizers: {e}", level="warning")
return student_tokenizer
def update_student_embeddings_double(student, teacher_models, fusion_type='geometric_mean_double', device=device, mean_resizing=False):
emb_student = student.model.get_input_embeddings().weight.data
student_vocab_size, student_emb_dim = emb_student.shape
teacher_embeddings_proj = []
min_teacher_vocab_size = float('inf')
for teacher in teacher_models:
emb_teacher = teacher.model.get_input_embeddings().weight.data.detach().cpu()
teacher_vocab_size, teacher_emb_dim = emb_teacher.shape
min_teacher_vocab_size = min(min_teacher_vocab_size, teacher_vocab_size)
if teacher_emb_dim != student_emb_dim:
proj = nn.Linear(teacher_emb_dim, student_emb_dim, bias=False).to(device)
emb_teacher_proj = proj(emb_teacher.to(device)).cpu()
else:
emb_teacher_proj = emb_teacher
teacher_embeddings_proj.append(emb_teacher_proj)
min_vocab = min(min_teacher_vocab_size, student_vocab_size)
fusion_function = fusion_functions.get(fusion_type, fusion_functions['geometric_mean_double'])
teacher_states = [{'embedding': emb} for emb in teacher_embeddings_proj]
emb_updated = fusion_function(teacher_states, 1.0, 'embedding', device=device)[:min_vocab] if fusion_type != 'layerwise' else fusion_functions['geometric_mean_double'](teacher_states, 1.0, 'embedding', device=device)[:min_vocab]
if ENABLE_EMBEDDING_NOISE:
noise = torch.randn_like(emb_updated) * EMBEDDING_NOISE_STD
emb_updated = emb_updated + noise
try:
student.model.get_input_embeddings().weight.data[:min_vocab].copy_(emb_updated.to(device))
except Exception as e:
log_message(f"Error updating student embeddings: {e}", level="warning")
return student
def save_checkpoint(model, checkpoint_path):
try:
torch.save(model.model.state_dict(), checkpoint_path)
log_message(f"Checkpoint saved to: {checkpoint_path}", level="info")
except Exception as e:
log_message(f"Error saving checkpoint: {e}", level="warning")
kd_loss_functions = {
'mse': F.mse_loss,
'kl': lambda student, teacher: F.kl_div(F.log_softmax(student, dim=-1), F.softmax(teacher.detach(), dim=-1), reduction='batchmean', log_target=True),
'cosine': lambda student, teacher: 1.0 - F.cosine_similarity(student.view(-1, student.size(-1)), teacher.detach().view(-1, teacher.size(-1))).mean(),
}
class ActivationSaver:
def __init__(self, module):
self.module = module
self.activation_values = None
self.feature_map_values = None
def __call__(self, *args, **kwargs):
output = self.module(*args, **kwargs)
if isinstance(output, tuple):
self.activation_values = output[0]
self.feature_map_values = output[1] if len(output) > 1 else None
return output
else:
self.activation_values = output
return output
def attach_activation_saver(model, activation_modules, feature_map_modules):
for name, module in model.model.named_modules():
if type(module).__name__ in activation_modules or type(module).__name__ in feature_map_modules:
wrapped_module = ActivationSaver(module)
setattr(model.model, name, wrapped_module)
return model
def attention_kd_loss(student_attention, teacher_attention, loss_type='mse'):
try:
if loss_type == 'mse':
return F.mse_loss(student_attention, teacher_attention.detach())
elif loss_type == 'kl':
return F.kl_div(F.log_softmax(student_attention, dim=-1), F.softmax(teacher_attention.detach(), dim=-1), reduction='batchmean', log_target=True)
elif loss_type == 'cosine':
return 1.0 - F.cosine_similarity(student_attention.view(-1), teacher_attention.detach().view(-1)).mean()
except Exception as e:
log_message(f"Attention KD Loss error: {e}", level="warning")
return torch.tensor(0.0, device=device)
def hidden_state_kd_loss(student_hidden, teacher_hidden, loss_type='mse'):
try:
if loss_type == 'mse':
return F.mse_loss(student_hidden, teacher_hidden.detach())
elif loss_type == 'kl':
return F.kl_div(F.log_softmax(student_hidden, dim=-1), F.softmax(teacher_hidden.detach(), dim=-1), reduction='batchmean', log_target=True)
elif loss_type == 'cosine':
return 1.0 - F.cosine_similarity(student_hidden.view(-1), teacher_hidden.detach().view(-1)).mean()
except Exception as e:
log_message(f"Hidden State KD Loss error: {e}", level="warning")
return torch.tensor(0.0, device=device)
def advanced_knowledge_distillation(teacher_models, student_model, device):
for teacher_model in teacher_models:
teacher_model.model.eval()
student_model.model.train()
total_steps = KD_STEPS
accumulation_steps = ACCUMULATION_STEPS
freeze_steps = FREEZE_STUDENT_STEPS
use_mixed_precision = ENABLE_MIXED_PRECISION
kd_loss_factor = KD_LOSS_FACTOR
ce_loss_factor = CE_LOSS_FACTOR
kd_temperature = KD_TEMPERATURE
attention_kd_enabled = ENABLE_ATTENTION_KD
hidden_kd_enabled = ENABLE_HIDDEN_STATE_KD
intermediate_kd_enabled = ENABLE_INTERMEDIATE_KD
layer_norm_kd_enabled = ENABLE_LAYER_NORM_KD
embedding_kd_enabled = ENABLE_EMBEDDING_KD
parameter_kd_enabled = ENABLE_PARAMETER_KD
activation_kd_enabled = ENABLE_ACTIVATION_KD
logit_masking_enabled = ENABLE_LOGIT_MASKING_KD
sparsity_regularization_enabled = ENABLE_SPARCITY_REGULARIZATION
feature_map_kd_enabled = ENABLE_FEATURE_MAP_KD
output_logit_kd_enabled = ENABLE_OUTPUT_LOGIT_KD
layerwise_parameter_kd_enabled = ENABLE_LAYERWISE_PARAMETER_KD
vocab_projection_kd_enabled = ENABLE_VOCAB_PROJECTION_KD
contrastive_kd_enabled = ENABLE_CONTRASTIVE_KD
rdrop_kd_enabled = ENABLE_RDROP_KD
adaptive_temperature_kd_enabled = ENABLE_ADAPTIVE_TEMPERATURE_KD
layer_wise_kd_enabled = ENABLE_LAYER_WISE_KD
activation_reg_enabled = ENABLE_ACTIVATION_REGULARIZATION
neuron_selectivity_kd_enabled = ENABLE_NEURON_SELECTIVITY_KD
weighted_parameter_kd_enabled = ENABLE_WEIGHTED_PARAMETER_KD
fsp_kd_enabled = ENABLE_FSP_KD
attention_alignment_kd_enabled = ENABLE_ATTENTION_ALIGNMENT_KD
gram_matrix_kd_enabled = ENABLE_GRAM_MATRIX_KD
attention_kd_factor = ATTENTION_KD_FACTOR
hidden_kd_factor = HIDDEN_STATE_KD_FACTOR
intermediate_kd_factor = INTERMEDIATE_KD_FACTOR
layer_norm_kd_factor = LAYER_NORM_KD_FACTOR
embedding_kd_factor = EMBEDDING_KD_FACTOR
parameter_kd_factor = PARAMETER_KD_FACTOR
activation_kd_factor = ACTIVATION_KD_FACTOR
logit_masking_factor = LOGIT_MASKING_FACTOR
sparsity_regularization_factor = SPARCITY_REGULARIZATION_FACTOR
feature_map_kd_factor = FEATURE_MAP_KD_FACTOR
output_logit_kd_factor = OUTPUT_LOGIT_KD_FACTOR
layerwise_parameter_kd_factor = LAYERWISE_PARAMETER_KD_FACTOR
vocab_projection_kd_factor = VOCAB_PROJECTION_KD_FACTOR
contrastive_kd_factor = CONTRASTIVE_KD_FACTOR
rdrop_kd_factor = RDROP_KD_FACTOR
adaptive_temperature_kd_factor = ADAPTIVE_TEMPERATURE_KD_FACTOR
layer_wise_kd_factor = LAYER_WISE_KD_FACTOR
activation_reg_lambda = ACTIVATION_REG_LAMBDA
neuron_selectivity_kd_factor = NEURON_SELECTIVITY_KD_FACTOR
weighted_parameter_kd_factor = WEIGHTED_PARAMETER_KD_FACTOR
fsp_kd_factor = FSP_KD_FACTOR
attention_alignment_kd_factor = ATTENTION_ALIGNMENT_KD_FACTOR
gram_matrix_kd_factor = GRAM_MATRIX_KD_FACTOR
attention_kd_loss_type = ATTENTION_KD_LOSS_TYPE
hidden_kd_loss_type = HIDDEN_STATE_KD_LOSS_TYPE
output_logit_kd_loss_type = OUTPUT_LOGIT_KD_LOSS_TYPE
contrastive_kd_loss_type = CONTRASTIVE_KD_LOSS_TYPE
rdrop_kd_loss_type = RDROP_KD_LOSS_TYPE
layer_wise_loss_type = LAYER_WISE_LOSS_TYPE
neuron_selectivity_loss_type = NEURON_SELECTIVITY_LOSS_TYPE
weighted_parameter_loss_type = WEIGHTED_PARAMETER_LOSS_TYPE
fsp_kd_loss_type = FSP_KD_LOSS_TYPE
attention_alignment_loss_type = ATTENTION_ALIGNMENT_LOSS_TYPE
gram_matrix_loss_type = GRAM_MATRIX_LOSS_TYPE
intermediate_layers = INTERMEDIATE_LAYERS
layer_norm_modules = LAYER_NORM_MODULES
activation_modules = ACTIVATION_MODULES
feature_map_modules = FEATURE_MAP_MODULES
layerwise_parameter_modules = LAYERWISE_PARAMETER_MODULES
vocab_projection_MODULES = VOCAB_PROJECTION_MODULES
layer_wise_modules = LAYER_WISE_MODULES
neuron_selectivity_modules = NEURON_SELECTIVITY_MODULES
weighted_parameter_modules = WEIGHTED_PARAMETER_MODULES
fsp_modules = FSP_MODULES
attention_alignment_modules = ATTENTION_ALIGNMENT_MODULES
gram_matrix_modules = GRAM_MATRIX_MODULES
adaptive_temperature_initial = ADAPTIVE_TEMPERATURE_INITIAL
adaptive_temperature_decay_rate = adaptive_temperature_decay_rate
adaptive_temperature = torch.tensor(adaptive_temperature_initial, device=device, requires_grad=False)
student_model.model = attach_activation_saver(student_model.model, activation_modules, feature_map_modules)
for teacher_model in teacher_models:
teacher_model.model = attach_activation_saver(teacher_model.model, activation_modules, feature_map_modules)
optimizer = torch.optim.AdamW(
student_model.model.parameters(),
lr=LR_INITIAL,
weight_decay=WEIGHT_DECAY
)
scheduler = get_scheduler(
name=SCHEDULER_TYPE,
optimizer=optimizer,
num_warmup_steps=WARMUP_STEPS,
num_training_steps=total_steps
) if ENABLE_LR_SCHEDULER else None
scaler = torch.amp.GradScaler(enabled=use_mixed_precision)
if ENABLE_STUDENT_PARAMETER_FREEZE:
for param in student_model.model.parameters():
param.requires_grad = False
log_to_file(f"Starting KD: freezing for {freeze_steps} steps.")
else:
log_to_file("KD started without initial freezing.")
writer = SummaryWriter(log_dir=TENSORBOARD_LOG_DIR)
MODELS['writer'] = writer
best_loss = float('inf')
patience_counter = 0
start_time = time.time()
progress_bar = tqdm(range(KD_STEPS), desc="Knowledge Distillation Progress")
for step in progress_bar:
try:
optimizer.zero_grad()
loss_accum = 0.0
gc.collect()
torch.cuda.empty_cache()
for _ in range(ACCUMULATION_STEPS):
batch_texts = ["This is a sample text for distillation.", "Another example sentence."]
try:
with torch.amp.autocast(device_type=device.type, enabled=ENABLE_MIXED_PRECISION):
teacher_outputs_list = []
teacher_inputs_list = []
for teacher_model in MODELS['teacher']:
with torch.no_grad():
teacher_inputs = teacher_model.tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128,
).to(device)
teacher_inputs_list.append(teacher_inputs)
teacher_outputs = teacher_model.model(**teacher_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD)
teacher_outputs_list.append(teacher_outputs)
student_inputs = MODELS['student'].tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128,
).to(device)
student_outputs = MODELS['student'].model(**student_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD)
student_logits = student_outputs.logits
student_attentions = student_outputs.attentions if ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD else None
student_hidden_states = student_outputs.hidden_states if ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD else None
teacher_logits = torch.stack([teacher.logits for teacher in teacher_outputs_list]).mean(dim=0)
teacher_attentions = [teacher.attentions for teacher in teacher_outputs_list]
teacher_attentions = torch.stack([torch.stack(attn) for attn in teacher_attentions]).mean(dim=0) if teacher_attentions[0] is not None else None
teacher_hidden_states = [teacher.hidden_states for teacher in teacher_outputs_list]
teacher_hidden_states = torch.stack([torch.stack(hiddens) for hiddens in teacher_hidden_states]).mean(dim=0) if teacher_hidden_states[0] is not None else None
min_vocab_size_logits = min(teacher_logits.size(-1), student_logits.size(-1))
teacher_logits_trimmed = teacher_logits[:, :, :min_vocab_size_logits]
student_logits_trimmed = student_logits[:, :, :min_vocab_size_logits]
teacher_embeds = torch.stack([MODELS['teacher'][0].model.get_input_embeddings()(teacher_inputs_list[0].input_ids) for _ in range(len(MODELS['teacher']))]).mean(dim=0)
student_embeds = MODELS['student'].model.get_input_embeddings()(student_inputs.input_ids)
ce_loss = torch.tensor(0.0).to(device)
output_kd_loss = OUTPUT_LOGIT_KD_FACTOR * kd_loss_functions[OUTPUT_LOGIT_KD_LOSS_TYPE](student_logits_trimmed / adaptive_temperature, teacher_logits_trimmed / adaptive_temperature) if ENABLE_OUTPUT_LOGIT_KD else torch.tensor(0.0).to(device)
loss = output_kd_loss
loss = loss + CE_LOSS_FACTOR * ce_loss
attn_loss = ATTENTION_KD_FACTOR * attention_kd_loss(student_attentions[-1], teacher_attentions[-1], ATTENTION_KD_LOSS_TYPE) if ENABLE_ATTENTION_KD and teacher_attentions is not None and student_attentions is not None else torch.tensor(0.0).to(device)
loss += attn_loss
hidden_loss = HIDDEN_STATE_KD_FACTOR * hidden_state_kd_loss(student_hidden_states[-1], teacher_hidden_states[-1], HIDDEN_STATE_KD_LOSS_TYPE) if ENABLE_HIDDEN_STATE_KD and teacher_hidden_states is not None and student_hidden_states is not None else torch.tensor(0.0).to(device)
loss += hidden_loss
intermediate_loss = 0.0
if ENABLE_INTERMEDIATE_KD and teacher_hidden_states is not None and student_hidden_states is not None:
for layer_idx in INTERMEDIATE_LAYERS:
if layer_idx < len(teacher_hidden_states) and layer_idx < len(student_hidden_states):
teacher_layer_output = teacher_hidden_states[layer_idx]
student_layer_output = student_hidden_states[layer_idx]
layer_loss = hidden_state_kd_loss(student_layer_output, teacher_layer_output, HIDDEN_STATE_KD_LOSS_TYPE)
intermediate_loss += layer_loss
loss += INTERMEDIATE_KD_FACTOR * intermediate_loss
layer_norm_loss = 0.0
if ENABLE_LAYER_NORM_KD:
for layer_norm_module_name in LAYER_NORM_MODULES:
if hasattr(MODELS['student'].model, layer_norm_module_name):
student_ln = getattr(MODELS['student'].model, layer_norm_module_name)
layer_norm_weights = []
layer_norm_biases = []
for teacher_model in MODELS['teacher']:
if hasattr(teacher_model.model, layer_norm_module_name):
teacher_ln = getattr(teacher_model.model, layer_norm_module_name)
if isinstance(student_ln, nn.LayerNorm) and isinstance(teacher_ln, nn.LayerNorm):
layer_norm_weights.append(teacher_ln.weight)
layer_norm_biases.append(teacher_ln.bias)
if layer_norm_weights:
teacher_weight_mean = torch.stack(layer_norm_weights).mean(dim=0)
teacher_bias_mean = torch.stack(layer_norm_biases).mean(dim=0)
layer_norm_loss += F.mse_loss(student_ln.weight, teacher_weight_mean)
layer_norm_loss += F.mse_loss(student_ln.bias, teacher_bias_mean)
loss += LAYER_NORM_KD_FACTOR * layer_norm_loss
embed_loss = EMBEDDING_KD_FACTOR * hidden_state_kd_loss(student_embeds, teacher_embeds, HIDDEN_STATE_KD_LOSS_TYPE) if ENABLE_EMBEDDING_KD else torch.tensor(0.0).to(device)
loss += embed_loss
parameter_loss = 0.0
if ENABLE_PARAMETER_KD:
for name, student_param in MODELS['student'].model.named_parameters():
param_values = []
for teacher_model in MODELS['teacher']:
if name in teacher_model.model.state_dict():
param_values.append(teacher_model.model.state_dict()[name])
if param_values:
teacher_param_mean = torch.stack(param_values).mean(dim=0)
parameter_loss += kd_loss_functions[WEIGHTED_PARAMETER_LOSS_TYPE](student_param, teacher_param_mean)
loss += PARAMETER_KD_FACTOR * parameter_loss
layerwise_parameter_loss = 0.0
if ENABLE_LAYERWISE_PARAMETER_KD:
for layer_module_name in LAYERWISE_PARAMETER_MODULES:
if hasattr(MODELS['student'].model, layer_module_name):
student_layer_module = getattr(MODELS['student'].model, layer_module_name)
for name, student_param in student_layer_module.named_parameters():
param_values = []
full_name = f"{layer_module_name}.{name}"
for teacher_model in MODELS['teacher']:
if hasattr(teacher_model.model, layer_module_name) and full_name in teacher_model.model.state_dict():
param_values.append(teacher_model.model.state_dict()[full_name])
if param_values:
teacher_param_mean = torch.stack(param_values).mean(dim=0)
layerwise_parameter_loss += kd_loss_functions[WEIGHTED_PARAMETER_LOSS_TYPE](student_param, teacher_param_mean)
loss += LAYERWISE_PARAMETER_KD_FACTOR * layerwise_parameter_loss
vocab_projection_loss = 0.0
if ENABLE_VOCAB_PROJECTION_KD:
for vocab_module_name in VOCAB_PROJECTION_MODULES:
if hasattr(MODELS['student'].model, vocab_module_name):
student_vocab_module = getattr(MODELS['student'].model, vocab_module_name)
projection_weights = []
for teacher_model in MODELS['teacher']:
if hasattr(teacher_model.model, vocab_module_name):
teacher_vocab_module = getattr(teacher_model.model, vocab_module_name)
projection_weights.append(teacher_vocab_module.weight)
if projection_weights:
teacher_weight_mean = torch.stack(projection_weights).mean(dim=0)
vocab_projection_loss += F.mse_loss(student_vocab_module.weight, teacher_weight_mean)
loss += VOCAB_PROJECTION_KD_FACTOR * vocab_projection_loss
activation_loss = 0.0
if ENABLE_ACTIVATION_KD:
for name, module in MODELS['student'].model.named_modules():
if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES:
if module.activation_values is not None:
activations = module.activation_values
activation_loss += F.mse_loss(activations, torch.zeros_like(activations))
loss += ACTIVATION_KD_FACTOR * activation_loss
logit_masking_loss = 0.0
if ENABLE_LOGIT_MASKING_KD:
teacher_probs = F.softmax(teacher_logits, dim=-1)
top_k_indices = torch.topk(teacher_probs, k=10, dim=-1)[1]
mask = torch.ones_like(student_logits).scatter_(-1, top_k_indices, 0.0).bool()
masked_student_logits = student_logits.masked_fill(mask, -1e9)
logit_masking_loss = kd_loss_functions['kl'](masked_student_logits / adaptive_temperature, teacher_logits / adaptive_temperature) * (adaptive_temperature ** 2)
loss += LOGIT_MASKING_FACTOR * logit_masking_loss
sparsity_loss = 0.0
if ENABLE_SPARCITY_REGULARIZATION:
for name, module in MODELS['student'].model.named_modules():
if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES:
if module.activation_values is not None:
activations = module.activation_values
sparsity_loss += torch.norm(activations, 1)
loss += SPARCITY_REGULARIZATION_FACTOR * sparsity_loss
feature_map_loss = 0.0
if ENABLE_FEATURE_MAP_KD:
for name, student_module in MODELS['student'].model.named_modules():
if isinstance(student_module, ActivationSaver) and type(student_module.module).__name__ in FEATURE_MAP_MODULES:
teacher_module = MODELS['teacher'][0].model.get_submodule(name)
if isinstance(teacher_module, ActivationSaver) and teacher_module.feature_map_values is not None and student_module.feature_map_values is not None:
student_feature_map = student_module.feature_map_values
teacher_feature_map = teacher_module.feature_map_values
feature_map_loss += F.mse_loss(student_feature_map, teacher_feature_map)
loss += FEATURE_MAP_KD_FACTOR * feature_map_loss
contrastive_loss = 0.0
if ENABLE_CONTRASTIVE_KD:
student_vec = student_hidden_states[-1][:, 0, :]
teacher_vec = teacher_hidden_states[-1][:, 0, :]
contrastive_loss = CONTRASTIVE_KD_FACTOR * kd_loss_functions[CONTRASTIVE_KD_LOSS_TYPE](student_vec, teacher_vec)
loss += contrastive_loss
rdrop_loss = 0.0
if ENABLE_RDROP_KD:
r_student_outputs = MODELS['student'].model(**student_inputs, output_attentions=ENABLE_ATTENTION_KD or ENABLE_ATTENTION_ALIGNMENT_KD, output_hidden_states=ENABLE_HIDDEN_STATE_KD or ENABLE_INTERMEDIATE_KD or ENABLE_LAYER_WISE_KD)
r_student_logits = r_student_outputs.logits
rdrop_loss = RDROP_KD_FACTOR * (kd_loss_functions['kl'](F.log_softmax(student_logits, dim=-1), F.log_softmax(r_student_logits, dim=-1)) + kd_loss_functions['kl'](F.log_softmax(r_student_logits, dim=-1), F.log_softmax(student_logits, dim=-1)))
loss += rdrop_loss
layer_wise_loss = 0.0
if ENABLE_LAYER_WISE_KD and teacher_hidden_states is not None and student_hidden_states is not None:
layer_wise_loss = 0.0
for layer_module_name in LAYER_WISE_MODULES:
if hasattr(MODELS['student'].model, layer_module_name):
student_layer_module = getattr(MODELS['student'].model, layer_module_name)
for student_layer, teacher_layer in zip(student_layer_module.modules(), MODELS['teacher'][0].model.get_module(layer_module_name).modules()):
if isinstance(student_layer, nn.Linear) and isinstance(teacher_layer, nn.Linear):
student_output = student_layer(student_hidden_states[-1])
teacher_output = teacher_layer(teacher_hidden_states[-1])
layer_wise_loss += hidden_state_kd_loss(student_output, teacher_output, LAYER_WISE_LOSS_TYPE)
loss += LAYER_WISE_KD_FACTOR * layer_wise_loss
activation_regularization_loss = 0.0
if ENABLE_ACTIVATION_REGULARIZATION:
activation_regularization_loss = 0.0
for name, module in MODELS['student'].model.named_modules():
if isinstance(module, ActivationSaver) and type(module.module).__name__ in ACTIVATION_MODULES:
if module.activation_values is not None:
activations = module.activation_values
activation_regularization_loss += torch.norm(activations, p=2)
loss += ACTIVATION_REG_LAMBDA * activation_regularization_loss
neuron_selectivity_loss = 0.0
if ENABLE_NEURON_SELECTIVITY_KD:
neuron_selectivity_loss = 0.0
for name, module in MODELS['student'].model.named_modules():
if isinstance(module, ActivationSaver) and type(module.module).__name__ in NEURON_SELECTIVITY_MODULES:
if module.activation_values is not None:
student_activations = module.activation_values
teacher_module = MODELS['teacher'][0].model.get_submodule(name)
if isinstance(teacher_module, ActivationSaver) and teacher_module.activation_values is not None:
teacher_activations = teacher_module.activation_values
neuron_selectivity_loss += NEURON_SELECTIVITY_KD_FACTOR * kd_loss_functions[NEURON_SELECTIVITY_LOSS_TYPE](student_activations.mean(dim=0), teacher_activations.mean(dim=0))
loss += neuron_selectivity_loss
l2_reg_loss = 0.0
if ENABLE_L2_REGULARIZATION:
for param in MODELS['student'].model.parameters():
l2_reg_loss += torch.norm(param, p=2)
loss += L2_LAMBDA * l2_reg_loss
l1_reg_loss = 0.0
if ENABLE_L1_REGULARIZATION:
for param in MODELS['student'].model.parameters():
l1_reg_loss += torch.norm(param, p=1)
loss += L1_LAMBDA * l1_reg_loss
loss_accum += (kd_loss_factor * loss + ce_loss_factor * ce_loss) / accumulation_steps
except Exception as e:
log_message(f"Error in inner distillation loop: {e}", level="warning")
scaler.scale(loss_accum).backward()
if ENABLE_GRADIENT_CLIPPING:
try:
torch.nn.utils.clip_grad_norm_(MODELS['student'].model.parameters(), GRAD_CLIP_VALUE)
except Exception as e:
log_message(f"Gradient clipping error: {e}", level="warning")
scaler.step(optimizer)
scaler.update()
if scheduler:
scheduler.step()
if ENABLE_ADAPTIVE_TEMPERATURE_KD:
adaptive_temperature *= adaptive_temperature_decay_rate
adaptive_temperature = max(adaptive_temperature, 1.0)
if ENABLE_STUDENT_PARAMETER_FREEZE:
if ENABLE_DYNAMIC_FREEZE:
if loss_accum.item() < FREEZE_THRESHOLD:
for param in MODELS['student'].model.parameters():
param.requires_grad = True
if step == 0:
log_message("Dynamic unfreezing activated due to low loss.", level="info")
elif step + 1 == FREEZE_STUDENT_STEPS:
for param in MODELS['student'].model.parameters():
param.requires_grad = True
log_message("Unfreezing after freeze_steps.", level="info")
if SAVE_CHECKPOINTS and (step + 1) % CHECKPOINT_INTERVAL == 0:
ckpt_path = os.path.join(CHECKPOINT_DIR, f"student_step_{step + 1}.pt")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
save_checkpoint(MODELS['student'], ckpt_path)
log_to_file(f"Checkpoint saved to {ckpt_path}")
if (step + 1) % LOG_INTERVAL == 0:
elapsed = time.time() - start_time
lr = optimizer.param_groups[0]['lr']
grad_norm = 0.0
num_params_with_grad = 0
for param in MODELS['student'].model.parameters():
if param.grad is not None:
grad_norm += param.grad.data.norm(2).item()
num_params_with_grad += 1
avg_grad_norm = grad_norm / num_params_with_grad if num_params_with_grad > 0 else 0.0
log_msg = (f"[KD] Step {step + 1}/{KD_STEPS}, Loss: {loss_accum.item():.4f}, "
f"LR: {lr:.6f}, GradNorm: {avg_grad_norm:.4f}, Time: {elapsed:.2f}s, Temp: {adaptive_temperature.item():.3f}")
print(log_msg)
log_to_file(log_msg)
writer.add_scalar("Loss/accumulated", loss_accum.item(), step + 1)
writer.add_scalar("Learning_Rate", lr, step + 1)
writer.add_scalar("GradNorm", avg_grad_norm, step + 1)
writer.add_scalar("Output_Logit_KD_Loss", output_kd_loss.item(), step + 1)
writer.add_scalar("Attention_KD_Loss", attn_loss.item(), step + 1)
writer.add_scalar("HiddenState_KD_Loss", hidden_loss.item(), step + 1)
writer.add_scalar("Intermediate_KD_Loss", intermediate_loss.item(), step + 1)
writer.add_scalar("LayerNorm_KD_Loss", layer_norm_loss.item(), step + 1)
writer.add_scalar("Embedding_KD_Loss", embed_loss.item(), step + 1)
writer.add_scalar("Parameter_KD_Loss", parameter_loss.item(), step + 1)
writer.add_scalar("Activation_KD_Loss", activation_loss.item(), step + 1)
writer.add_scalar("LogitMasking_KD_Loss", logit_masking_loss.item(), step + 1)
writer.add_scalar("Sparcity_Regularization_Loss", sparsity_loss.item(), step + 1)
writer.add_scalar("FeatureMap_KD_Loss", feature_map_loss.item(), step + 1)
writer.add_scalar("Layerwise_Parameter_KD_Loss", layerwise_parameter_loss.item(), step + 1)
writer.add_scalar("VocabProjection_KD_Loss", vocab_projection_loss.item(), step + 1)
writer.add_scalar("Contrastive_KD_Loss", contrastive_loss.item(), step + 1)
writer.add_scalar("RDrop_KD_Loss", rdrop_loss.item(), step + 1)
writer.add_scalar("Adaptive_Temperature", adaptive_temperature.item(), step + 1)
writer.add_scalar("LayerWise_KD_Loss", layer_wise_loss.item(), step + 1)
writer.add_scalar("Activation_Regularization_Loss", activation_regularization_loss.item(), step + 1)
writer.add_scalar("NeuronSelectivity_KD_Loss", neuron_selectivity_loss.item(), step + 1)
writer.add_scalar("WeightedParameter_KD_Loss", weighted_parameter_loss.item(), step + 1)
if ENABLE_EARLY_STOPPING:
if loss_accum.item() < best_loss:
best_loss = loss_accum.item()
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= EARLY_STOPPING_PATIENCE:
log_message(f"Early stopping activated at step {step + 1} after {EARLY_STOPPING_PATIENCE} steps.", level="info")
log_to_file(f"Early stopping activated at step {step + 1}.")
break
except Exception as e:
log_message(f"Error in distillation step {step + 1}: {e}", level="warning")
continue
progress_bar.close()
writer.close()
MODELS['writer'] = None
cleanup_and_exit(0)
return MODELS['student']
def push_model_to_hub(model, tokenizer, quantization_method, repo_name, use_auth_token):
try:
log_message(f"Saving {model.__class__.__name__} to {repo_name} with method '{quantization_method}'...", level="info")
model.model.save_pretrained(repo_name, push_to_hub=False)
tokenizer.save_pretrained(repo_name, push_to_hub=False)
model.push_to_hub(repo_name, use_auth_token=use_auth_token)
tokenizer.push_to_hub(repo_name, use_auth_token=use_auth_token)
log_message("Upload completed.", level="info")
log_to_file(f"Model and tokenizer uploaded to {repo_name} with method {quantization_method}.")
except Exception as e:
log_message(f"Error during push to hub: {e}", level="warning")
def login_to_huggingface(token):
global AUTH_TOKEN, USER_NAME
try:
user_info = HfApi(token=token).whoami()
AUTH_TOKEN = token
USER_NAME = user_info['name']
log_message(f"Successfully logged in to Hugging Face as {USER_NAME}.", level="info")
return AUTH_TOKEN, USER_NAME, None
except Exception as e:
log_message(f"Hugging Face Hub login error: {e}", level="warning")
return None, None, "Invalid Hugging Face token."
def run_fusion_distillation(teacher_model_ckpt_1, teacher_model_ckpt_2, student_model_ckpt, repo_name, disable_mean_resizing, huggingface_token):
global MODELS, LOG_TEXT, GRADIO_LOG_OUTPUT, AUTH_TOKEN, USER_NAME
LOG_TEXT = ""
GRADIO_LOG_OUTPUT.value = LOG_TEXT
mean_resizing = not disable_mean_resizing
token, username, error_message = login_to_huggingface(huggingface_token)
if token is None:
log_message("Hugging Face login failed.", level="warning")
return error_message
AUTH_TOKEN = token
USER_NAME = username
try:
log_message(f"Authenticated with Hugging Face Hub as user: {USER_NAME}.", level="info")
except Exception as e:
log_message(f"Hub login error: {e}", level="warning")
return "Hugging Face Hub login failed: " + str(e)
try:
student_tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt)
MODELS['student'] = pipeline(model=student_model_ckpt, tokenizer=student_tokenizer, device=device)
if MODELS['student'].tokenizer is None:
log_message("Student model tokenizer could not be loaded.", level="warning")
return "Student model tokenizer could not be loaded."
special_token = "[UNFILTERED]"
if special_token not in MODELS['student'].tokenizer.get_vocab():
MODELS['student'].tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
log_message("Special token added to student tokenizer.", level="info")
MODELS['student'].model.resize_token_embeddings(len(MODELS['student'].tokenizer))
except Exception as e:
log_message(f"Error loading student model: {e}", level="warning")
return "Error loading student model: " + str(e)
teacher_models = []
teacher_tokenizers = []
teacher_model_checkpoints = [teacher_model_ckpt_1]
if teacher_model_ckpt_2:
teacher_model_checkpoints.append(teacher_model_ckpt_2)
if not isinstance(teacher_model_checkpoints, list):
log_message("Error loading teacher models: Teacher Model Checkpoints must be a list.", level="warning")
return "Error loading teacher models: Teacher Model Checkpoints must be a list."
for i, teacher_model_ckpt in enumerate(teacher_model_checkpoints):
if not isinstance(teacher_model_ckpt, str):
log_message(f"Error loading teacher models: not a string at index {i}", level="warning")
return f"Error loading teacher models: not a string at index {i}"
try:
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_ckpt)
teacher_pipeline = pipeline(model=teacher_model_ckpt, tokenizer=teacher_tokenizer, device=device)
teacher_models.append(teacher_pipeline)
teacher_tokenizers.append(teacher_tokenizer)
log_message(f"Teacher model {i + 1} loaded: {teacher_model_ckpt}", level="info")
except Exception as e:
log_message(f"Error loading teacher model {i + 1} ({teacher_model_ckpt}): {e}", level="warning")
return f"Error loading teacher model {i + 1}: {e}"
MODELS['teacher'] = teacher_models
try:
unify_parameters(MODELS['student'], MODELS['teacher'][0])
unify_embeddings(MODELS['student'], MODELS['teacher'][0], mean_resizing=mean_resizing)
MODELS['student'].tokenizer = unify_tokenizers(MODELS['student'].tokenizer, teacher_tokenizers[0], MODELS['student'])
MODELS['student'].tokenizer = fuse_tokenizers(teacher_tokenizers, MODELS['student'].tokenizer)
except Exception as e:
log_message(f"Model unification error: {e}", level="warning")
fusion_method_name = 'geometric_mean_double'
try:
unified_teacher_state = complete_unify_teacher_models_double(MODELS['teacher'], fusion_method=fusion_method_name, layer_scale=1.0, device=device)
MODELS['student'] = unify_teacher_into_student(unified_teacher_state, MODELS['student'], force_parameter_copy=True)
MODELS['student'] = update_student_embeddings_double(MODELS['student'], MODELS['teacher'], fusion_type=fusion_method_name, device=device, mean_resizing=mean_resizing)
except Exception as e:
log_message(f"Teacher model fusion error: {e}", level="warning")
try:
unified_student = advanced_knowledge_distillation(MODELS['teacher'], MODELS['student'], device)
except Exception as e:
log_message(f"Knowledge distillation error: {e}", level="warning")
return "Knowledge Distillation Failed, but Fusion might be partially completed. Check logs."
if ENABLE_PARAMETER_COUNT_CHECK:
student_param_count = sum(p.numel() for p in unified_student.model.parameters())
teacher_param_count = sum(p.numel() for p in MODELS['teacher'][0].model.parameters())
if student_param_count != teacher_param_count:
log_message(f"Warning: student parameters ({student_param_count:,}) differ from teacher ({teacher_param_count:,}).", level="warning")
else:
log_message(f"Parameter count consistent: Student and Teacher ({teacher_param_count:,}).", level="info")
sample_texts = ["What is the capital of France?", "Solve: 3 + 5 * 2", "Define a function in Python."]
try:
predictions = generate_predictions(unified_student, MODELS['student'].tokenizer, sample_texts)
log_message("Sample Predictions after KD:\n" + "\n".join(predictions), level="info")
except Exception as e:
log_message(f"Sample prediction error: {e}", level="warning")
api = HfApi()
try:
api.create_repo(repo_name, exist_ok=True, token=AUTH_TOKEN)
push_model_to_hub(unified_student, MODELS['student'].tokenizer, "unified_teacher_kd_full_options_default_true_geometric_mean_double_fusion_v11", repo_name, AUTH_TOKEN)
final_msg = f"Student model fused, distilled, and uploaded to Hub to repo: {repo_name} with all advanced options and 'geometric_mean_double' fusion!"
print(final_msg)
log_to_file(final_msg, "training_log.txt")
return "Fusion and Distillation Completed! Check console and logs."
except Exception as e:
log_message(f"Final upload/repo creation error: {e}", level="warning")
return "Error during final upload or repository creation: " + str(e)
if __name__ == "__main__":
with gr.Blocks(css=".gradio-container {padding: 20px}") as iface:
gr.Markdown("# Fusion and Distillation Pipeline for Language Models")
gr.Markdown(
"This application fuses and distills knowledge from teacher language models into a student model. "
"It supports advanced knowledge distillation techniques and model fusion strategies."
)
huggingface_token_input = gr.Textbox(label="Hugging Face Token", type="password", visible=True)
username_display = gr.Textbox(label="Hugging Face Username", interactive=False, visible=False)
def process_token(huggingface_token):
_, username, error_message = login_to_huggingface(huggingface_token)
return username, error_message
huggingface_token_input.change(
process_token,
inputs=[huggingface_token_input],
outputs=[username_display]
)
with gr.Column():
gr.Markdown("## Model Selection")
teacher_model_ckpt_1_input = HuggingfaceHubSearch(label="Teacher Model 1")
teacher_model_ckpt_2_input = HuggingfaceHubSearch(label="Teacher Model 2 (Optional)")
student_model_ckpt_input = HuggingfaceHubSearch(label="Student Model")
repo_name_input = gr.Textbox(
label="Repository Name",
info="Enter the name of the Hugging Face repository to create or update with the distilled model."
)
disable_mean_resizing_checkbox = gr.Checkbox(
label="Disable Mean Resizing",
info="Check to disable mean resizing of embeddings during unification."
)
with gr.Column():
gr.Markdown("## Execution and Logs")
run_button = gr.Button("Run Fusion and Distillation", variant="primary", interactive=True, elem_id='run-button')
output_status_textbox = gr.Textbox(
info="Real-time status and logs of the fusion and distillation process.",
label="Status",
lines=3
)
GRADIO_LOG_OUTPUT = gr.Textbox(
value="",
label="Detailed Log Output",
lines=10,
interactive=False
)
def update_run_button_interactivity(token):
return gr.Button.update(interactive=True)
huggingface_token_input.change(
update_run_button_interactivity,
inputs=[huggingface_token_input],
outputs=[run_button]
)
run_button.click(
run_fusion_distillation,
inputs=[
teacher_model_ckpt_1_input,
teacher_model_ckpt_2_input,
student_model_ckpt_input,
repo_name_input,
disable_mean_resizing_checkbox,
huggingface_token_input
],
outputs=output_status_textbox,
)
iface.launch(server_name="0.0.0.0", server_port=7860)