Train_xd / app_demo.py
Ignaciohhhhggfgjfrffd's picture
Rename app.py to app_demo.py
f621cb4 verified
import os, io, json, tempfile, string
os.system("pip install -U transformers peft datasets accelerate bitsandbytes trl scipy einops evaluate zstandard wandb tokenizers")
os.system("pip install spaces-0.1.0-py3-none-any.whl")
import gradio as gr
import spaces
import torch
import logging
import multiprocessing
import uuid
import gc
import math
from itertools import islice
from datasets import load_dataset, IterableDataset, interleave_datasets
from huggingface_hub import login, whoami, create_repo, upload_folder
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig,
IntervalStrategy,
LlamaConfig, LlamaForCausalLM,
MistralConfig, MistralForCausalLM,
GemmaConfig, GemmaForCausalLM,
GPT2Config, GPT2LMHeadModel,
PreTrainedTokenizerFast
)
from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer
from tokenizers import ByteLevelBPETokenizer
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = multiprocessing.cpu_count()
ARCHITECTURE_MAP = {
"Llama": (LlamaConfig, LlamaForCausalLM),
"Mistral": (MistralConfig, MistralForCausalLM),
"Gemma": (GemmaConfig, GemmaForCausalLM),
"GPT2": (GPT2Config, GPT2LMHeadModel),
}
def _normalize_text_helper(text, do_lowercase, do_remove_punct):
if not isinstance(text, str):
return text
if do_lowercase:
text = text.lower()
if do_remove_punct:
text = text.translate(str.maketrans('', '', string.punctuation))
return text
def _load_hf_streaming(ids):
streams = {"train": [], "validation": []}
CONFLICT_COLUMNS = ['id', 'raw_id', 'shard_id', 'num_shards', 'meta']
for ident in ids:
try:
d = load_dataset(ident, streaming=True)
def clean_and_keep_text(example):
keys_to_drop = [k for k in example.keys() if k in CONFLICT_COLUMNS or k.endswith('_id')]
for k in keys_to_drop:
if k in example:
del example[k]
return example
if isinstance(d, dict):
for split, ds in d.items():
cleaned_ds = ds.map(clean_and_keep_text, batched=False)
if "train" in split: streams["train"].append(cleaned_ds)
elif "validation" in split or "test" in split: streams["validation"].append(cleaned_ds)
else:
cleaned_ds = d.map(clean_and_keep_text, batched=False)
streams["train"].append(cleaned_ds)
streams["validation"].append(cleaned_ds.skip(1).map(lambda x: x).with_format(type_injection_class=IterableDataset).map(lambda x: x, batched=False))
except Exception as e:
logger.error(f"Error cargando o limpiando dataset {ident}: {e}")
continue
return streams
def _load_uploaded_stream(files):
all_rows = []
for f in files or []:
name = f.name.lower()
content = f.read().decode("utf-8", errors="ignore")
if name.endswith(".csv"):
import csv
reader = csv.DictReader(io.StringIO(content))
for row in reader: all_rows.append(row)
elif name.endswith(".jsonl"):
for line in io.StringIO(content):
try: all_rows.append(json.loads(line))
except: pass
elif name.endswith(".json"):
try:
data = json.loads(content)
if isinstance(data, list): all_rows.extend(data)
elif isinstance(data, dict): all_rows.append(data)
except: pass
elif name.endswith(".txt"):
for line in io.StringIO(content):
line = line.strip()
if line: all_rows.append({"text": line})
if not all_rows:
return {"train": IterableDataset.from_generator(lambda: iter([])),
"validation": IterableDataset.from_generator(lambda: iter([]))}
n = len(all_rows)
val_size = max(1, n // 100)
train_data = all_rows[:-val_size]
val_data = all_rows[-val_size:]
train_ds = IterableDataset.from_generator(lambda: iter(train_data))
val_ds = IterableDataset.from_generator(lambda: iter(val_data))
return {"train": train_ds, "validation": val_ds}
def _guess_columns(sample):
text_col = None
if isinstance(sample, dict):
keys = list(sample.keys())
for k in keys:
if k.lower() in ["text","content","prompt","input","messages","sentence","review","body","ctx","question"]: text_col = k
return text_col or "text", None
def hf_login(token):
if not token or token.strip() == "":
return "Error: Por favor, introduce tu token de Hugging Face."
try:
login(token=token.strip(), add_to_git_credential=True)
user = whoami()
name = user.get("name") or user.get("fullname") or "usuario"
email = user.get("email") or ""
return f"Sesi贸n iniciada: {name} {f'({email})' if email else ''}"
except Exception as e:
return f"Error de inicio de sesi贸n: {e}"
def _sft_formatting_func(example, text_col, tokenizer,
do_lowercase, do_remove_punct,
enable_cot, prompt_col, reasoning_col, response_col):
if enable_cot:
prompt = example.get(prompt_col, "")
reasoning = example.get(reasoning_col, "")
response = example.get(response_col, "")
prompt = _normalize_text_helper(prompt, do_lowercase, do_remove_punct)
reasoning = _normalize_text_helper(reasoning, do_lowercase, do_remove_punct)
response = _normalize_text_helper(response, do_lowercase, do_remove_punct)
if reasoning:
return f"Prompt: {prompt}\n\nReasoning: {reasoning}\n\nResponse: {response}"
else:
return f"Prompt: {prompt}\n\nResponse: {response}"
if text_col == "messages" and hasattr(tokenizer, 'apply_chat_template'):
processed_messages = []
for msg in example.get(text_col, []):
new_msg = msg.copy()
if 'content' in new_msg and isinstance(new_msg['content'], str):
new_msg['content'] = _normalize_text_helper(new_msg['content'], do_lowercase, do_remove_punct)
processed_messages.append(new_msg)
return tokenizer.apply_chat_template(processed_messages, tokenize=False, add_generation_prompt=False)
text = example.get(text_col)
if isinstance(text, str):
return _normalize_text_helper(text, do_lowercase, do_remove_punct)
return ""
def get_training_corpus_iterator(dataset, text_col, chunk_size=1000):
if not dataset or not text_col:
return
iterator = iter(dataset)
while True:
chunk = list(islice(iterator, chunk_size))
if not chunk:
break
texts = []
for example in chunk:
text = example.get(text_col)
if isinstance(text, str) and text.strip():
texts.append(text.strip())
elif text_col == "messages" and isinstance(text, list):
for msg in text:
if isinstance(msg, dict) and isinstance(msg.get('content'), str):
texts.append(msg['content'].strip())
if texts:
yield texts
def _count_dataset_size(dataset):
count = 0
try:
for _ in dataset:
count += 1
except Exception:
pass
return count
def _calculate_auto_config(train_dataset, block_size, scratch_architecture):
size = _count_dataset_size(train_dataset)
if size == 0:
return 32000, 512, 1024, 8, 8, 512, False, 8, "meta-llama/Meta-Llama-3-8B"
log_size = math.log2(max(1000, size))
vocab_size = min(65536, 32000 + int(log_size * 2000))
hidden_size = min(2048, 512 + int(log_size * 50))
hidden_size = max(512, hidden_size)
intermediate_size = hidden_size * 2
layers = min(24, 4 + int(log_size * 0.5))
layers = max(4, int(layers))
heads = max(4, hidden_size // 64)
heads = min(32, heads)
max_pos_embed = int(block_size)
kv_heads = heads
if scratch_architecture in ["Mistral", "Llama", "Gemma"]:
kv_heads = max(1, heads // 8)
if hidden_size < 1024:
kv_heads = heads
tie_embeddings = False
base_tokenizer_name = "meta-llama/Meta-Llama-3-8B"
return vocab_size, hidden_size, intermediate_size, layers, heads, max_pos_embed, tie_embeddings, kv_heads, base_tokenizer_name
@spaces.GPU()
def train_and_upload(model_base_input, datasets_hf_text, uploads, repo_name_input,
train_from_scratch, scratch_architecture,
add_eos_token, auto_find_batch_size, chat_template,
disable_gradient_checkpointing, distributed_backend, eval_strategy,
load_best_model_at_end,
merge_adapter, mixed_precision, optimizer, peft, padding,
quantization, scheduler, batch_size, block_size, epochs,
gradient_accumulation, learning_rate, logging_steps, lora_alpha,
lora_dropout, lora_r, max_grad_norm,
save_total_limit, seed, warmup_ratio, weight_decay, target_modules,
steps_per_epoch_estimate,
trust_remote_code_input, attn_implementation_input, new_special_tokens_input,
apply_lowercase_input, remove_punctuation_input,
enable_cot_input, prompt_col_input, reasoning_col_input, response_col_input,
wandb_project_input, wandb_api_key_input,
scratch_vocab_size, scratch_special_tokens, scratch_base_tokenizer,
scratch_hidden_size, scratch_intermediate_size, scratch_num_hidden_layers,
scratch_num_attention_heads, scratch_num_key_value_heads, scratch_max_pos_embed, scratch_tie_word_embeddings,
auto_config_scratch,
progress=gr.Progress()):
temp_dir = tempfile.mkdtemp()
logs = ""
repo_link = ""
config_data = {}
def update_logs(new_msg, phase_msg, step_ratio=None):
nonlocal logs
logs += f"{new_msg}\n"
if step_ratio is not None:
progress(step_ratio)
return logs, phase_msg, repo_link, None
try:
user = whoami()
username = user.get("name") or "hf_user"
model = model_base_input.strip()
if not model and not train_from_scratch:
logs += "Error: Debe especificar un ID de **Modelo Base** o activar 'Entrenar desde Cero'.\n"
yield logs, "Error", repo_link, None
return
if repo_name_input and repo_name_input.strip():
repo_base = repo_name_input.strip().replace(" ", "-")
else:
random_suffix = uuid.uuid4().hex[:6]
if train_from_scratch:
model_slug = f"{scratch_architecture.lower()}-{int(scratch_hidden_size)}-{int(scratch_num_hidden_layers)}l"
else:
model_slug = model.split('/')[-1].replace('.', '-').lower()
repo_base = f"{model_slug}-sft-{random_suffix}"
repo_id = f"{username}/{repo_base}"
hf_ids = [x.strip() for x in (datasets_hf_text or "").split(",") if x.strip()]
all_ds = {"train": [], "validation": []}
if hf_ids or uploads:
yield update_logs("Cargando datasets en streaming...", "Cargando Datos", 0.05)
dsh = _load_hf_streaming(hf_ids)
all_ds["train"].extend(dsh.get("train", []))
all_ds["validation"].extend(dsh.get("validation", []))
dsu = _load_uploaded_stream(uploads)
all_ds["train"].append(dsu["train"])
all_ds["validation"].append(dsu["validation"])
valid_train_streams = [ds for ds in all_ds["train"] if ds and next(iter(ds), None) is not None]
valid_val_streams = [ds for ds in all_ds["validation"] if ds and next(iter(ds), None) is not None]
if not valid_train_streams:
logs += "Error: No se encontraron datasets v谩lidos o con contenido para entrenar.\n"
yield logs, "Error", repo_link, None
return
if not valid_val_streams and (eval_strategy.lower() == "steps" or load_best_model_at_end):
logs += "Advertencia: La evaluaci贸n est谩 habilitada (`eval_strategy` o `load_best_model_at_end`) pero no se encontraron datasets de validaci贸n. Deshabilitando evaluaci贸n.\n"
eval_strategy = "no"
load_best_model_at_end = False
train_dataset = interleave_datasets(valid_train_streams)
validation_dataset = None
if valid_val_streams:
validation_dataset = interleave_datasets(valid_val_streams)
text_col = "text"
if not enable_cot_input:
try:
sample = next(iter(train_dataset))
text_col, _ = _guess_columns(sample)
except Exception:
text_col = "text"
yield update_logs(f"Columna de texto detectada: {text_col}", "Detectado", 0.10)
else:
yield update_logs("Formato de Razonamiento (CoT) activado.", "Detectado", 0.10)
logs += f" Prompt: {prompt_col_input}\n"
if reasoning_col_input:
logs += f" Razonamiento: {reasoning_col_input}\n"
logs += f" Respuesta: {response_col_input}\n"
yield logs, "Detectado", repo_link, None
if train_from_scratch:
yield update_logs(f"--- Modo 'Entrenar desde Cero' activado (Arquitectura: {scratch_architecture}) ---", "Preparando Modelo", 0.15)
logs += "PEFT, Quantization y Merge ser谩n desactivados.\n"
peft = False
merge_adapter = False
quantization = "none"
config_data["train_from_scratch"] = True
config_data["architecture"] = scratch_architecture
if auto_config_scratch:
yield update_logs("Calculando configuraci贸n del modelo autom谩ticamente...", "Preparando Modelo")
(auto_vocab_size, auto_hidden_size, auto_intermediate_size,
auto_layers, auto_heads, auto_max_pos_embed, auto_tie_embeddings, auto_kv_heads,
auto_base_tokenizer_name) = _calculate_auto_config(train_dataset, int(block_size), scratch_architecture)
scratch_vocab_size = auto_vocab_size
scratch_hidden_size = auto_hidden_size
scratch_intermediate_size = auto_intermediate_size
scratch_num_hidden_layers = auto_layers
scratch_num_attention_heads = auto_heads
scratch_num_key_value_heads = auto_kv_heads
scratch_max_pos_embed = auto_max_pos_embed
scratch_tie_word_embeddings = auto_tie_embeddings
scratch_base_tokenizer = auto_base_tokenizer_name
yield update_logs(f"Config Autocalculada: Vocab={auto_vocab_size}, Hidden={auto_hidden_size}, Layers={auto_layers}, Heads={auto_heads}, KV={auto_kv_heads}", "Preparando Modelo")
yield update_logs("Iniciando entrenamiento de nuevo tokenizer...", "Preparando Modelo")
base_tok_name = scratch_base_tokenizer.strip() if scratch_base_tokenizer.strip() else "meta-llama/Meta-Llama-3-8B"
yield update_logs(f"Cargando tokenizer base: {base_tok_name}", "Preparando Modelo")
base_tok = AutoTokenizer.from_pretrained(base_tok_name, use_fast=True)
special_tokens_list = [t.strip() for t in scratch_special_tokens.split(",") if t.strip()]
corpus_iterator = get_training_corpus_iterator(train_dataset, text_col)
tokenizer = base_tok.train_new_from_iterator(
corpus_iterator,
vocab_size=int(scratch_vocab_size)
)
yield update_logs(f"Nuevo tokenizer entrenado. Vocab Size: {tokenizer.vocab_size}", "Preparando Modelo")
if special_tokens_list:
current_special_tokens = tokenizer.all_special_tokens
new_tokens_to_add = [t for t in special_tokens_list if t not in current_special_tokens]
if new_tokens_to_add:
tokenizer.add_special_tokens({"additional_special_tokens": new_tokens_to_add})
if "<s>" in special_tokens_list: tokenizer.bos_token = "<s>"
if "</s>" in special_tokens_list: tokenizer.eos_token = "</s>"
if "<unk>" in special_tokens_list: tokenizer.unk_token = "<unk>"
if "<pad>" in special_tokens_list: tokenizer.pad_token = "<pad>"
if "<mask>" in special_tokens_list: tokenizer.mask_token = "<mask>"
if tokenizer.pad_token is None:
yield update_logs("Advertencia: No se encontr贸 '<pad>' en los tokens especiales. Usando '</s>' como pad_token.", "Preparando Modelo")
tokenizer.pad_token = tokenizer.eos_token
if scratch_architecture not in ARCHITECTURE_MAP:
logs += f"Error: Arquitectura '{scratch_architecture}' no soportada.\n"
yield logs, "Error", repo_link, None
return
ConfigClass, ModelClass = ARCHITECTURE_MAP[scratch_architecture]
yield update_logs(f"Creando {ConfigClass.__name__} para nuevo modelo...", "Preparando Modelo")
config_params = {
"vocab_size": tokenizer.vocab_size,
"pad_token_id": tokenizer.pad_token_id,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"initializer_range": 0.02,
"use_cache": True,
"tie_word_embeddings": scratch_tie_word_embeddings,
}
if scratch_architecture == "GPT2":
config_params.update({
"n_embd": int(scratch_hidden_size),
"n_layer": int(scratch_num_hidden_layers),
"n_head": int(scratch_num_attention_heads),
"n_positions": int(scratch_max_pos_embed),
})
else:
config_params.update({
"hidden_size": int(scratch_hidden_size),
"intermediate_size": int(scratch_intermediate_size),
"num_hidden_layers": int(scratch_num_hidden_layers),
"num_attention_heads": int(scratch_num_attention_heads),
"max_position_embeddings": int(scratch_max_pos_embed),
"rms_norm_eps": 1e-6,
})
if scratch_architecture in ["Mistral", "Gemma"] and scratch_num_key_value_heads > 0:
config_params["num_key_value_heads"] = int(scratch_num_key_value_heads)
config = ConfigClass(**config_params)
config_data["model_config"] = config.to_dict()
yield update_logs(f"Inicializando {ModelClass.__name__} desde cero...", "Preparando Modelo")
model_hf = ModelClass(config).to(device)
torch_dtype = torch.float32
else:
yield update_logs("--- Modo 'Fine-Tuning' activado ---", "Preparando Modelo", 0.15)
config_data["train_from_scratch"] = False
config_data["model_base"] = model
quantization_val = quantization if quantization != "none" else None
bnb_config = None
if quantization_val == "int4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True,
)
elif quantization_val == "int8":
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
elif mixed_precision in ["fp16", "bf16"]:
pass
else:
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
torch_dtype = torch.float32
if mixed_precision == "bf16":
torch_dtype = torch.bfloat16
elif mixed_precision == "fp16":
torch_dtype = torch.float16
yield update_logs("Cargando Tokenizer y Modelo (Fine-Tuning)...", "Cargando Modelo", 0.20)
tokenizer = AutoTokenizer.from_pretrained(
model,
padding_side=padding,
add_eos_token=add_eos_token,
trust_remote_code=trust_remote_code_input,
use_fast=False
)
chat_template_str = chat_template.strip() if chat_template else "tokenizer"
if chat_template_str.lower() == "none":
tokenizer.chat_template = None
yield update_logs("Plantilla de chat deshabilitada.", "Cargando Modelo")
elif chat_template_str.lower() != "tokenizer":
tokenizer.chat_template = chat_template_str
yield update_logs("Aplicando plantilla de chat personalizada.", "Cargando Modelo")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {
"quantization_config": bnb_config,
"device_map": "auto",
"trust_remote_code": trust_remote_code_input,
"torch_dtype": torch_dtype,
}
attn_impl_val = attn_implementation_input.strip().lower() if attn_implementation_input else "auto"
if attn_impl_val and attn_impl_val != "auto":
model_kwargs["attn_implementation"] = attn_impl_val
yield update_logs(f"Usando attn_implementation: {attn_impl_val}", "Cargando Modelo")
model_hf = AutoModelForCausalLM.from_pretrained(
model,
**model_kwargs
).to(device)
if new_special_tokens_input:
tokens_to_add = [t.strip() for t in new_special_tokens_input.split(",") if t.strip()]
if tokens_to_add:
yield update_logs(f"A帽adiendo {len(tokens_to_add)} tokens y redimensionando embeddings...", "Cargando Modelo")
tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add})
model_hf.resize_token_embeddings(len(tokenizer))
if hasattr(model_hf, "tie_weights"):
model_hf.tie_weights()
else:
input_embeddings = model_hf.get_input_embeddings()
output_embeddings = model_hf.get_output_embeddings()
if output_embeddings is not None and input_embeddings.weight.shape == output_embeddings.weight.shape:
output_embeddings.weight.data = input_embeddings.weight.data
if quantization_val is not None:
model_hf = prepare_model_for_kbit_training(model_hf)
peft_config = None
if peft and not train_from_scratch:
yield update_logs("Configurando PEFT (LoRA)...", "Preparando Trainer", 0.25)
peft_config = LoraConfig(
r=int(lora_r),
lora_alpha=float(lora_alpha),
lora_dropout=float(lora_dropout),
target_modules=target_modules.split(",") if target_modules != "all-linear" else None,
bias="none",
task_type="CAUSAL_LM"
)
config_data["lora_config"] = peft_config.to_dict()
else:
yield update_logs("Entrenamiento completo (sin PEFT) activado.", "Preparando Trainer", 0.25)
eval_strategy_lower = eval_strategy.lower()
num_steps_calculated = int(float(epochs) * float(steps_per_epoch_estimate) / int(gradient_accumulation) / int(batch_size))
yield update_logs(f"C谩lculo de max_steps para streaming (basado en {steps_per_epoch_estimate} pasos/茅poca): {num_steps_calculated} pasos.", "Preparando Trainer")
if eval_strategy_lower == "steps":
save_strategy_val = IntervalStrategy.STEPS
elif eval_strategy_lower == "epoch":
save_strategy_val = IntervalStrategy.EPOCH
else:
save_strategy_val = IntervalStrategy.NO
report_to_val = "none"
if wandb_project_input and wandb_project_input.strip():
yield update_logs(f"Habilitando logging en Weights & Biases (Proyecto: {wandb_project_input.strip()}).", "Preparando Trainer")
os.environ["WANDB_DISABLED"] = "false"
os.environ["WANDB_PROJECT"] = wandb_project_input.strip()
if wandb_api_key_input and wandb_api_key_input.strip():
os.environ["WANDB_API_KEY"] = wandb_api_key_input.strip()
report_to_val = "wandb"
else:
os.environ["WANDB_DISABLED"] = "true"
yield update_logs("Logging en W&B deshabilitado.", "Preparando Trainer")
training_args = TrainingArguments(
output_dir=os.path.join(temp_dir, "results"),
num_train_epochs=float(epochs),
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size),
gradient_accumulation_steps=int(gradient_accumulation),
optim=optimizer,
save_strategy=save_strategy_val,
save_steps=int(logging_steps) * 10,
logging_steps=int(logging_steps),
evaluation_strategy=eval_strategy_lower,
eval_steps=int(logging_steps) if eval_strategy_lower == "steps" else None,
learning_rate=float(learning_rate),
fp16=(mixed_precision == "fp16"),
bf16=(mixed_precision == "bf16"),
max_grad_norm=float(max_grad_norm),
warmup_ratio=float(warmup_ratio),
weight_decay=float(weight_decay),
load_best_model_at_end=load_best_model_at_end,
save_total_limit=int(save_total_limit),
gradient_checkpointing=not disable_gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
push_to_hub=True,
hub_model_id=repo_id,
hub_private_repo=False,
disable_tqdm=False,
max_steps=num_steps_calculated,
report_to=report_to_val,
save_safetensors=True,
save_only_model=True,
dataloader_num_workers=0,
)
yield update_logs("Inicializando SFT Trainer...", "Preparando Trainer", 0.30)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
formatting_lambda = lambda example: _sft_formatting_func(
example,
text_col,
tokenizer,
apply_lowercase_input,
remove_punctuation_input,
enable_cot_input,
prompt_col_input,
reasoning_col_input,
response_col_input
)
trainer = SFTTrainer(
model=model_hf,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
peft_config=peft_config,
formatting_func=formatting_lambda,
max_seq_length=int(block_size)
)
yield update_logs("Iniciando entrenamiento...", "Entrenando", 0.40)
trainer.train()
yield update_logs("Entrenamiento finalizado. Guardando modelo final...", "Guardando Modelo", 0.85)
trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
del model_hf
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if merge_adapter and peft and not train_from_scratch:
yield update_logs("Fusionando adaptadores LoRA con el modelo base...", "Fusionando", 0.90)
del trainer
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
base_model_for_merge = AutoModelForCausalLM.from_pretrained(
model,
trust_remote_code=trust_remote_code_input,
torch_dtype=torch_dtype,
attn_implementation=attn_impl_val if attn_impl_val != "auto" else None
).to(device)
ft = PeftModel.from_pretrained(
base_model_for_merge,
training_args.output_dir, torch_dtype=torch.float32,
is_trainable=False, device_map={"": device}
).merge_and_unload()
output_model_dir = os.path.join(tempfile.mkdtemp(), "final_merged_model")
ft.save_pretrained(output_model_dir, safe_serialization=True)
tokenizer.save_pretrained(output_model_dir)
upload_folder(
folder_path=output_model_dir,
repo_id=repo_id,
commit_message="Modelo fusionado (PEFT y base) en safetensors",
allow_patterns=["*"]
)
del ft, base_model_for_merge
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
yield update_logs("Subida manejada por Trainer.push_to_hub().", "Subiendo a Hub", 0.95)
repo_link = f"https://huggingface.co/{repo_id}"
config_data.update({
"training_arguments": {},
"quantization": quantization if not train_from_scratch else "none",
"text_column": text_col if not enable_cot_input else "N/A (CoT Enabled)",
"hub_url": repo_link,
"load_config": {
"trust_remote_code": trust_remote_code_input,
"attn_implementation": attn_impl_val,
"new_special_tokens": new_special_tokens_input,
"chat_template": tokenizer.chat_template
},
"normalization_config": {
"apply_lowercase": apply_lowercase_input,
"remove_punctuation": remove_punctuation_input
},
"cot_config": {
"enabled": enable_cot_input,
"prompt_col": prompt_col_input,
"reasoning_col": reasoning_col_input,
"response_col": response_col_input
},
"wandb_config": {
"project": wandb_project_input
}
})
training_args_dict = training_args.to_dict()
for key, value in training_args_dict.items():
if isinstance(value, set):
config_data["training_arguments"][key] = list(value)
elif isinstance(value, IntervalStrategy):
config_data["training_arguments"][key] = str(value)
else:
config_data["training_arguments"][key] = value
cfg_path = os.path.join(tempfile.mkdtemp(), "config_result.json")
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, ensure_ascii=False, indent=2)
logs += f"Subida completa: {repo_link}\n"
yield logs, "Listo", repo_link, cfg_path
except Exception as e:
err = f"Error fatal de ejecuci贸n: {type(e).__name__}: {e}\n"
import traceback
err += traceback.format_exc()
yield err, "Error", "", None
css = """
:root {
--bg:#f8f8f8;
--card:#ffffff;
--muted:#e0e0e0;
--text:#1f2937;
--sub:#6b7280;
--accent:#818cf8;
--accent-hover:#a78bfa;
--shadow:rgba(0,0,0,0.1);
--font-family: 'Inter', sans-serif;
}
.gradio-container {
background-color: var(--bg) !important;
font-family: var(--font-family);
color: var(--text);
}
div:not(.label-wrap), label:not(.label-wrap label) { color: var(--text) !important; }
h1, h2, h3 {
letter-spacing: .2px;
color: var(--text);
font-weight: 700;
}
button {
border-radius: 12px !important;
padding: 10px 14px !important;
border: 1px solid var(--muted) !important;
background: linear-gradient(180deg, #f0f0f0, #ffffff) !important;
color: var(--text) !important;
transition: all .15s ease;
box-shadow: 0 2px 5px var(--shadow);
}
button:hover {
transform: translateY(-1px);
border-color: #c0c0c0 !important;
box-shadow: 0 4px 10px var(--shadow);
}
.primary-btn button {
background: linear-gradient(90deg, #818cf8, #a855f7, #ec4899) !important;
border: none !important;
color: white !important;
font-weight: 600;
}
.primary-btn button:hover {
transform: scale(1.01);
box-shadow: 0 8px 20px rgba(130, 100, 255, 0.4);
}
input[type="text"],
input[type="password"],
input[type="number"],
textarea,
select {
background: var(--card) !important;
border-radius: 10px !important;
border: 1px solid var(--muted) !important;
color: var(--text) !important;
padding: 10px;
}
.label-wrap label {
color: var(--sub) !important;
font-size: 14px;
}
.card {
background: var(--card);
border: 1px solid var(--muted);
border-radius: 16px;
padding: 20px;
box-shadow: 0 4px 15px var(--shadow);
}
.textbox-logs textarea {
font-family: monospace;
font-size: 12px;
line-height: 1.4;
background: #f0f0f0 !important;
color: var(--text) !important;
}
.title-h1 {
text-align: center;
font-size: 32px;
font-weight: 800;
color: #1f2937;
letter-spacing: 1px;
margin-bottom: 4px;
}
.title-subtitle {
text-align: center;
font-size: 14px;
color: #6b7280;
margin-bottom: 20px;
}
"""
with gr.Blocks(title="HuggingFace SFT Trainer Studio", theme="gradio/soft", css=css) as demo:
gr.Markdown("# HuggingFace SFT Trainer Studio", elem_classes="title-h1")
gr.Markdown("Afinaci贸n avanzada multi-dataset con TRL (SFT) y pipeline de datos en streaming.", elem_classes="title-subtitle")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### 1. Autenticaci贸n")
token_input = gr.Textbox(
label="Token de Acceso de Hugging Face",
type="password",
placeholder="Ingresa tu token (hf_xxx...)",
lines=1
)
login_btn = gr.Button("Conectar y Guardar Token")
login_status = gr.Textbox(label="Estado", lines=1)
login_btn.click(fn=hf_login, inputs=[token_input], outputs=login_status)
gr.Markdown("### 2. Configuraci贸n de Modelo y Repositorio")
train_from_scratch = gr.Checkbox(
label="Entrenar desde Cero (Nuevo Modelo y Tokenizer)",
value=False,
info="Si se activa, se ignorar谩 el 'Modelo Base' y se entrenar谩 un nuevo modelo/tokenizer."
)
model_base_input = gr.Textbox(
label="Modelo Base (ID del Hub)",
value="",
placeholder="Ej: meta-llama/Meta-Llama-3-8B. (Ignorado si 'Entrenar desde Cero' est谩 activo)",
visible=True
)
repo_name_input = gr.Textbox(
label="Nombre del Repositorio de Salida",
value="",
placeholder="Opcional. Ej: mi-llama-personalizado. (Se generar谩 uno si est谩 vac铆o)"
)
with gr.Column(scale=2):
with gr.Group(elem_classes="card"):
gr.Markdown("### 3. Datos de Entrenamiento (Streaming)")
ds_text = gr.Textbox(
label="Datasets de Hugging Face (separados por coma)",
placeholder="tatsu-lab/alpaca, OpenAssistant/oasst1",
lines=2
)
uploads = gr.Files(label="Subir datasets locales (csv/json/jsonl/txt)", file_count="multiple", type="binary")
with gr.Accordion("鈿欙笍 Par谩metros Avanzados", open=False):
with gr.Tabs() as tabs_advanced:
with gr.TabItem("Arquitectura (Desde Cero)", id="tab_scratch"):
scratch_group = gr.Group(visible=False)
with scratch_group:
auto_config_scratch = gr.Checkbox(
label="Calcular Configuraci贸n del Modelo Autom谩ticamente (Basado en el tama帽o del dataset)",
value=False
)
gr.Markdown("#### Configuraci贸n del Tokenizer y Arquitectura")
with gr.Row():
scratch_architecture = gr.Dropdown(
label="Arquitectura del Modelo",
choices=list(ARCHITECTURE_MAP.keys()),
value="Llama",
info="Elige la arquitectura base para el nuevo modelo."
)
scratch_base_tokenizer = gr.Textbox(
label="Tokenizer Base para Entrenar",
value="meta-llama/Meta-Llama-3-8B",
info="Usaremos este tokenizer para 'train_new_from_iterator'."
)
with gr.Row():
scratch_vocab_size = gr.Number(label="Tama帽o del Vocabulario (Nuevo Tokenizer)", value=32000)
scratch_special_tokens = gr.Textbox(
label="Tokens Especiales (Nuevo Tokenizer)",
value="<s>,<pad>,</s>,<unk>,<mask>,<|user|>,<|bot|>,<|end|>",
info="Separados por coma."
)
gr.Markdown("#### Configuraci贸n del Modelo (Hiperpar谩metros)")
with gr.Row():
scratch_hidden_size = gr.Number(label="hidden_size / n_embd", value=512)
scratch_intermediate_size = gr.Number(label="intermediate_size", value=1024, info="Ignorado por arquitecturas como GPT2.")
with gr.Row():
scratch_num_hidden_layers = gr.Number(label="num_hidden_layers / n_layer", value=8)
scratch_num_attention_heads = gr.Number(label="num_attention_heads / n_head", value=8)
with gr.Row():
scratch_num_key_value_heads = gr.Number(label="num_key_value_heads", value=8, info="Importante para Mistral/Gemma (GQA). Dejar en 0 si no se usa.")
scratch_max_pos_embed = gr.Number(label="max_position_embeddings / n_positions", value=512)
with gr.Row():
scratch_tie_word_embeddings = gr.Checkbox(label="tie_word_embeddings", value=False)
with gr.TabItem("Opciones Generales (Trainer)", id="tab_general"):
general_group = gr.Group(visible=True)
with general_group:
with gr.Row():
add_eos_token = gr.Checkbox(label="add_eos_token", value=True)
auto_find_batch_size = gr.Checkbox(label="auto_find_batch_size (Acelerador)", value=True)
disable_gradient_checkpointing = gr.Checkbox(label="disable_gradient_checkpointing", value=False)
load_best_model_at_end = gr.Checkbox(label="load_best_model_at_end", value=False, info="Requiere dataset de validaci贸n y `eval_strategy='steps'`. Guarda el mejor modelo al final.")
with gr.Row():
chat_template = gr.Textbox(label="chat_template", value="tokenizer", lines=3,
placeholder="Dejar 'tokenizer' para usar la del modelo. 'none' para deshabilitar. O pegar plantilla Jinja2 personalizada.\nEj: {% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n' }}{% else %}{{ 'Assistant: ' + message['content'] + '\n' }}{% endif %}{% endfor %}",
info="Plantilla de chat personalizada (Jinja2).")
with gr.Row():
distributed_backend = gr.Textbox(label="distributed_backend (Ignorado en GPU 煤nica)", value="ddp", placeholder="ddp o fsdp")
eval_strategy = gr.Textbox(label="eval_strategy", value="steps", placeholder="no, steps o epoch", info="Recomendado: 'steps' para usar `validation_dataset`.")
merge_adapter = gr.Checkbox(label="merge_adapter (Subir modelo completo)", value=True, info="Ignorado si 'Entrenar desde Cero' est谩 activo.")
mixed_precision = gr.Textbox(label="mixed_precision", value="bf16", placeholder="none, fp16 o bf16")
with gr.Row():
optimizer = gr.Textbox(label="optimizer", value="adamw_torch", placeholder="adamw_8bit, adamw_torch, etc.")
peft = gr.Checkbox(label="peft (Activar LoRA)", value=True, info="Ignorado si 'Entrenar desde Cero' est谩 activo.")
padding = gr.Textbox(label="padding", value="left", placeholder="left o right")
quantization = gr.Textbox(label="quantization", value="int4", placeholder="none, int8 o int4", info="Ignorado si 'Entrenar desde Cero' est谩 activo.")
scheduler = gr.Textbox(label="scheduler", value="cosine", placeholder="linear, cosine o constant")
with gr.TabItem("Hiperpar谩metros Num茅ricos", id="tab_numeric"):
numeric_group = gr.Group(visible=True)
with numeric_group:
with gr.Row():
batch_size = gr.Slider(1, 128, value=1, step=1, label="per_device_train_batch_size")
block_size = gr.Slider(16, 8192, value=1024, step=16, label="max_seq_length (Bloque de texto)")
epochs = gr.Slider(1, 20, value=1, step=1, label="num_train_epochs")
gradient_accumulation = gr.Slider(1, 64, value=8, step=1, label="gradient_accumulation_steps")
with gr.Row():
learning_rate = gr.Number(value=1e-5, label="learning_rate (lr)")
logging_steps = gr.Number(value=5, label="logging_steps", info="Tambi茅n controla `evaluation_steps`.")
max_grad_norm = gr.Number(value=0.3, label="max_grad_norm", info="Normalizaci贸n de gradientes (Clipping).")
steps_per_epoch_estimate = gr.Number(value=10000, label="Estimaci贸n de Pasos por 脡poca", info="Usado para calcular `max_steps` en streaming.")
with gr.Row():
save_total_limit = gr.Number(value=1, label="save_total_limit", info="N煤mero de checkpoints a guardar.")
seed = gr.Number(value=42, label="seed")
warmup_ratio = gr.Number(value=0.05, label="warmup_ratio")
weight_decay = gr.Number(value=0.01, label="weight_decay")
with gr.TabItem("PEFT / LoRA", id="tab_peft"):
peft_group = gr.Group(visible=True)
with peft_group:
gr.Markdown("#### PEFT / LoRA (Ignorado si 'Entrenar desde Cero' est谩 activo o `peft` est谩 deshabilitado)")
with gr.Row():
lora_r = gr.Number(value=32, label="lora_r (r)")
lora_alpha = gr.Number(value=32, label="lora_alpha")
lora_dropout = gr.Number(value=0.05, label="lora_dropout")
target_modules = gr.Textbox(value="q_proj,k_proj,v_proj,o_proj", placeholder="Ej: all-linear o q_proj,v_proj", label="target_modules (Separados por coma)")
with gr.TabItem("Configuraci贸n del Modelo (Carga)", id="tab_load"):
load_config_group = gr.Group(visible=True)
with load_config_group:
gr.Markdown("#### Configuraci贸n del Modelo (Carga - Ignorado si 'Entrenar desde Cero' est谩 activo)")
with gr.Row():
trust_remote_code_input = gr.Checkbox(label="trust_remote_code", value=True, info="Permitir cargar modelos con c贸digo personalizado.")
attn_implementation_input = gr.Textbox(label="attn_implementation", value="sdpa", placeholder="auto, eager, sdpa, flash_attention_2", info="Optimizaci贸n de atenci贸n (ej: flash_attention_2 si est谩 disponible).")
with gr.Row():
new_special_tokens_input = gr.Textbox(label="Nuevos tokens especiales (coma-separado)", placeholder="<|im_start|>, <|im_end|>, <|system|>", info="A帽adir tokens al vocabulario. Redimensiona embeddings.")
with gr.TabItem("Formato de Datos (Razonamiento)", id="tab_cot"):
cot_group = gr.Group(visible=True)
with cot_group:
with gr.Row():
enable_cot_input = gr.Checkbox(label="Activar formato de razonamiento (CoT)", value=False, info="Anula la detecci贸n autom谩tica de columna 'text' o 'messages'.")
with gr.Row():
prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt", info="Nombre de la columna con el prompt/pregunta.")
with gr.Row():
reasoning_col_input = gr.Textbox(label="Columna de Razonamiento (Opcional)", value="reasoning", info="Nombre de la columna con los pasos de CoT.")
with gr.Row():
response_col_input = gr.Textbox(label="Columna de Respuesta", value="response", info="Nombre de la columna con la respuesta final.")
with gr.TabItem("Normalizaci贸n de Texto (Datos)", id="tab_norm"):
normalization_group = gr.Group(visible=True)
with normalization_group:
with gr.Row():
apply_lowercase_input = gr.Checkbox(label="Aplicar min煤sculas", value=False, info="Convierte todo el texto a min煤sculas antes de tokenizar.")
remove_punctuation_input = gr.Checkbox(label="Eliminar puntuaci贸n", value=False, info="Elimina signos de puntuaci贸n del texto antes de tokenizar.")
with gr.TabItem("Logging (Weights & Biases)", id="tab_wandb"):
wandb_group = gr.Group(visible=True)
with wandb_group:
with gr.Row():
wandb_project_input = gr.Textbox(label="Nombre del Proyecto W&B", placeholder="mi-proyecto-sft", info="Dejar vac铆o para deshabilitar W&B.")
with gr.Row():
wandb_api_key_input = gr.Textbox(label="API Key de W&B (Opcional)", type="password", placeholder="w_**************************************", info="Opcional. Usar si no est谩 configurado globalmente.")
def toggle_scratch_config(scratch_checked):
scratch_visible = scratch_checked
finetune_visible = not scratch_checked
updates = {
model_base_input: gr.Textbox(visible=finetune_visible),
scratch_group: gr.Group(visible=scratch_visible),
peft_group: gr.Group(visible=finetune_visible),
load_config_group: gr.Group(visible=finetune_visible),
}
return updates
train_from_scratch.change(
fn=toggle_scratch_config,
inputs=[train_from_scratch],
outputs=[model_base_input, scratch_group, peft_group, load_config_group],
queue=False
)
run_btn = gr.Button("馃殌 Iniciar Entrenamiento y Subir a Hugging Face Hub", elem_classes="primary-btn mt-6")
gr.Markdown("## Resultados y Progreso")
with gr.Group(elem_classes="card"):
with gr.Row():
logs = gr.Textbox(label="Logs de Ejecuci贸n", lines=12, elem_classes="textbox-logs", scale=3)
with gr.Column(scale=2):
phase = gr.Textbox(label="Fase Actual", lines=1)
link = gr.Textbox(label="Repositorio de Salida", lines=1)
cfg_out = gr.File(label="Configuraci贸n Final Generada (params.json)", file_types=[".json"])
scratch_inputs = [
scratch_vocab_size, scratch_special_tokens, scratch_base_tokenizer,
scratch_hidden_size, scratch_intermediate_size, scratch_num_hidden_layers,
scratch_num_attention_heads, scratch_num_key_value_heads, scratch_max_pos_embed, scratch_tie_word_embeddings
]
all_inputs = [
model_base_input, ds_text, uploads, repo_name_input,
train_from_scratch, scratch_architecture,
add_eos_token, auto_find_batch_size, chat_template,
disable_gradient_checkpointing, distributed_backend, eval_strategy,
load_best_model_at_end,
merge_adapter, mixed_precision, optimizer, peft, padding,
quantization, scheduler, batch_size, block_size, epochs,
gradient_accumulation, learning_rate, logging_steps, lora_alpha,
lora_dropout, lora_r, max_grad_norm,
save_total_limit, seed, warmup_ratio, weight_decay, target_modules,
steps_per_epoch_estimate,
trust_remote_code_input, attn_implementation_input, new_special_tokens_input,
apply_lowercase_input, remove_punctuation_input,
enable_cot_input, prompt_col_input, reasoning_col_input, response_col_input,
wandb_project_input, wandb_api_key_input,
auto_config_scratch
] + scratch_inputs
all_outputs = [logs, phase, link, cfg_out]
run_btn.click(
fn=train_and_upload,
inputs=all_inputs,
outputs=all_outputs
)
demo.launch(server_name="0.0.0.0", server_port=7860)