import os, io, json, tempfile, string os.system("pip install -U transformers peft datasets accelerate bitsandbytes trl scipy einops evaluate zstandard wandb tokenizers") os.system("pip install spaces-0.1.0-py3-none-any.whl") import gradio as gr import spaces import torch import logging import multiprocessing import uuid import gc import math from itertools import islice from datasets import load_dataset, IterableDataset, interleave_datasets from huggingface_hub import login, whoami, create_repo, upload_folder from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, IntervalStrategy, LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast ) from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training from trl import SFTTrainer from tokenizers import ByteLevelBPETokenizer logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" num_workers = multiprocessing.cpu_count() ARCHITECTURE_MAP = { "Llama": (LlamaConfig, LlamaForCausalLM), "Mistral": (MistralConfig, MistralForCausalLM), "Gemma": (GemmaConfig, GemmaForCausalLM), "GPT2": (GPT2Config, GPT2LMHeadModel), } def _normalize_text_helper(text, do_lowercase, do_remove_punct): if not isinstance(text, str): return text if do_lowercase: text = text.lower() if do_remove_punct: text = text.translate(str.maketrans('', '', string.punctuation)) return text def _load_hf_streaming(ids): streams = {"train": [], "validation": []} CONFLICT_COLUMNS = ['id', 'raw_id', 'shard_id', 'num_shards', 'meta'] for ident in ids: try: d = load_dataset(ident, streaming=True) def clean_and_keep_text(example): keys_to_drop = [k for k in example.keys() if k in CONFLICT_COLUMNS or k.endswith('_id')] for k in keys_to_drop: if k in example: del example[k] return example if isinstance(d, dict): for split, ds in d.items(): cleaned_ds = ds.map(clean_and_keep_text, batched=False) if "train" in split: streams["train"].append(cleaned_ds) elif "validation" in split or "test" in split: streams["validation"].append(cleaned_ds) else: cleaned_ds = d.map(clean_and_keep_text, batched=False) streams["train"].append(cleaned_ds) streams["validation"].append(cleaned_ds.skip(1).map(lambda x: x).with_format(type_injection_class=IterableDataset).map(lambda x: x, batched=False)) except Exception as e: logger.error(f"Error cargando o limpiando dataset {ident}: {e}") continue return streams def _load_uploaded_stream(files): all_rows = [] for f in files or []: name = f.name.lower() content = f.read().decode("utf-8", errors="ignore") if name.endswith(".csv"): import csv reader = csv.DictReader(io.StringIO(content)) for row in reader: all_rows.append(row) elif name.endswith(".jsonl"): for line in io.StringIO(content): try: all_rows.append(json.loads(line)) except: pass elif name.endswith(".json"): try: data = json.loads(content) if isinstance(data, list): all_rows.extend(data) elif isinstance(data, dict): all_rows.append(data) except: pass elif name.endswith(".txt"): for line in io.StringIO(content): line = line.strip() if line: all_rows.append({"text": line}) if not all_rows: return {"train": IterableDataset.from_generator(lambda: iter([])), "validation": IterableDataset.from_generator(lambda: iter([]))} n = len(all_rows) val_size = max(1, n // 100) train_data = all_rows[:-val_size] val_data = all_rows[-val_size:] train_ds = IterableDataset.from_generator(lambda: iter(train_data)) val_ds = IterableDataset.from_generator(lambda: iter(val_data)) return {"train": train_ds, "validation": val_ds} def _guess_columns(sample): text_col = None if isinstance(sample, dict): keys = list(sample.keys()) for k in keys: if k.lower() in ["text","content","prompt","input","messages","sentence","review","body","ctx","question"]: text_col = k return text_col or "text", None def hf_login(token): if not token or token.strip() == "": return "Error: Por favor, introduce tu token de Hugging Face." try: login(token=token.strip(), add_to_git_credential=True) user = whoami() name = user.get("name") or user.get("fullname") or "usuario" email = user.get("email") or "" return f"Sesión iniciada: {name} {f'({email})' if email else ''}" except Exception as e: return f"Error de inicio de sesión: {e}" def _sft_formatting_func(example, text_col, tokenizer, do_lowercase, do_remove_punct, enable_cot, prompt_col, reasoning_col, response_col): if enable_cot: prompt = example.get(prompt_col, "") reasoning = example.get(reasoning_col, "") response = example.get(response_col, "") prompt = _normalize_text_helper(prompt, do_lowercase, do_remove_punct) reasoning = _normalize_text_helper(reasoning, do_lowercase, do_remove_punct) response = _normalize_text_helper(response, do_lowercase, do_remove_punct) if reasoning: return f"Prompt: {prompt}\n\nReasoning: {reasoning}\n\nResponse: {response}" else: return f"Prompt: {prompt}\n\nResponse: {response}" if text_col == "messages" and hasattr(tokenizer, 'apply_chat_template'): processed_messages = [] for msg in example.get(text_col, []): new_msg = msg.copy() if 'content' in new_msg and isinstance(new_msg['content'], str): new_msg['content'] = _normalize_text_helper(new_msg['content'], do_lowercase, do_remove_punct) processed_messages.append(new_msg) return tokenizer.apply_chat_template(processed_messages, tokenize=False, add_generation_prompt=False) text = example.get(text_col) if isinstance(text, str): return _normalize_text_helper(text, do_lowercase, do_remove_punct) return "" def get_training_corpus_iterator(dataset, text_col, chunk_size=1000): if not dataset or not text_col: return iterator = iter(dataset) while True: chunk = list(islice(iterator, chunk_size)) if not chunk: break texts = [] for example in chunk: text = example.get(text_col) if isinstance(text, str) and text.strip(): texts.append(text.strip()) elif text_col == "messages" and isinstance(text, list): for msg in text: if isinstance(msg, dict) and isinstance(msg.get('content'), str): texts.append(msg['content'].strip()) if texts: yield texts def _count_dataset_size(dataset): count = 0 try: for _ in dataset: count += 1 except Exception: pass return count def _calculate_auto_config(train_dataset, block_size, scratch_architecture): size = _count_dataset_size(train_dataset) if size == 0: return 32000, 512, 1024, 8, 8, 512, False, 8, "meta-llama/Meta-Llama-3-8B" log_size = math.log2(max(1000, size)) vocab_size = min(65536, 32000 + int(log_size * 2000)) hidden_size = min(2048, 512 + int(log_size * 50)) hidden_size = max(512, hidden_size) intermediate_size = hidden_size * 2 layers = min(24, 4 + int(log_size * 0.5)) layers = max(4, int(layers)) heads = max(4, hidden_size // 64) heads = min(32, heads) max_pos_embed = int(block_size) kv_heads = heads if scratch_architecture in ["Mistral", "Llama", "Gemma"]: kv_heads = max(1, heads // 8) if hidden_size < 1024: kv_heads = heads tie_embeddings = False base_tokenizer_name = "meta-llama/Meta-Llama-3-8B" return vocab_size, hidden_size, intermediate_size, layers, heads, max_pos_embed, tie_embeddings, kv_heads, base_tokenizer_name @spaces.GPU() def train_and_upload(model_base_input, datasets_hf_text, uploads, repo_name_input, train_from_scratch, scratch_architecture, add_eos_token, auto_find_batch_size, chat_template, disable_gradient_checkpointing, distributed_backend, eval_strategy, load_best_model_at_end, merge_adapter, mixed_precision, optimizer, peft, padding, quantization, scheduler, batch_size, block_size, epochs, gradient_accumulation, learning_rate, logging_steps, lora_alpha, lora_dropout, lora_r, max_grad_norm, save_total_limit, seed, warmup_ratio, weight_decay, target_modules, steps_per_epoch_estimate, trust_remote_code_input, attn_implementation_input, new_special_tokens_input, apply_lowercase_input, remove_punctuation_input, enable_cot_input, prompt_col_input, reasoning_col_input, response_col_input, wandb_project_input, wandb_api_key_input, scratch_vocab_size, scratch_special_tokens, scratch_base_tokenizer, scratch_hidden_size, scratch_intermediate_size, scratch_num_hidden_layers, scratch_num_attention_heads, scratch_num_key_value_heads, scratch_max_pos_embed, scratch_tie_word_embeddings, auto_config_scratch, progress=gr.Progress()): temp_dir = tempfile.mkdtemp() logs = "" repo_link = "" config_data = {} def update_logs(new_msg, phase_msg, step_ratio=None): nonlocal logs logs += f"{new_msg}\n" if step_ratio is not None: progress(step_ratio) return logs, phase_msg, repo_link, None try: user = whoami() username = user.get("name") or "hf_user" model = model_base_input.strip() if not model and not train_from_scratch: logs += "Error: Debe especificar un ID de **Modelo Base** o activar 'Entrenar desde Cero'.\n" yield logs, "Error", repo_link, None return if repo_name_input and repo_name_input.strip(): repo_base = repo_name_input.strip().replace(" ", "-") else: random_suffix = uuid.uuid4().hex[:6] if train_from_scratch: model_slug = f"{scratch_architecture.lower()}-{int(scratch_hidden_size)}-{int(scratch_num_hidden_layers)}l" else: model_slug = model.split('/')[-1].replace('.', '-').lower() repo_base = f"{model_slug}-sft-{random_suffix}" repo_id = f"{username}/{repo_base}" hf_ids = [x.strip() for x in (datasets_hf_text or "").split(",") if x.strip()] all_ds = {"train": [], "validation": []} if hf_ids or uploads: yield update_logs("Cargando datasets en streaming...", "Cargando Datos", 0.05) dsh = _load_hf_streaming(hf_ids) all_ds["train"].extend(dsh.get("train", [])) all_ds["validation"].extend(dsh.get("validation", [])) dsu = _load_uploaded_stream(uploads) all_ds["train"].append(dsu["train"]) all_ds["validation"].append(dsu["validation"]) valid_train_streams = [ds for ds in all_ds["train"] if ds and next(iter(ds), None) is not None] valid_val_streams = [ds for ds in all_ds["validation"] if ds and next(iter(ds), None) is not None] if not valid_train_streams: logs += "Error: No se encontraron datasets válidos o con contenido para entrenar.\n" yield logs, "Error", repo_link, None return if not valid_val_streams and (eval_strategy.lower() == "steps" or load_best_model_at_end): logs += "Advertencia: La evaluación está habilitada (`eval_strategy` o `load_best_model_at_end`) pero no se encontraron datasets de validación. Deshabilitando evaluación.\n" eval_strategy = "no" load_best_model_at_end = False train_dataset = interleave_datasets(valid_train_streams) validation_dataset = None if valid_val_streams: validation_dataset = interleave_datasets(valid_val_streams) text_col = "text" if not enable_cot_input: try: sample = next(iter(train_dataset)) text_col, _ = _guess_columns(sample) except Exception: text_col = "text" yield update_logs(f"Columna de texto detectada: {text_col}", "Detectado", 0.10) else: yield update_logs("Formato de Razonamiento (CoT) activado.", "Detectado", 0.10) logs += f" Prompt: {prompt_col_input}\n" if reasoning_col_input: logs += f" Razonamiento: {reasoning_col_input}\n" logs += f" Respuesta: {response_col_input}\n" yield logs, "Detectado", repo_link, None if train_from_scratch: yield update_logs(f"--- Modo 'Entrenar desde Cero' activado (Arquitectura: {scratch_architecture}) ---", "Preparando Modelo", 0.15) logs += "PEFT, Quantization y Merge serán desactivados.\n" peft = False merge_adapter = False quantization = "none" config_data["train_from_scratch"] = True config_data["architecture"] = scratch_architecture if auto_config_scratch: yield update_logs("Calculando configuración del modelo automáticamente...", "Preparando Modelo") (auto_vocab_size, auto_hidden_size, auto_intermediate_size, auto_layers, auto_heads, auto_max_pos_embed, auto_tie_embeddings, auto_kv_heads, auto_base_tokenizer_name) = _calculate_auto_config(train_dataset, int(block_size), scratch_architecture) scratch_vocab_size = auto_vocab_size scratch_hidden_size = auto_hidden_size scratch_intermediate_size = auto_intermediate_size scratch_num_hidden_layers = auto_layers scratch_num_attention_heads = auto_heads scratch_num_key_value_heads = auto_kv_heads scratch_max_pos_embed = auto_max_pos_embed scratch_tie_word_embeddings = auto_tie_embeddings scratch_base_tokenizer = auto_base_tokenizer_name yield update_logs(f"Config Autocalculada: Vocab={auto_vocab_size}, Hidden={auto_hidden_size}, Layers={auto_layers}, Heads={auto_heads}, KV={auto_kv_heads}", "Preparando Modelo") yield update_logs("Iniciando entrenamiento de nuevo tokenizer...", "Preparando Modelo") base_tok_name = scratch_base_tokenizer.strip() if scratch_base_tokenizer.strip() else "meta-llama/Meta-Llama-3-8B" yield update_logs(f"Cargando tokenizer base: {base_tok_name}", "Preparando Modelo") base_tok = AutoTokenizer.from_pretrained(base_tok_name, use_fast=True) special_tokens_list = [t.strip() for t in scratch_special_tokens.split(",") if t.strip()] corpus_iterator = get_training_corpus_iterator(train_dataset, text_col) tokenizer = base_tok.train_new_from_iterator( corpus_iterator, vocab_size=int(scratch_vocab_size) ) yield update_logs(f"Nuevo tokenizer entrenado. Vocab Size: {tokenizer.vocab_size}", "Preparando Modelo") if special_tokens_list: current_special_tokens = tokenizer.all_special_tokens new_tokens_to_add = [t for t in special_tokens_list if t not in current_special_tokens] if new_tokens_to_add: tokenizer.add_special_tokens({"additional_special_tokens": new_tokens_to_add}) if "" in special_tokens_list: tokenizer.bos_token = "" if "" in special_tokens_list: tokenizer.eos_token = "" if "" in special_tokens_list: tokenizer.unk_token = "" if "" in special_tokens_list: tokenizer.pad_token = "" if "" in special_tokens_list: tokenizer.mask_token = "" if tokenizer.pad_token is None: yield update_logs("Advertencia: No se encontró '' en los tokens especiales. Usando '' como pad_token.", "Preparando Modelo") tokenizer.pad_token = tokenizer.eos_token if scratch_architecture not in ARCHITECTURE_MAP: logs += f"Error: Arquitectura '{scratch_architecture}' no soportada.\n" yield logs, "Error", repo_link, None return ConfigClass, ModelClass = ARCHITECTURE_MAP[scratch_architecture] yield update_logs(f"Creando {ConfigClass.__name__} para nuevo modelo...", "Preparando Modelo") config_params = { "vocab_size": tokenizer.vocab_size, "pad_token_id": tokenizer.pad_token_id, "bos_token_id": tokenizer.bos_token_id, "eos_token_id": tokenizer.eos_token_id, "initializer_range": 0.02, "use_cache": True, "tie_word_embeddings": scratch_tie_word_embeddings, } if scratch_architecture == "GPT2": config_params.update({ "n_embd": int(scratch_hidden_size), "n_layer": int(scratch_num_hidden_layers), "n_head": int(scratch_num_attention_heads), "n_positions": int(scratch_max_pos_embed), }) else: config_params.update({ "hidden_size": int(scratch_hidden_size), "intermediate_size": int(scratch_intermediate_size), "num_hidden_layers": int(scratch_num_hidden_layers), "num_attention_heads": int(scratch_num_attention_heads), "max_position_embeddings": int(scratch_max_pos_embed), "rms_norm_eps": 1e-6, }) if scratch_architecture in ["Mistral", "Gemma"] and scratch_num_key_value_heads > 0: config_params["num_key_value_heads"] = int(scratch_num_key_value_heads) config = ConfigClass(**config_params) config_data["model_config"] = config.to_dict() yield update_logs(f"Inicializando {ModelClass.__name__} desde cero...", "Preparando Modelo") model_hf = ModelClass(config).to(device) torch_dtype = torch.float32 else: yield update_logs("--- Modo 'Fine-Tuning' activado ---", "Preparando Modelo", 0.15) config_data["train_from_scratch"] = False config_data["model_base"] = model quantization_val = quantization if quantization != "none" else None bnb_config = None if quantization_val == "int4": bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) elif quantization_val == "int8": bnb_config = BitsAndBytesConfig(load_in_8bit=True) elif mixed_precision in ["fp16", "bf16"]: pass else: bnb_config = BitsAndBytesConfig(load_in_8bit=True) torch_dtype = torch.float32 if mixed_precision == "bf16": torch_dtype = torch.bfloat16 elif mixed_precision == "fp16": torch_dtype = torch.float16 yield update_logs("Cargando Tokenizer y Modelo (Fine-Tuning)...", "Cargando Modelo", 0.20) tokenizer = AutoTokenizer.from_pretrained( model, padding_side=padding, add_eos_token=add_eos_token, trust_remote_code=trust_remote_code_input, use_fast=False ) chat_template_str = chat_template.strip() if chat_template else "tokenizer" if chat_template_str.lower() == "none": tokenizer.chat_template = None yield update_logs("Plantilla de chat deshabilitada.", "Cargando Modelo") elif chat_template_str.lower() != "tokenizer": tokenizer.chat_template = chat_template_str yield update_logs("Aplicando plantilla de chat personalizada.", "Cargando Modelo") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model_kwargs = { "quantization_config": bnb_config, "device_map": "auto", "trust_remote_code": trust_remote_code_input, "torch_dtype": torch_dtype, } attn_impl_val = attn_implementation_input.strip().lower() if attn_implementation_input else "auto" if attn_impl_val and attn_impl_val != "auto": model_kwargs["attn_implementation"] = attn_impl_val yield update_logs(f"Usando attn_implementation: {attn_impl_val}", "Cargando Modelo") model_hf = AutoModelForCausalLM.from_pretrained( model, **model_kwargs ).to(device) if new_special_tokens_input: tokens_to_add = [t.strip() for t in new_special_tokens_input.split(",") if t.strip()] if tokens_to_add: yield update_logs(f"Añadiendo {len(tokens_to_add)} tokens y redimensionando embeddings...", "Cargando Modelo") tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add}) model_hf.resize_token_embeddings(len(tokenizer)) if hasattr(model_hf, "tie_weights"): model_hf.tie_weights() else: input_embeddings = model_hf.get_input_embeddings() output_embeddings = model_hf.get_output_embeddings() if output_embeddings is not None and input_embeddings.weight.shape == output_embeddings.weight.shape: output_embeddings.weight.data = input_embeddings.weight.data if quantization_val is not None: model_hf = prepare_model_for_kbit_training(model_hf) peft_config = None if peft and not train_from_scratch: yield update_logs("Configurando PEFT (LoRA)...", "Preparando Trainer", 0.25) peft_config = LoraConfig( r=int(lora_r), lora_alpha=float(lora_alpha), lora_dropout=float(lora_dropout), target_modules=target_modules.split(",") if target_modules != "all-linear" else None, bias="none", task_type="CAUSAL_LM" ) config_data["lora_config"] = peft_config.to_dict() else: yield update_logs("Entrenamiento completo (sin PEFT) activado.", "Preparando Trainer", 0.25) eval_strategy_lower = eval_strategy.lower() num_steps_calculated = int(float(epochs) * float(steps_per_epoch_estimate) / int(gradient_accumulation) / int(batch_size)) yield update_logs(f"Cálculo de max_steps para streaming (basado en {steps_per_epoch_estimate} pasos/época): {num_steps_calculated} pasos.", "Preparando Trainer") if eval_strategy_lower == "steps": save_strategy_val = IntervalStrategy.STEPS elif eval_strategy_lower == "epoch": save_strategy_val = IntervalStrategy.EPOCH else: save_strategy_val = IntervalStrategy.NO report_to_val = "none" if wandb_project_input and wandb_project_input.strip(): yield update_logs(f"Habilitando logging en Weights & Biases (Proyecto: {wandb_project_input.strip()}).", "Preparando Trainer") os.environ["WANDB_DISABLED"] = "false" os.environ["WANDB_PROJECT"] = wandb_project_input.strip() if wandb_api_key_input and wandb_api_key_input.strip(): os.environ["WANDB_API_KEY"] = wandb_api_key_input.strip() report_to_val = "wandb" else: os.environ["WANDB_DISABLED"] = "true" yield update_logs("Logging en W&B deshabilitado.", "Preparando Trainer") training_args = TrainingArguments( output_dir=os.path.join(temp_dir, "results"), num_train_epochs=float(epochs), per_device_train_batch_size=int(batch_size), per_device_eval_batch_size=int(batch_size), gradient_accumulation_steps=int(gradient_accumulation), optim=optimizer, save_strategy=save_strategy_val, save_steps=int(logging_steps) * 10, logging_steps=int(logging_steps), evaluation_strategy=eval_strategy_lower, eval_steps=int(logging_steps) if eval_strategy_lower == "steps" else None, learning_rate=float(learning_rate), fp16=(mixed_precision == "fp16"), bf16=(mixed_precision == "bf16"), max_grad_norm=float(max_grad_norm), warmup_ratio=float(warmup_ratio), weight_decay=float(weight_decay), load_best_model_at_end=load_best_model_at_end, save_total_limit=int(save_total_limit), gradient_checkpointing=not disable_gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, push_to_hub=True, hub_model_id=repo_id, hub_private_repo=False, disable_tqdm=False, max_steps=num_steps_calculated, report_to=report_to_val, save_safetensors=True, save_only_model=True, dataloader_num_workers=0, ) yield update_logs("Inicializando SFT Trainer...", "Preparando Trainer", 0.30) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() formatting_lambda = lambda example: _sft_formatting_func( example, text_col, tokenizer, apply_lowercase_input, remove_punctuation_input, enable_cot_input, prompt_col_input, reasoning_col_input, response_col_input ) trainer = SFTTrainer( model=model_hf, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset, eval_dataset=validation_dataset, peft_config=peft_config, formatting_func=formatting_lambda, max_seq_length=int(block_size) ) yield update_logs("Iniciando entrenamiento...", "Entrenando", 0.40) trainer.train() yield update_logs("Entrenamiento finalizado. Guardando modelo final...", "Guardando Modelo", 0.85) trainer.save_model(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir) del model_hf gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if merge_adapter and peft and not train_from_scratch: yield update_logs("Fusionando adaptadores LoRA con el modelo base...", "Fusionando", 0.90) del trainer gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() base_model_for_merge = AutoModelForCausalLM.from_pretrained( model, trust_remote_code=trust_remote_code_input, torch_dtype=torch_dtype, attn_implementation=attn_impl_val if attn_impl_val != "auto" else None ).to(device) ft = PeftModel.from_pretrained( base_model_for_merge, training_args.output_dir, torch_dtype=torch.float32, is_trainable=False, device_map={"": device} ).merge_and_unload() output_model_dir = os.path.join(tempfile.mkdtemp(), "final_merged_model") ft.save_pretrained(output_model_dir, safe_serialization=True) tokenizer.save_pretrained(output_model_dir) upload_folder( folder_path=output_model_dir, repo_id=repo_id, commit_message="Modelo fusionado (PEFT y base) en safetensors", allow_patterns=["*"] ) del ft, base_model_for_merge gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() else: yield update_logs("Subida manejada por Trainer.push_to_hub().", "Subiendo a Hub", 0.95) repo_link = f"https://huggingface.co/{repo_id}" config_data.update({ "training_arguments": {}, "quantization": quantization if not train_from_scratch else "none", "text_column": text_col if not enable_cot_input else "N/A (CoT Enabled)", "hub_url": repo_link, "load_config": { "trust_remote_code": trust_remote_code_input, "attn_implementation": attn_impl_val, "new_special_tokens": new_special_tokens_input, "chat_template": tokenizer.chat_template }, "normalization_config": { "apply_lowercase": apply_lowercase_input, "remove_punctuation": remove_punctuation_input }, "cot_config": { "enabled": enable_cot_input, "prompt_col": prompt_col_input, "reasoning_col": reasoning_col_input, "response_col": response_col_input }, "wandb_config": { "project": wandb_project_input } }) training_args_dict = training_args.to_dict() for key, value in training_args_dict.items(): if isinstance(value, set): config_data["training_arguments"][key] = list(value) elif isinstance(value, IntervalStrategy): config_data["training_arguments"][key] = str(value) else: config_data["training_arguments"][key] = value cfg_path = os.path.join(tempfile.mkdtemp(), "config_result.json") with open(cfg_path, "w", encoding="utf-8") as f: json.dump(config_data, f, ensure_ascii=False, indent=2) logs += f"Subida completa: {repo_link}\n" yield logs, "Listo", repo_link, cfg_path except Exception as e: err = f"Error fatal de ejecución: {type(e).__name__}: {e}\n" import traceback err += traceback.format_exc() yield err, "Error", "", None css = """ :root { --bg:#f8f8f8; --card:#ffffff; --muted:#e0e0e0; --text:#1f2937; --sub:#6b7280; --accent:#818cf8; --accent-hover:#a78bfa; --shadow:rgba(0,0,0,0.1); --font-family: 'Inter', sans-serif; } .gradio-container { background-color: var(--bg) !important; font-family: var(--font-family); color: var(--text); } div:not(.label-wrap), label:not(.label-wrap label) { color: var(--text) !important; } h1, h2, h3 { letter-spacing: .2px; color: var(--text); font-weight: 700; } button { border-radius: 12px !important; padding: 10px 14px !important; border: 1px solid var(--muted) !important; background: linear-gradient(180deg, #f0f0f0, #ffffff) !important; color: var(--text) !important; transition: all .15s ease; box-shadow: 0 2px 5px var(--shadow); } button:hover { transform: translateY(-1px); border-color: #c0c0c0 !important; box-shadow: 0 4px 10px var(--shadow); } .primary-btn button { background: linear-gradient(90deg, #818cf8, #a855f7, #ec4899) !important; border: none !important; color: white !important; font-weight: 600; } .primary-btn button:hover { transform: scale(1.01); box-shadow: 0 8px 20px rgba(130, 100, 255, 0.4); } input[type="text"], input[type="password"], input[type="number"], textarea, select { background: var(--card) !important; border-radius: 10px !important; border: 1px solid var(--muted) !important; color: var(--text) !important; padding: 10px; } .label-wrap label { color: var(--sub) !important; font-size: 14px; } .card { background: var(--card); border: 1px solid var(--muted); border-radius: 16px; padding: 20px; box-shadow: 0 4px 15px var(--shadow); } .textbox-logs textarea { font-family: monospace; font-size: 12px; line-height: 1.4; background: #f0f0f0 !important; color: var(--text) !important; } .title-h1 { text-align: center; font-size: 32px; font-weight: 800; color: #1f2937; letter-spacing: 1px; margin-bottom: 4px; } .title-subtitle { text-align: center; font-size: 14px; color: #6b7280; margin-bottom: 20px; } """ with gr.Blocks(title="HuggingFace SFT Trainer Studio", theme="gradio/soft", css=css) as demo: gr.Markdown("# HuggingFace SFT Trainer Studio", elem_classes="title-h1") gr.Markdown("Afinación avanzada multi-dataset con TRL (SFT) y pipeline de datos en streaming.", elem_classes="title-subtitle") with gr.Row(equal_height=True): with gr.Column(scale=1): with gr.Group(elem_classes="card"): gr.Markdown("### 1. Autenticación") token_input = gr.Textbox( label="Token de Acceso de Hugging Face", type="password", placeholder="Ingresa tu token (hf_xxx...)", lines=1 ) login_btn = gr.Button("Conectar y Guardar Token") login_status = gr.Textbox(label="Estado", lines=1) login_btn.click(fn=hf_login, inputs=[token_input], outputs=login_status) gr.Markdown("### 2. Configuración de Modelo y Repositorio") train_from_scratch = gr.Checkbox( label="Entrenar desde Cero (Nuevo Modelo y Tokenizer)", value=False, info="Si se activa, se ignorará el 'Modelo Base' y se entrenará un nuevo modelo/tokenizer." ) model_base_input = gr.Textbox( label="Modelo Base (ID del Hub)", value="", placeholder="Ej: meta-llama/Meta-Llama-3-8B. (Ignorado si 'Entrenar desde Cero' está activo)", visible=True ) repo_name_input = gr.Textbox( label="Nombre del Repositorio de Salida", value="", placeholder="Opcional. Ej: mi-llama-personalizado. (Se generará uno si está vacío)" ) with gr.Column(scale=2): with gr.Group(elem_classes="card"): gr.Markdown("### 3. Datos de Entrenamiento (Streaming)") ds_text = gr.Textbox( label="Datasets de Hugging Face (separados por coma)", placeholder="tatsu-lab/alpaca, OpenAssistant/oasst1", lines=2 ) uploads = gr.Files(label="Subir datasets locales (csv/json/jsonl/txt)", file_count="multiple", type="binary") with gr.Accordion("⚙️ Parámetros Avanzados", open=False): with gr.Tabs() as tabs_advanced: with gr.TabItem("Arquitectura (Desde Cero)", id="tab_scratch"): scratch_group = gr.Group(visible=False) with scratch_group: auto_config_scratch = gr.Checkbox( label="Calcular Configuración del Modelo Automáticamente (Basado en el tamaño del dataset)", value=False ) gr.Markdown("#### Configuración del Tokenizer y Arquitectura") with gr.Row(): scratch_architecture = gr.Dropdown( label="Arquitectura del Modelo", choices=list(ARCHITECTURE_MAP.keys()), value="Llama", info="Elige la arquitectura base para el nuevo modelo." ) scratch_base_tokenizer = gr.Textbox( label="Tokenizer Base para Entrenar", value="meta-llama/Meta-Llama-3-8B", info="Usaremos este tokenizer para 'train_new_from_iterator'." ) with gr.Row(): scratch_vocab_size = gr.Number(label="Tamaño del Vocabulario (Nuevo Tokenizer)", value=32000) scratch_special_tokens = gr.Textbox( label="Tokens Especiales (Nuevo Tokenizer)", value=",,,,,<|user|>,<|bot|>,<|end|>", info="Separados por coma." ) gr.Markdown("#### Configuración del Modelo (Hiperparámetros)") with gr.Row(): scratch_hidden_size = gr.Number(label="hidden_size / n_embd", value=512) scratch_intermediate_size = gr.Number(label="intermediate_size", value=1024, info="Ignorado por arquitecturas como GPT2.") with gr.Row(): scratch_num_hidden_layers = gr.Number(label="num_hidden_layers / n_layer", value=8) scratch_num_attention_heads = gr.Number(label="num_attention_heads / n_head", value=8) with gr.Row(): scratch_num_key_value_heads = gr.Number(label="num_key_value_heads", value=8, info="Importante para Mistral/Gemma (GQA). Dejar en 0 si no se usa.") scratch_max_pos_embed = gr.Number(label="max_position_embeddings / n_positions", value=512) with gr.Row(): scratch_tie_word_embeddings = gr.Checkbox(label="tie_word_embeddings", value=False) with gr.TabItem("Opciones Generales (Trainer)", id="tab_general"): general_group = gr.Group(visible=True) with general_group: with gr.Row(): add_eos_token = gr.Checkbox(label="add_eos_token", value=True) auto_find_batch_size = gr.Checkbox(label="auto_find_batch_size (Acelerador)", value=True) disable_gradient_checkpointing = gr.Checkbox(label="disable_gradient_checkpointing", value=False) load_best_model_at_end = gr.Checkbox(label="load_best_model_at_end", value=False, info="Requiere dataset de validación y `eval_strategy='steps'`. Guarda el mejor modelo al final.") with gr.Row(): chat_template = gr.Textbox(label="chat_template", value="tokenizer", lines=3, placeholder="Dejar 'tokenizer' para usar la del modelo. 'none' para deshabilitar. O pegar plantilla Jinja2 personalizada.\nEj: {% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n' }}{% else %}{{ 'Assistant: ' + message['content'] + '\n' }}{% endif %}{% endfor %}", info="Plantilla de chat personalizada (Jinja2).") with gr.Row(): distributed_backend = gr.Textbox(label="distributed_backend (Ignorado en GPU única)", value="ddp", placeholder="ddp o fsdp") eval_strategy = gr.Textbox(label="eval_strategy", value="steps", placeholder="no, steps o epoch", info="Recomendado: 'steps' para usar `validation_dataset`.") merge_adapter = gr.Checkbox(label="merge_adapter (Subir modelo completo)", value=True, info="Ignorado si 'Entrenar desde Cero' está activo.") mixed_precision = gr.Textbox(label="mixed_precision", value="bf16", placeholder="none, fp16 o bf16") with gr.Row(): optimizer = gr.Textbox(label="optimizer", value="adamw_torch", placeholder="adamw_8bit, adamw_torch, etc.") peft = gr.Checkbox(label="peft (Activar LoRA)", value=True, info="Ignorado si 'Entrenar desde Cero' está activo.") padding = gr.Textbox(label="padding", value="left", placeholder="left o right") quantization = gr.Textbox(label="quantization", value="int4", placeholder="none, int8 o int4", info="Ignorado si 'Entrenar desde Cero' está activo.") scheduler = gr.Textbox(label="scheduler", value="cosine", placeholder="linear, cosine o constant") with gr.TabItem("Hiperparámetros Numéricos", id="tab_numeric"): numeric_group = gr.Group(visible=True) with numeric_group: with gr.Row(): batch_size = gr.Slider(1, 128, value=1, step=1, label="per_device_train_batch_size") block_size = gr.Slider(16, 8192, value=1024, step=16, label="max_seq_length (Bloque de texto)") epochs = gr.Slider(1, 20, value=1, step=1, label="num_train_epochs") gradient_accumulation = gr.Slider(1, 64, value=8, step=1, label="gradient_accumulation_steps") with gr.Row(): learning_rate = gr.Number(value=1e-5, label="learning_rate (lr)") logging_steps = gr.Number(value=5, label="logging_steps", info="También controla `evaluation_steps`.") max_grad_norm = gr.Number(value=0.3, label="max_grad_norm", info="Normalización de gradientes (Clipping).") steps_per_epoch_estimate = gr.Number(value=10000, label="Estimación de Pasos por Época", info="Usado para calcular `max_steps` en streaming.") with gr.Row(): save_total_limit = gr.Number(value=1, label="save_total_limit", info="Número de checkpoints a guardar.") seed = gr.Number(value=42, label="seed") warmup_ratio = gr.Number(value=0.05, label="warmup_ratio") weight_decay = gr.Number(value=0.01, label="weight_decay") with gr.TabItem("PEFT / LoRA", id="tab_peft"): peft_group = gr.Group(visible=True) with peft_group: gr.Markdown("#### PEFT / LoRA (Ignorado si 'Entrenar desde Cero' está activo o `peft` está deshabilitado)") with gr.Row(): lora_r = gr.Number(value=32, label="lora_r (r)") lora_alpha = gr.Number(value=32, label="lora_alpha") lora_dropout = gr.Number(value=0.05, label="lora_dropout") target_modules = gr.Textbox(value="q_proj,k_proj,v_proj,o_proj", placeholder="Ej: all-linear o q_proj,v_proj", label="target_modules (Separados por coma)") with gr.TabItem("Configuración del Modelo (Carga)", id="tab_load"): load_config_group = gr.Group(visible=True) with load_config_group: gr.Markdown("#### Configuración del Modelo (Carga - Ignorado si 'Entrenar desde Cero' está activo)") with gr.Row(): trust_remote_code_input = gr.Checkbox(label="trust_remote_code", value=True, info="Permitir cargar modelos con código personalizado.") attn_implementation_input = gr.Textbox(label="attn_implementation", value="sdpa", placeholder="auto, eager, sdpa, flash_attention_2", info="Optimización de atención (ej: flash_attention_2 si está disponible).") with gr.Row(): new_special_tokens_input = gr.Textbox(label="Nuevos tokens especiales (coma-separado)", placeholder="<|im_start|>, <|im_end|>, <|system|>", info="Añadir tokens al vocabulario. Redimensiona embeddings.") with gr.TabItem("Formato de Datos (Razonamiento)", id="tab_cot"): cot_group = gr.Group(visible=True) with cot_group: with gr.Row(): enable_cot_input = gr.Checkbox(label="Activar formato de razonamiento (CoT)", value=False, info="Anula la detección automática de columna 'text' o 'messages'.") with gr.Row(): prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt", info="Nombre de la columna con el prompt/pregunta.") with gr.Row(): reasoning_col_input = gr.Textbox(label="Columna de Razonamiento (Opcional)", value="reasoning", info="Nombre de la columna con los pasos de CoT.") with gr.Row(): response_col_input = gr.Textbox(label="Columna de Respuesta", value="response", info="Nombre de la columna con la respuesta final.") with gr.TabItem("Normalización de Texto (Datos)", id="tab_norm"): normalization_group = gr.Group(visible=True) with normalization_group: with gr.Row(): apply_lowercase_input = gr.Checkbox(label="Aplicar minúsculas", value=False, info="Convierte todo el texto a minúsculas antes de tokenizar.") remove_punctuation_input = gr.Checkbox(label="Eliminar puntuación", value=False, info="Elimina signos de puntuación del texto antes de tokenizar.") with gr.TabItem("Logging (Weights & Biases)", id="tab_wandb"): wandb_group = gr.Group(visible=True) with wandb_group: with gr.Row(): wandb_project_input = gr.Textbox(label="Nombre del Proyecto W&B", placeholder="mi-proyecto-sft", info="Dejar vacío para deshabilitar W&B.") with gr.Row(): wandb_api_key_input = gr.Textbox(label="API Key de W&B (Opcional)", type="password", placeholder="w_**************************************", info="Opcional. Usar si no está configurado globalmente.") def toggle_scratch_config(scratch_checked): scratch_visible = scratch_checked finetune_visible = not scratch_checked updates = { model_base_input: gr.Textbox(visible=finetune_visible), scratch_group: gr.Group(visible=scratch_visible), peft_group: gr.Group(visible=finetune_visible), load_config_group: gr.Group(visible=finetune_visible), } return updates train_from_scratch.change( fn=toggle_scratch_config, inputs=[train_from_scratch], outputs=[model_base_input, scratch_group, peft_group, load_config_group], queue=False ) run_btn = gr.Button("🚀 Iniciar Entrenamiento y Subir a Hugging Face Hub", elem_classes="primary-btn mt-6") gr.Markdown("## Resultados y Progreso") with gr.Group(elem_classes="card"): with gr.Row(): logs = gr.Textbox(label="Logs de Ejecución", lines=12, elem_classes="textbox-logs", scale=3) with gr.Column(scale=2): phase = gr.Textbox(label="Fase Actual", lines=1) link = gr.Textbox(label="Repositorio de Salida", lines=1) cfg_out = gr.File(label="Configuración Final Generada (params.json)", file_types=[".json"]) scratch_inputs = [ scratch_vocab_size, scratch_special_tokens, scratch_base_tokenizer, scratch_hidden_size, scratch_intermediate_size, scratch_num_hidden_layers, scratch_num_attention_heads, scratch_num_key_value_heads, scratch_max_pos_embed, scratch_tie_word_embeddings ] all_inputs = [ model_base_input, ds_text, uploads, repo_name_input, train_from_scratch, scratch_architecture, add_eos_token, auto_find_batch_size, chat_template, disable_gradient_checkpointing, distributed_backend, eval_strategy, load_best_model_at_end, merge_adapter, mixed_precision, optimizer, peft, padding, quantization, scheduler, batch_size, block_size, epochs, gradient_accumulation, learning_rate, logging_steps, lora_alpha, lora_dropout, lora_r, max_grad_norm, save_total_limit, seed, warmup_ratio, weight_decay, target_modules, steps_per_epoch_estimate, trust_remote_code_input, attn_implementation_input, new_special_tokens_input, apply_lowercase_input, remove_punctuation_input, enable_cot_input, prompt_col_input, reasoning_col_input, response_col_input, wandb_project_input, wandb_api_key_input, auto_config_scratch ] + scratch_inputs all_outputs = [logs, phase, link, cfg_out] run_btn.click( fn=train_and_upload, inputs=all_inputs, outputs=all_outputs ) demo.launch(server_name="0.0.0.0", server_port=7860)