Kfjjdjdjdhdhd's picture
Update app.py
afe2bc3 verified
#!pip install torch safetensors huggingface-hub git+https://github.com/huggingface/transformers.git gradio
import os, logging, torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim, numpy as np, pathlib, pickle
from tqdm import tqdm
from safetensors.numpy import save_file
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoTokenizer, AutoModelForCausalLM, pipeline
from PIL import Image
import gradio as gr
from huggingface_hub import login, HfApi
GRADIO_LOG = ""
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = torch.device("cpu")
checkpoint_path = pathlib.Path("/tmp/gemma_pytorch_models/checkpoint.pth")
save_dir = pathlib.Path("/tmp/gemma_pytorch_models/")
save_dir.mkdir(exist_ok=True, parents=True)
def log_message(msg, level="info"):
global GRADIO_LOG
if level == "info":
logging.info(msg)
elif level == "warning":
logging.warning(msg)
else:
logging.debug(msg)
GRADIO_LOG += msg + "\n"
def load_checkpoint(model):
if checkpoint_path.exists():
try:
state = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state)
log_message("Checkpoint cargado.", "info")
except Exception as e:
log_message(f"Error cargando checkpoint: {e}", "warning")
else:
log_message("No se encontr贸 checkpoint.", "info")
def save_checkpoint(model):
try:
torch.save(model.state_dict(), checkpoint_path)
log_message("Checkpoint guardado.", "info")
except Exception as e:
log_message(f"Error guardando checkpoint: {e}", "warning")
def merge_tokenizer_vocabs(t1, t2, student_tok):
try:
def get_vocab(obj):
try:
return obj.get_vocab()
except Exception:
try:
return obj.tokenizer.get_vocab()
except Exception:
return {}
vocab1 = get_vocab(t1)
vocab2 = get_vocab(t2)
merged = {}
merged.update(vocab1)
merged.update(vocab2)
student_vocab = get_vocab(student_tok)
missing = [token for token in merged if token not in student_vocab]
if missing:
log_message(f"Agregando {len(missing)} tokens faltantes al student tokenizer", "info")
student_tok.add_tokens(missing)
except Exception as e:
log_message(f"Error en merge_tokenizer_vocabs: {e}", "warning")
def enhanced_transfer(t1_tok, t2_tok, student_tok, teacher, student):
try:
teacher_state = teacher.state_dict()
student_state = student.state_dict()
for name, param in teacher_state.items():
if name in student_state and student_state[name].shape == param.shape:
student_state[name].copy_(param)
student.load_state_dict(student_state)
merge_tokenizer_vocabs(t1_tok, t2_tok, student_tok)
except Exception as e:
log_message(f"Error en enhanced_transfer: {e}", "warning")
def hidden_shared_copy(teacher1, teacher2, student):
try:
s_state = student.state_dict()
t1_state = teacher1.state_dict()
t2_state = teacher2.state_dict()
for key in s_state.keys():
if key in t1_state and key in t2_state:
avg = (t1_state[key] + t2_state[key]) / 2.0
s_state[key] = (s_state[key] + avg) / 2.0
student.load_state_dict(s_state)
except Exception as e:
log_message(f"Error en hidden_shared_copy: {e}", "warning")
def get_embedding(model):
for m in model.modules():
if isinstance(m, nn.Embedding):
return m
return None
def copy_embeddings(teacher1, teacher2, student):
try:
emb1 = get_embedding(teacher1)
emb2 = get_embedding(teacher2)
emb_student = get_embedding(student)
if emb1 is not None and emb2 is not None and emb_student is not None:
if emb1.weight.shape == emb2.weight.shape == emb_student.weight.shape:
avg_embed = (emb1.weight.data + emb2.weight.data) / 2.0
emb_student.weight.data.copy_(avg_embed)
log_message("Copiados embeddings promediados", "info")
else:
log_message(f"No se copian embeddings: dimensiones incompatibles teacher1: {emb1.weight.shape}, teacher2: {emb2.weight.shape}, student: {emb_student.weight.shape}", "warning")
except Exception as e:
log_message(f"Error en copy_embeddings: {e}", "warning")
def copy_additional_parameters(teacher1, teacher2, student):
try:
s_state = student.state_dict()
t1_state = teacher1.state_dict()
t2_state = teacher2.state_dict()
for key in s_state.keys():
if key in t1_state and key in t2_state:
s_state[key].requires_grad = t1_state[key].requires_grad or t2_state[key].requires_grad
student.load_state_dict(s_state)
except Exception as e:
log_message(f"Error en copy_additional_parameters: {e}", "warning")
def copy_internal_attributes(teacher, student):
try:
for key, value in teacher.__dict__.items():
if not key.startswith("_") and key not in student.__dict__:
student.__dict__[key] = value
except Exception as e:
log_message(f"Error en copy_internal_attributes: {e}", "warning")
def copy_extra_metadata(teacher, student):
try:
for key, value in teacher.__dict__.items():
if any(substr in key.lower() for substr in ["optimizer", "dropout", "lr", "schedule", "loss", "metric"]):
student.__dict__[key] = value
except Exception as e:
log_message(f"Error en copy_extra_metadata: {e}", "warning")
def copy_all_remaining_attributes(teacher, student):
try:
for key, value in teacher.__dict__.items():
if key not in student.__dict__:
student.__dict__[key] = value
except Exception as e:
log_message(f"Error en copy_all_remaining_attributes: {e}", "warning")
def copy_attention_attributes(teacher, student):
try:
if hasattr(teacher, "attention") and not hasattr(student, "attention"):
student.attention = teacher.attention
except Exception as e:
log_message(f"Error en copy_attention_attributes: {e}", "warning")
def copy_model_config(teacher, student):
try:
if hasattr(teacher, "config"):
student.config = teacher.config
except Exception as e:
log_message(f"Error en copy_model_config: {e}", "warning")
def complete_copy_from_teachers_to_student(teacher1, teacher2, student, t1_tok, t2_tok, student_tok, texts):
for _ in tqdm(range(1), desc="Completing Copy"):
try:
enhanced_transfer(t1_tok, t2_tok, student_tok, teacher1, student)
except Exception as e:
print(f"Error en enhanced_transfer: {e}")
try:
hidden_shared_copy(teacher1, teacher2, student)
except Exception as e:
print(f"Error en hidden_shared_copy: {e}")
try:
copy_embeddings(teacher1, teacher2, student)
except Exception as e:
print(f"Error en copy_embeddings: {e}")
try:
copy_additional_parameters(teacher1, teacher2, student)
except Exception as e:
print(f"Error en copy_additional_parameters: {e}")
try:
copy_internal_attributes(teacher1, student)
except Exception as e:
print(f"Error en copy_internal_attributes (teacher1): {e}")
try:
copy_internal_attributes(teacher2, student)
except Exception as e:
print(f"Error en copy_internal_attributes (teacher2): {e}")
try:
copy_extra_metadata(teacher1, student)
except Exception as e:
print(f"Error en copy_extra_metadata (teacher1): {e}")
try:
copy_extra_metadata(teacher2, student)
except Exception as e:
print(f"Error en copy_extra_metadata (teacher2): {e}")
try:
copy_all_remaining_attributes(teacher1, student)
except Exception as e:
print(f"Error en copy_all_remaining_attributes: {e}")
try:
copy_attention_attributes(teacher1, student)
except Exception as e:
print(f"Error en copy_attention_attributes: {e}")
try:
copy_model_config(teacher1, student)
except Exception as e:
print(f"Error en copy_model_config: {e}")
try:
merge_tokenizer_vocabs(t1_tok, t2_tok, student_tok)
except Exception as e:
print(f"Error en merge_tokenizer_vocabs: {e}")
print("Copia completa finalizada")
def unify_parametersx(student_model, teacher_model, exclude_layers=None):
for _ in tqdm(range(1), desc="Unify Parameters"):
try:
teacher_state = teacher_model.state_dict()
student_state = student_model.state_dict()
excluded = exclude_layers or []
for name, param in student_state.items():
if any(ex in name for ex in excluded):
continue
if name in teacher_state and student_state[name].shape == teacher_state[name].shape:
student_state[name].copy_(teacher_state[name])
student_model.load_state_dict(student_state, strict=False)
print("Model parameters unified.")
except Exception as e:
print(f"Error in unify_parameters: {e}")
def unify_embeddings(student_model, teacher_model, project_embeddings=True, mean_resizing=False):
for _ in tqdm(range(1), desc="Unify Embeddings"):
try:
student_emb = student_model.get_input_embeddings()
teacher_emb = teacher_model.get_input_embeddings()
if project_embeddings:
in_dim = teacher_emb.weight.shape[1]
out_dim = student_emb.weight.shape[1]
if in_dim != out_dim:
projection = nn.Linear(in_dim, out_dim).to(device)
teacher_emb_projected = projection(teacher_emb.weight)
else:
teacher_emb_projected = teacher_emb.weight
teacher_vocab_size_proj = teacher_emb_projected.shape[0]
min_vocab = min(student_emb.weight.shape[0], teacher_vocab_size_proj)
student_emb.weight.data[:min_vocab].copy_(teacher_emb_projected.data[:min_vocab])
else:
min_vocab = min(student_emb.weight.shape[0], teacher_emb.weight.shape[0])
student_emb.weight.data[:min_vocab].copy_(teacher_emb.weight.data[:min_vocab])
if hasattr(student_model, "get_output_embeddings") and hasattr(teacher_model, "get_output_embeddings"):
student_out_emb = student_model.get_output_embeddings()
teacher_out_emb = teacher_model.get_output_embeddings()
if student_out_emb is not None and teacher_out_emb is not None:
if project_embeddings:
in_dim = teacher_out_emb.weight.shape[1]
out_dim = student_out_emb.weight.shape[1]
if in_dim != out_dim:
projection = nn.Linear(in_dim, out_dim).to(device)
teacher_out_emb_projected = projection(teacher_out_emb.weight)
else:
teacher_out_emb_projected = teacher_out_emb.weight
teacher_vocab_size_proj = teacher_out_emb_projected.shape[0]
min_vocab = min(student_out_emb.weight.shape[0], teacher_vocab_size_proj)
student_out_emb.weight.data[:min_vocab].copy_(teacher_out_emb_projected.data[:min_vocab])
else:
min_vocab = min(student_out_emb.weight.shape[0], teacher_out_emb.weight.shape[0])
student_out_emb.weight.data[:min_vocab].copy_(teacher_out_emb.weight.data[:min_vocab])
print("Embeddings unified.")
except Exception as e:
print(f"Error in unify_embeddings: {e}")
def unify_tokenizers(student_tokenizer, teacher_tokenizer, student_model):
for _ in tqdm(range(1), desc="Unify Tokenizers"):
print("Unifying tokenizers...")
if teacher_tokenizer is None:
print("No teacher tokenizer provided, skipping unification.")
return student_tokenizer
try:
teacher_vocab = teacher_tokenizer.get_vocab()
student_vocab = student_tokenizer.get_vocab()
new_tokens = [token for token in teacher_vocab if token not in student_vocab]
if new_tokens:
print(f"Adding {len(new_tokens)} new tokens to student tokenizer.")
student_tokenizer.add_tokens(new_tokens)
student_model.resize_token_embeddings(len(student_tokenizer))
print("Student tokenizer vocab resized.")
else:
print("No new tokens to add to student tokenizer.")
except Exception as e:
print(f"Error in unify_tokenizers: {e}")
print("Tokenizers unified.")
return student_tokenizer
def unify_teacher_into_student(unified_teacher_state, student, force_parameter_copy=False):
for _ in tqdm(range(1), desc="Unify Teacher into Student"):
print("Unifying teacher into student...")
student_state = student.model.state_dict()
new_state = {}
for key, student_val in student_state.items():
if key in unified_teacher_state:
teacher_val = unified_teacher_state[key]
try:
if student_val.shape == teacher_val.shape:
new_state[key] = teacher_val
elif len(student_val.shape) == len(teacher_val.shape):
min_shape = [min(s, t) for s, t in zip(student_val.shape, teacher_val.shape)]
student_slice = tuple([slice(0, s) for s in min_shape])
teacher_slice = tuple([slice(0, s) for s in min_shape])
new_state[key] = student_val.clone()
new_state[key][student_slice] = teacher_val[teacher_slice]
else:
print(f"Shape mismatch for {key}. Skipping parameter copy.")
new_state[key] = student_val
except Exception as e:
print(f"Error copying parameter {key}: {e}. Skipping.")
new_state[key] = student_val
else:
new_state[key] = student_val
student.model.load_state_dict(new_state, strict=False)
print("Teacher unified into student.")
return student
def fuse_tokenizers(teacher_tokenizers, student_tokenizer):
for _ in tqdm(range(1), desc="Fuse Tokenizers"):
print("Fusing tokenizers vocabularies...")
unified_vocab = set(student_tokenizer.get_vocab().keys()) if student_tokenizer else set()
for t_tok in teacher_tokenizers:
if t_tok:
unified_vocab = unified_vocab.union(set(t_tok.get_vocab().keys()))
new_tokens = list(unified_vocab - set(student_tokenizer.get_vocab().keys()))
if new_tokens:
print(f"Adding {len(new_tokens)} tokens from teacher tokenizers to student tokenizer.")
try:
student_tokenizer.add_tokens(new_tokens)
except Exception as e:
print(f"Error adding tokens: {e}")
else:
print("No new tokens to add from teacher tokenizers.")
print("Tokenizers vocabularies fused.")
return student_tokenizer
def distillation_loss(student_logits, teacher_logits):
alpha = 0.5
temperature = 2.0
teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
student_soft = F.softmax(student_logits / temperature, dim=-1)
loss_soft = F.mse_loss(student_soft, teacher_soft)
loss_hard = F.mse_loss(student_logits, teacher_logits)
return alpha * loss_soft + (1 - alpha) * loss_hard
def run_pipeline(teacher1_id, teacher2_id, student_id, hf_token_input):
global GRADIO_LOG
GRADIO_LOG = ""
try:
login(token=hf_token_input)
api = HfApi()
proc = AutoProcessor.from_pretrained(teacher1_id)
teacher1_model = AutoModelForImageTextToText.from_pretrained(teacher1_id)
teacher1_model.to(device); teacher1_model.eval()
teacher1_tokenizer = proc.tokenizer if hasattr(proc, "tokenizer") else proc
teacher2_tokenizer = AutoTokenizer.from_pretrained(teacher2_id)
teacher2_model = AutoModelForCausalLM.from_pretrained(teacher2_id)
teacher2_model.to(device); teacher2_model.eval()
student_tokenizer = AutoTokenizer.from_pretrained(student_id)
if student_tokenizer.pad_token is None:
student_tokenizer.add_special_tokens({'pad_token': student_tokenizer.eos_token if student_tokenizer.eos_token is not None else "[PAD]"})
student_model = AutoModelForCausalLM.from_pretrained(student_id)
student_model.to(device)
print("Modelos cargados")
except Exception as e:
return f"Error cargando modelos: {e}"
try:
t1_task = getattr(teacher1_model.config, "task_type", "image-to-text")
except Exception:
t1_task = "image-to-text"
try:
t2_task = getattr(teacher2_model.config, "task", "text-generation")
except Exception:
t2_task = "text-generation"
try:
s_task = getattr(student_model.config, "task", "text-generation")
except Exception:
s_task = "text-generation"
try:
pipe_t1 = pipeline(t1_task, model=teacher1_model, tokenizer=teacher1_tokenizer, device=-1)
pipe_t2 = pipeline(t2_task, model=teacher2_model, tokenizer=teacher2_tokenizer, device=-1)
pipe_s = pipeline(s_task, model=student_model, tokenizer=student_tokenizer, device=-1)
print("Pipelines creados:")
print("Teacher1:", pipe_t1.task)
print("Teacher2:", pipe_t2.task)
print("Student:", pipe_s.task)
except Exception as e:
return f"Error creando pipelines: {e}"
try:
out_t1 = pipe_t1(Image.new("RGB", (224,224), color=(0,0,0)))
out_t2 = pipe_t2("Test input")
out_s = pipe_s("Test input")
print("Outputs de prueba:")
print("Teacher1:", out_t1)
print("Teacher2:", out_t2)
print("Student:", out_s)
except Exception as e:
return f"Error en ejecuci贸n de pipelines de prueba: {e}"
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
student_model.train()
print("Inicio del entrenamiento por destilaci贸n")
texts = ["Texto de ejemplo 1", "Texto de ejemplo 2"]
for epoch in range(1):
print(f"脡poca {epoch+1}")
for text in tqdm(texts, desc="Entrenamiento destilaci贸n"):
optimizer.zero_grad()
inputs = student_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=16)
with torch.no_grad():
t1_inputs = teacher1_tokenizer(text=[text], return_tensors="pt", padding="max_length", truncation=True, max_length=16)
t2_inputs = teacher2_tokenizer(text=[text], return_tensors="pt", padding="max_length", truncation=True, max_length=16)
teacher1_out = teacher1_model(**t1_inputs, max_new_tokens=10)
teacher2_out = teacher2_model(**t2_inputs, max_new_tokens=10)
teacher_logits = (teacher1_out.logits + teacher2_out.logits) / 2.0
student_out = student_model(**inputs)
student_logits = student_out.logits
loss = distillation_loss(student_logits, teacher_logits)
loss.backward()
optimizer.step()
print(f"Texto: {text} | Loss: {loss.item():.6f}")
print("Entrenamiento completado")
save_checkpoint(student_model)
print("Aplicando cuantizaci贸n din谩mica al modelo student")
student_quantized = torch.quantization.quantize_dynamic(student_model, {nn.Linear}, dtype=torch.qint8)
def save_model_state(model, filename):
state_dict = model.state_dict()
np_state = {k: v.cpu().numpy() for k, v in state_dict.items()}
save_file(np_state, str(filename))
for _ in tqdm(range(1), desc="Guardando modelos"):
pass
teacher1_file = save_dir / "gemma_teacher_model_quant.safetensors"
teacher2_file = save_dir / "llama_teacher_model_quant.safetensors"
student_file = save_dir / "gemma_student_model_quant.safetensors"
save_model_state(teacher1_model, teacher1_file)
save_model_state(teacher2_model, teacher2_file)
save_model_state(student_quantized, student_file)
print(f"Modelos guardados en {save_dir}")
try:
user_info = HfApi().whoami(token=hf_token_input)
username = user_info["name"]
repo_id_t1 = f"{username}/gemma-teacher-pytorch-safetensors"
HfApi().create_repo(repo_id_t1, token=hf_token_input, exist_ok=True)
HfApi().upload_file(token=hf_token_input, path_or_fileobj=str(teacher1_file), path_in_repo="gemma_teacher_model_quant.safetensors", repo_id=repo_id_t1)
repo_id_t2 = f"{username}/llama-teacher-pytorch-safetensors"
HfApi().create_repo(repo_id_t2, token=hf_token_input, exist_ok=True)
HfApi().upload_file(token=hf_token_input, path_or_fileobj=str(teacher2_file), path_in_repo="llama_teacher_model_quant.safetensors", repo_id=repo_id_t2)
repo_id_student = f"{username}/gemma-student-pytorch-safetensors"
HfApi().create_repo(repo_id_student, token=hf_token_input, exist_ok=True)
HfApi().upload_file(token=hf_token_input, path_or_fileobj=str(student_file), path_in_repo="gemma_student_model_quant.safetensors", repo_id=repo_id_student)
print("Modelos subidos al Hub de Hugging Face")
except Exception as e:
print(f"Error subiendo modelos: {e}")
return GRADIO_LOG + "\nProceso completado."
iface = gr.Interface(
fn=run_pipeline,
inputs=[
gr.Textbox(label="Teacher1 Model ID (ImageText)"),
gr.Textbox(label="Teacher2 Model ID (CausalLM)"),
gr.Textbox(label="Student Model ID (CausalLM)"),
gr.Textbox(label="Hugging Face Token")
],
outputs="text",
title="Pipeline de Unificaci贸n y Destilaci贸n de Modelos",
description="Ingrese los IDs de los modelos y su token HF para ejecutar el pipeline autom谩ticamente."
)
iface.launch()