| |
|
|
| import os, logging, torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim, numpy as np, pathlib, pickle |
| from tqdm import tqdm |
| from safetensors.numpy import save_file |
| from transformers import AutoProcessor, AutoModelForImageTextToText, AutoTokenizer, AutoModelForCausalLM, pipeline |
| from PIL import Image |
| import gradio as gr |
| from huggingface_hub import login, HfApi |
|
|
| GRADIO_LOG = "" |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" |
| device = torch.device("cpu") |
| checkpoint_path = pathlib.Path("/tmp/gemma_pytorch_models/checkpoint.pth") |
| save_dir = pathlib.Path("/tmp/gemma_pytorch_models/") |
| save_dir.mkdir(exist_ok=True, parents=True) |
|
|
| def log_message(msg, level="info"): |
| global GRADIO_LOG |
| if level == "info": |
| logging.info(msg) |
| elif level == "warning": |
| logging.warning(msg) |
| else: |
| logging.debug(msg) |
| GRADIO_LOG += msg + "\n" |
|
|
| def load_checkpoint(model): |
| if checkpoint_path.exists(): |
| try: |
| state = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(state) |
| log_message("Checkpoint cargado.", "info") |
| except Exception as e: |
| log_message(f"Error cargando checkpoint: {e}", "warning") |
| else: |
| log_message("No se encontr贸 checkpoint.", "info") |
|
|
| def save_checkpoint(model): |
| try: |
| torch.save(model.state_dict(), checkpoint_path) |
| log_message("Checkpoint guardado.", "info") |
| except Exception as e: |
| log_message(f"Error guardando checkpoint: {e}", "warning") |
|
|
| def merge_tokenizer_vocabs(t1, t2, student_tok): |
| try: |
| def get_vocab(obj): |
| try: |
| return obj.get_vocab() |
| except Exception: |
| try: |
| return obj.tokenizer.get_vocab() |
| except Exception: |
| return {} |
| vocab1 = get_vocab(t1) |
| vocab2 = get_vocab(t2) |
| merged = {} |
| merged.update(vocab1) |
| merged.update(vocab2) |
| student_vocab = get_vocab(student_tok) |
| missing = [token for token in merged if token not in student_vocab] |
| if missing: |
| log_message(f"Agregando {len(missing)} tokens faltantes al student tokenizer", "info") |
| student_tok.add_tokens(missing) |
| except Exception as e: |
| log_message(f"Error en merge_tokenizer_vocabs: {e}", "warning") |
|
|
| def enhanced_transfer(t1_tok, t2_tok, student_tok, teacher, student): |
| try: |
| teacher_state = teacher.state_dict() |
| student_state = student.state_dict() |
| for name, param in teacher_state.items(): |
| if name in student_state and student_state[name].shape == param.shape: |
| student_state[name].copy_(param) |
| student.load_state_dict(student_state) |
| merge_tokenizer_vocabs(t1_tok, t2_tok, student_tok) |
| except Exception as e: |
| log_message(f"Error en enhanced_transfer: {e}", "warning") |
|
|
| def hidden_shared_copy(teacher1, teacher2, student): |
| try: |
| s_state = student.state_dict() |
| t1_state = teacher1.state_dict() |
| t2_state = teacher2.state_dict() |
| for key in s_state.keys(): |
| if key in t1_state and key in t2_state: |
| avg = (t1_state[key] + t2_state[key]) / 2.0 |
| s_state[key] = (s_state[key] + avg) / 2.0 |
| student.load_state_dict(s_state) |
| except Exception as e: |
| log_message(f"Error en hidden_shared_copy: {e}", "warning") |
|
|
| def get_embedding(model): |
| for m in model.modules(): |
| if isinstance(m, nn.Embedding): |
| return m |
| return None |
|
|
| def copy_embeddings(teacher1, teacher2, student): |
| try: |
| emb1 = get_embedding(teacher1) |
| emb2 = get_embedding(teacher2) |
| emb_student = get_embedding(student) |
| if emb1 is not None and emb2 is not None and emb_student is not None: |
| if emb1.weight.shape == emb2.weight.shape == emb_student.weight.shape: |
| avg_embed = (emb1.weight.data + emb2.weight.data) / 2.0 |
| emb_student.weight.data.copy_(avg_embed) |
| log_message("Copiados embeddings promediados", "info") |
| else: |
| log_message(f"No se copian embeddings: dimensiones incompatibles teacher1: {emb1.weight.shape}, teacher2: {emb2.weight.shape}, student: {emb_student.weight.shape}", "warning") |
| except Exception as e: |
| log_message(f"Error en copy_embeddings: {e}", "warning") |
|
|
| def copy_additional_parameters(teacher1, teacher2, student): |
| try: |
| s_state = student.state_dict() |
| t1_state = teacher1.state_dict() |
| t2_state = teacher2.state_dict() |
| for key in s_state.keys(): |
| if key in t1_state and key in t2_state: |
| s_state[key].requires_grad = t1_state[key].requires_grad or t2_state[key].requires_grad |
| student.load_state_dict(s_state) |
| except Exception as e: |
| log_message(f"Error en copy_additional_parameters: {e}", "warning") |
|
|
| def copy_internal_attributes(teacher, student): |
| try: |
| for key, value in teacher.__dict__.items(): |
| if not key.startswith("_") and key not in student.__dict__: |
| student.__dict__[key] = value |
| except Exception as e: |
| log_message(f"Error en copy_internal_attributes: {e}", "warning") |
|
|
| def copy_extra_metadata(teacher, student): |
| try: |
| for key, value in teacher.__dict__.items(): |
| if any(substr in key.lower() for substr in ["optimizer", "dropout", "lr", "schedule", "loss", "metric"]): |
| student.__dict__[key] = value |
| except Exception as e: |
| log_message(f"Error en copy_extra_metadata: {e}", "warning") |
|
|
| def copy_all_remaining_attributes(teacher, student): |
| try: |
| for key, value in teacher.__dict__.items(): |
| if key not in student.__dict__: |
| student.__dict__[key] = value |
| except Exception as e: |
| log_message(f"Error en copy_all_remaining_attributes: {e}", "warning") |
|
|
| def copy_attention_attributes(teacher, student): |
| try: |
| if hasattr(teacher, "attention") and not hasattr(student, "attention"): |
| student.attention = teacher.attention |
| except Exception as e: |
| log_message(f"Error en copy_attention_attributes: {e}", "warning") |
|
|
| def copy_model_config(teacher, student): |
| try: |
| if hasattr(teacher, "config"): |
| student.config = teacher.config |
| except Exception as e: |
| log_message(f"Error en copy_model_config: {e}", "warning") |
|
|
| def complete_copy_from_teachers_to_student(teacher1, teacher2, student, t1_tok, t2_tok, student_tok, texts): |
| for _ in tqdm(range(1), desc="Completing Copy"): |
| try: |
| enhanced_transfer(t1_tok, t2_tok, student_tok, teacher1, student) |
| except Exception as e: |
| print(f"Error en enhanced_transfer: {e}") |
| try: |
| hidden_shared_copy(teacher1, teacher2, student) |
| except Exception as e: |
| print(f"Error en hidden_shared_copy: {e}") |
| try: |
| copy_embeddings(teacher1, teacher2, student) |
| except Exception as e: |
| print(f"Error en copy_embeddings: {e}") |
| try: |
| copy_additional_parameters(teacher1, teacher2, student) |
| except Exception as e: |
| print(f"Error en copy_additional_parameters: {e}") |
| try: |
| copy_internal_attributes(teacher1, student) |
| except Exception as e: |
| print(f"Error en copy_internal_attributes (teacher1): {e}") |
| try: |
| copy_internal_attributes(teacher2, student) |
| except Exception as e: |
| print(f"Error en copy_internal_attributes (teacher2): {e}") |
| try: |
| copy_extra_metadata(teacher1, student) |
| except Exception as e: |
| print(f"Error en copy_extra_metadata (teacher1): {e}") |
| try: |
| copy_extra_metadata(teacher2, student) |
| except Exception as e: |
| print(f"Error en copy_extra_metadata (teacher2): {e}") |
| try: |
| copy_all_remaining_attributes(teacher1, student) |
| except Exception as e: |
| print(f"Error en copy_all_remaining_attributes: {e}") |
| try: |
| copy_attention_attributes(teacher1, student) |
| except Exception as e: |
| print(f"Error en copy_attention_attributes: {e}") |
| try: |
| copy_model_config(teacher1, student) |
| except Exception as e: |
| print(f"Error en copy_model_config: {e}") |
| try: |
| merge_tokenizer_vocabs(t1_tok, t2_tok, student_tok) |
| except Exception as e: |
| print(f"Error en merge_tokenizer_vocabs: {e}") |
| print("Copia completa finalizada") |
|
|
| def unify_parametersx(student_model, teacher_model, exclude_layers=None): |
| for _ in tqdm(range(1), desc="Unify Parameters"): |
| try: |
| teacher_state = teacher_model.state_dict() |
| student_state = student_model.state_dict() |
| excluded = exclude_layers or [] |
| for name, param in student_state.items(): |
| if any(ex in name for ex in excluded): |
| continue |
| if name in teacher_state and student_state[name].shape == teacher_state[name].shape: |
| student_state[name].copy_(teacher_state[name]) |
| student_model.load_state_dict(student_state, strict=False) |
| print("Model parameters unified.") |
| except Exception as e: |
| print(f"Error in unify_parameters: {e}") |
|
|
| def unify_embeddings(student_model, teacher_model, project_embeddings=True, mean_resizing=False): |
| for _ in tqdm(range(1), desc="Unify Embeddings"): |
| try: |
| student_emb = student_model.get_input_embeddings() |
| teacher_emb = teacher_model.get_input_embeddings() |
| if project_embeddings: |
| in_dim = teacher_emb.weight.shape[1] |
| out_dim = student_emb.weight.shape[1] |
| if in_dim != out_dim: |
| projection = nn.Linear(in_dim, out_dim).to(device) |
| teacher_emb_projected = projection(teacher_emb.weight) |
| else: |
| teacher_emb_projected = teacher_emb.weight |
| teacher_vocab_size_proj = teacher_emb_projected.shape[0] |
| min_vocab = min(student_emb.weight.shape[0], teacher_vocab_size_proj) |
| student_emb.weight.data[:min_vocab].copy_(teacher_emb_projected.data[:min_vocab]) |
| else: |
| min_vocab = min(student_emb.weight.shape[0], teacher_emb.weight.shape[0]) |
| student_emb.weight.data[:min_vocab].copy_(teacher_emb.weight.data[:min_vocab]) |
| if hasattr(student_model, "get_output_embeddings") and hasattr(teacher_model, "get_output_embeddings"): |
| student_out_emb = student_model.get_output_embeddings() |
| teacher_out_emb = teacher_model.get_output_embeddings() |
| if student_out_emb is not None and teacher_out_emb is not None: |
| if project_embeddings: |
| in_dim = teacher_out_emb.weight.shape[1] |
| out_dim = student_out_emb.weight.shape[1] |
| if in_dim != out_dim: |
| projection = nn.Linear(in_dim, out_dim).to(device) |
| teacher_out_emb_projected = projection(teacher_out_emb.weight) |
| else: |
| teacher_out_emb_projected = teacher_out_emb.weight |
| teacher_vocab_size_proj = teacher_out_emb_projected.shape[0] |
| min_vocab = min(student_out_emb.weight.shape[0], teacher_vocab_size_proj) |
| student_out_emb.weight.data[:min_vocab].copy_(teacher_out_emb_projected.data[:min_vocab]) |
| else: |
| min_vocab = min(student_out_emb.weight.shape[0], teacher_out_emb.weight.shape[0]) |
| student_out_emb.weight.data[:min_vocab].copy_(teacher_out_emb.weight.data[:min_vocab]) |
| print("Embeddings unified.") |
| except Exception as e: |
| print(f"Error in unify_embeddings: {e}") |
|
|
| def unify_tokenizers(student_tokenizer, teacher_tokenizer, student_model): |
| for _ in tqdm(range(1), desc="Unify Tokenizers"): |
| print("Unifying tokenizers...") |
| if teacher_tokenizer is None: |
| print("No teacher tokenizer provided, skipping unification.") |
| return student_tokenizer |
| try: |
| teacher_vocab = teacher_tokenizer.get_vocab() |
| student_vocab = student_tokenizer.get_vocab() |
| new_tokens = [token for token in teacher_vocab if token not in student_vocab] |
| if new_tokens: |
| print(f"Adding {len(new_tokens)} new tokens to student tokenizer.") |
| student_tokenizer.add_tokens(new_tokens) |
| student_model.resize_token_embeddings(len(student_tokenizer)) |
| print("Student tokenizer vocab resized.") |
| else: |
| print("No new tokens to add to student tokenizer.") |
| except Exception as e: |
| print(f"Error in unify_tokenizers: {e}") |
| print("Tokenizers unified.") |
| return student_tokenizer |
|
|
| def unify_teacher_into_student(unified_teacher_state, student, force_parameter_copy=False): |
| for _ in tqdm(range(1), desc="Unify Teacher into Student"): |
| print("Unifying teacher into student...") |
| student_state = student.model.state_dict() |
| new_state = {} |
| for key, student_val in student_state.items(): |
| if key in unified_teacher_state: |
| teacher_val = unified_teacher_state[key] |
| try: |
| if student_val.shape == teacher_val.shape: |
| new_state[key] = teacher_val |
| elif len(student_val.shape) == len(teacher_val.shape): |
| min_shape = [min(s, t) for s, t in zip(student_val.shape, teacher_val.shape)] |
| student_slice = tuple([slice(0, s) for s in min_shape]) |
| teacher_slice = tuple([slice(0, s) for s in min_shape]) |
| new_state[key] = student_val.clone() |
| new_state[key][student_slice] = teacher_val[teacher_slice] |
| else: |
| print(f"Shape mismatch for {key}. Skipping parameter copy.") |
| new_state[key] = student_val |
| except Exception as e: |
| print(f"Error copying parameter {key}: {e}. Skipping.") |
| new_state[key] = student_val |
| else: |
| new_state[key] = student_val |
| student.model.load_state_dict(new_state, strict=False) |
| print("Teacher unified into student.") |
| return student |
|
|
| def fuse_tokenizers(teacher_tokenizers, student_tokenizer): |
| for _ in tqdm(range(1), desc="Fuse Tokenizers"): |
| print("Fusing tokenizers vocabularies...") |
| unified_vocab = set(student_tokenizer.get_vocab().keys()) if student_tokenizer else set() |
| for t_tok in teacher_tokenizers: |
| if t_tok: |
| unified_vocab = unified_vocab.union(set(t_tok.get_vocab().keys())) |
| new_tokens = list(unified_vocab - set(student_tokenizer.get_vocab().keys())) |
| if new_tokens: |
| print(f"Adding {len(new_tokens)} tokens from teacher tokenizers to student tokenizer.") |
| try: |
| student_tokenizer.add_tokens(new_tokens) |
| except Exception as e: |
| print(f"Error adding tokens: {e}") |
| else: |
| print("No new tokens to add from teacher tokenizers.") |
| print("Tokenizers vocabularies fused.") |
| return student_tokenizer |
|
|
| def distillation_loss(student_logits, teacher_logits): |
| alpha = 0.5 |
| temperature = 2.0 |
| teacher_soft = F.softmax(teacher_logits / temperature, dim=-1) |
| student_soft = F.softmax(student_logits / temperature, dim=-1) |
| loss_soft = F.mse_loss(student_soft, teacher_soft) |
| loss_hard = F.mse_loss(student_logits, teacher_logits) |
| return alpha * loss_soft + (1 - alpha) * loss_hard |
|
|
| def run_pipeline(teacher1_id, teacher2_id, student_id, hf_token_input): |
| global GRADIO_LOG |
| GRADIO_LOG = "" |
| try: |
| login(token=hf_token_input) |
| api = HfApi() |
| proc = AutoProcessor.from_pretrained(teacher1_id) |
| teacher1_model = AutoModelForImageTextToText.from_pretrained(teacher1_id) |
| teacher1_model.to(device); teacher1_model.eval() |
| teacher1_tokenizer = proc.tokenizer if hasattr(proc, "tokenizer") else proc |
| teacher2_tokenizer = AutoTokenizer.from_pretrained(teacher2_id) |
| teacher2_model = AutoModelForCausalLM.from_pretrained(teacher2_id) |
| teacher2_model.to(device); teacher2_model.eval() |
| student_tokenizer = AutoTokenizer.from_pretrained(student_id) |
| if student_tokenizer.pad_token is None: |
| student_tokenizer.add_special_tokens({'pad_token': student_tokenizer.eos_token if student_tokenizer.eos_token is not None else "[PAD]"}) |
| student_model = AutoModelForCausalLM.from_pretrained(student_id) |
| student_model.to(device) |
| print("Modelos cargados") |
| except Exception as e: |
| return f"Error cargando modelos: {e}" |
| try: |
| t1_task = getattr(teacher1_model.config, "task_type", "image-to-text") |
| except Exception: |
| t1_task = "image-to-text" |
| try: |
| t2_task = getattr(teacher2_model.config, "task", "text-generation") |
| except Exception: |
| t2_task = "text-generation" |
| try: |
| s_task = getattr(student_model.config, "task", "text-generation") |
| except Exception: |
| s_task = "text-generation" |
| try: |
| pipe_t1 = pipeline(t1_task, model=teacher1_model, tokenizer=teacher1_tokenizer, device=-1) |
| pipe_t2 = pipeline(t2_task, model=teacher2_model, tokenizer=teacher2_tokenizer, device=-1) |
| pipe_s = pipeline(s_task, model=student_model, tokenizer=student_tokenizer, device=-1) |
| print("Pipelines creados:") |
| print("Teacher1:", pipe_t1.task) |
| print("Teacher2:", pipe_t2.task) |
| print("Student:", pipe_s.task) |
| except Exception as e: |
| return f"Error creando pipelines: {e}" |
| try: |
| out_t1 = pipe_t1(Image.new("RGB", (224,224), color=(0,0,0))) |
| out_t2 = pipe_t2("Test input") |
| out_s = pipe_s("Test input") |
| print("Outputs de prueba:") |
| print("Teacher1:", out_t1) |
| print("Teacher2:", out_t2) |
| print("Student:", out_s) |
| except Exception as e: |
| return f"Error en ejecuci贸n de pipelines de prueba: {e}" |
|
|
| optimizer = optim.Adam(student_model.parameters(), lr=1e-4) |
| student_model.train() |
| print("Inicio del entrenamiento por destilaci贸n") |
| texts = ["Texto de ejemplo 1", "Texto de ejemplo 2"] |
| for epoch in range(1): |
| print(f"脡poca {epoch+1}") |
| for text in tqdm(texts, desc="Entrenamiento destilaci贸n"): |
| optimizer.zero_grad() |
| inputs = student_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=16) |
| with torch.no_grad(): |
| t1_inputs = teacher1_tokenizer(text=[text], return_tensors="pt", padding="max_length", truncation=True, max_length=16) |
| t2_inputs = teacher2_tokenizer(text=[text], return_tensors="pt", padding="max_length", truncation=True, max_length=16) |
| teacher1_out = teacher1_model(**t1_inputs, max_new_tokens=10) |
| teacher2_out = teacher2_model(**t2_inputs, max_new_tokens=10) |
| teacher_logits = (teacher1_out.logits + teacher2_out.logits) / 2.0 |
| student_out = student_model(**inputs) |
| student_logits = student_out.logits |
| loss = distillation_loss(student_logits, teacher_logits) |
| loss.backward() |
| optimizer.step() |
| print(f"Texto: {text} | Loss: {loss.item():.6f}") |
| print("Entrenamiento completado") |
| save_checkpoint(student_model) |
| print("Aplicando cuantizaci贸n din谩mica al modelo student") |
| student_quantized = torch.quantization.quantize_dynamic(student_model, {nn.Linear}, dtype=torch.qint8) |
| def save_model_state(model, filename): |
| state_dict = model.state_dict() |
| np_state = {k: v.cpu().numpy() for k, v in state_dict.items()} |
| save_file(np_state, str(filename)) |
| for _ in tqdm(range(1), desc="Guardando modelos"): |
| pass |
| teacher1_file = save_dir / "gemma_teacher_model_quant.safetensors" |
| teacher2_file = save_dir / "llama_teacher_model_quant.safetensors" |
| student_file = save_dir / "gemma_student_model_quant.safetensors" |
| save_model_state(teacher1_model, teacher1_file) |
| save_model_state(teacher2_model, teacher2_file) |
| save_model_state(student_quantized, student_file) |
| print(f"Modelos guardados en {save_dir}") |
| try: |
| user_info = HfApi().whoami(token=hf_token_input) |
| username = user_info["name"] |
| repo_id_t1 = f"{username}/gemma-teacher-pytorch-safetensors" |
| HfApi().create_repo(repo_id_t1, token=hf_token_input, exist_ok=True) |
| HfApi().upload_file(token=hf_token_input, path_or_fileobj=str(teacher1_file), path_in_repo="gemma_teacher_model_quant.safetensors", repo_id=repo_id_t1) |
| repo_id_t2 = f"{username}/llama-teacher-pytorch-safetensors" |
| HfApi().create_repo(repo_id_t2, token=hf_token_input, exist_ok=True) |
| HfApi().upload_file(token=hf_token_input, path_or_fileobj=str(teacher2_file), path_in_repo="llama_teacher_model_quant.safetensors", repo_id=repo_id_t2) |
| repo_id_student = f"{username}/gemma-student-pytorch-safetensors" |
| HfApi().create_repo(repo_id_student, token=hf_token_input, exist_ok=True) |
| HfApi().upload_file(token=hf_token_input, path_or_fileobj=str(student_file), path_in_repo="gemma_student_model_quant.safetensors", repo_id=repo_id_student) |
| print("Modelos subidos al Hub de Hugging Face") |
| except Exception as e: |
| print(f"Error subiendo modelos: {e}") |
| return GRADIO_LOG + "\nProceso completado." |
|
|
| iface = gr.Interface( |
| fn=run_pipeline, |
| inputs=[ |
| gr.Textbox(label="Teacher1 Model ID (ImageText)"), |
| gr.Textbox(label="Teacher2 Model ID (CausalLM)"), |
| gr.Textbox(label="Student Model ID (CausalLM)"), |
| gr.Textbox(label="Hugging Face Token") |
| ], |
| outputs="text", |
| title="Pipeline de Unificaci贸n y Destilaci贸n de Modelos", |
| description="Ingrese los IDs de los modelos y su token HF para ejecutar el pipeline autom谩ticamente." |
| ) |
|
|
| iface.launch() |