Spaces:
Runtime error
Runtime error
| import os, io, json, tempfile, string | |
| os.system("pip install -U transformers peft datasets accelerate bitsandbytes trl scipy einops evaluate zstandard wandb tokenizers") | |
| os.system("pip install spaces-0.1.0-py3-none-any.whl") | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import logging | |
| import multiprocessing | |
| import uuid | |
| import gc | |
| import math | |
| from itertools import islice | |
| from datasets import load_dataset, IterableDataset, interleave_datasets | |
| from huggingface_hub import login, whoami, create_repo, upload_folder | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| BitsAndBytesConfig, | |
| IntervalStrategy, | |
| LlamaConfig, LlamaForCausalLM, | |
| MistralConfig, MistralForCausalLM, | |
| GemmaConfig, GemmaForCausalLM, | |
| GPT2Config, GPT2LMHeadModel, | |
| PreTrainedTokenizerFast | |
| ) | |
| from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training | |
| from trl import SFTTrainer | |
| from tokenizers import ByteLevelBPETokenizer | |
| logger = logging.getLogger(__name__) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| num_workers = multiprocessing.cpu_count() | |
| ARCHITECTURE_MAP = { | |
| "Llama": (LlamaConfig, LlamaForCausalLM), | |
| "Mistral": (MistralConfig, MistralForCausalLM), | |
| "Gemma": (GemmaConfig, GemmaForCausalLM), | |
| "GPT2": (GPT2Config, GPT2LMHeadModel), | |
| } | |
| def _normalize_text_helper(text, do_lowercase, do_remove_punct): | |
| if not isinstance(text, str): | |
| return text | |
| if do_lowercase: | |
| text = text.lower() | |
| if do_remove_punct: | |
| text = text.translate(str.maketrans('', '', string.punctuation)) | |
| return text | |
| def _load_hf_streaming(ids): | |
| streams = {"train": [], "validation": []} | |
| CONFLICT_COLUMNS = ['id', 'raw_id', 'shard_id', 'num_shards', 'meta'] | |
| for ident in ids: | |
| try: | |
| d = load_dataset(ident, streaming=True) | |
| def clean_and_keep_text(example): | |
| keys_to_drop = [k for k in example.keys() if k in CONFLICT_COLUMNS or k.endswith('_id')] | |
| for k in keys_to_drop: | |
| if k in example: | |
| del example[k] | |
| return example | |
| if isinstance(d, dict): | |
| for split, ds in d.items(): | |
| cleaned_ds = ds.map(clean_and_keep_text, batched=False) | |
| if "train" in split: streams["train"].append(cleaned_ds) | |
| elif "validation" in split or "test" in split: streams["validation"].append(cleaned_ds) | |
| else: | |
| cleaned_ds = d.map(clean_and_keep_text, batched=False) | |
| streams["train"].append(cleaned_ds) | |
| streams["validation"].append(cleaned_ds.skip(1).map(lambda x: x).with_format(type_injection_class=IterableDataset).map(lambda x: x, batched=False)) | |
| except Exception as e: | |
| logger.error(f"Error cargando o limpiando dataset {ident}: {e}") | |
| continue | |
| return streams | |
| def _load_uploaded_stream(files): | |
| all_rows = [] | |
| for f in files or []: | |
| name = f.name.lower() | |
| content = f.read().decode("utf-8", errors="ignore") | |
| if name.endswith(".csv"): | |
| import csv | |
| reader = csv.DictReader(io.StringIO(content)) | |
| for row in reader: all_rows.append(row) | |
| elif name.endswith(".jsonl"): | |
| for line in io.StringIO(content): | |
| try: all_rows.append(json.loads(line)) | |
| except: pass | |
| elif name.endswith(".json"): | |
| try: | |
| data = json.loads(content) | |
| if isinstance(data, list): all_rows.extend(data) | |
| elif isinstance(data, dict): all_rows.append(data) | |
| except: pass | |
| elif name.endswith(".txt"): | |
| for line in io.StringIO(content): | |
| line = line.strip() | |
| if line: all_rows.append({"text": line}) | |
| if not all_rows: | |
| return {"train": IterableDataset.from_generator(lambda: iter([])), | |
| "validation": IterableDataset.from_generator(lambda: iter([]))} | |
| n = len(all_rows) | |
| val_size = max(1, n // 100) | |
| train_data = all_rows[:-val_size] | |
| val_data = all_rows[-val_size:] | |
| train_ds = IterableDataset.from_generator(lambda: iter(train_data)) | |
| val_ds = IterableDataset.from_generator(lambda: iter(val_data)) | |
| return {"train": train_ds, "validation": val_ds} | |
| def _guess_columns(sample): | |
| text_col = None | |
| if isinstance(sample, dict): | |
| keys = list(sample.keys()) | |
| for k in keys: | |
| if k.lower() in ["text","content","prompt","input","messages","sentence","review","body","ctx","question"]: text_col = k | |
| return text_col or "text", None | |
| def hf_login(token): | |
| if not token or token.strip() == "": | |
| return "Error: Por favor, introduce tu token de Hugging Face." | |
| try: | |
| login(token=token.strip(), add_to_git_credential=True) | |
| user = whoami() | |
| name = user.get("name") or user.get("fullname") or "usuario" | |
| email = user.get("email") or "" | |
| return f"Sesi贸n iniciada: {name} {f'({email})' if email else ''}" | |
| except Exception as e: | |
| return f"Error de inicio de sesi贸n: {e}" | |
| def _sft_formatting_func(example, text_col, tokenizer, | |
| do_lowercase, do_remove_punct, | |
| enable_cot, prompt_col, reasoning_col, response_col): | |
| if enable_cot: | |
| prompt = example.get(prompt_col, "") | |
| reasoning = example.get(reasoning_col, "") | |
| response = example.get(response_col, "") | |
| prompt = _normalize_text_helper(prompt, do_lowercase, do_remove_punct) | |
| reasoning = _normalize_text_helper(reasoning, do_lowercase, do_remove_punct) | |
| response = _normalize_text_helper(response, do_lowercase, do_remove_punct) | |
| if reasoning: | |
| return f"Prompt: {prompt}\n\nReasoning: {reasoning}\n\nResponse: {response}" | |
| else: | |
| return f"Prompt: {prompt}\n\nResponse: {response}" | |
| if text_col == "messages" and hasattr(tokenizer, 'apply_chat_template'): | |
| processed_messages = [] | |
| for msg in example.get(text_col, []): | |
| new_msg = msg.copy() | |
| if 'content' in new_msg and isinstance(new_msg['content'], str): | |
| new_msg['content'] = _normalize_text_helper(new_msg['content'], do_lowercase, do_remove_punct) | |
| processed_messages.append(new_msg) | |
| return tokenizer.apply_chat_template(processed_messages, tokenize=False, add_generation_prompt=False) | |
| text = example.get(text_col) | |
| if isinstance(text, str): | |
| return _normalize_text_helper(text, do_lowercase, do_remove_punct) | |
| return "" | |
| def get_training_corpus_iterator(dataset, text_col, chunk_size=1000): | |
| if not dataset or not text_col: | |
| return | |
| iterator = iter(dataset) | |
| while True: | |
| chunk = list(islice(iterator, chunk_size)) | |
| if not chunk: | |
| break | |
| texts = [] | |
| for example in chunk: | |
| text = example.get(text_col) | |
| if isinstance(text, str) and text.strip(): | |
| texts.append(text.strip()) | |
| elif text_col == "messages" and isinstance(text, list): | |
| for msg in text: | |
| if isinstance(msg, dict) and isinstance(msg.get('content'), str): | |
| texts.append(msg['content'].strip()) | |
| if texts: | |
| yield texts | |
| def _count_dataset_size(dataset): | |
| count = 0 | |
| try: | |
| for _ in dataset: | |
| count += 1 | |
| except Exception: | |
| pass | |
| return count | |
| def _calculate_auto_config(train_dataset, block_size, scratch_architecture): | |
| size = _count_dataset_size(train_dataset) | |
| if size == 0: | |
| return 32000, 512, 1024, 8, 8, 512, False, 8, "meta-llama/Meta-Llama-3-8B" | |
| log_size = math.log2(max(1000, size)) | |
| vocab_size = min(65536, 32000 + int(log_size * 2000)) | |
| hidden_size = min(2048, 512 + int(log_size * 50)) | |
| hidden_size = max(512, hidden_size) | |
| intermediate_size = hidden_size * 2 | |
| layers = min(24, 4 + int(log_size * 0.5)) | |
| layers = max(4, int(layers)) | |
| heads = max(4, hidden_size // 64) | |
| heads = min(32, heads) | |
| max_pos_embed = int(block_size) | |
| kv_heads = heads | |
| if scratch_architecture in ["Mistral", "Llama", "Gemma"]: | |
| kv_heads = max(1, heads // 8) | |
| if hidden_size < 1024: | |
| kv_heads = heads | |
| tie_embeddings = False | |
| base_tokenizer_name = "meta-llama/Meta-Llama-3-8B" | |
| return vocab_size, hidden_size, intermediate_size, layers, heads, max_pos_embed, tie_embeddings, kv_heads, base_tokenizer_name | |
| def train_and_upload(model_base_input, datasets_hf_text, uploads, repo_name_input, | |
| train_from_scratch, scratch_architecture, | |
| add_eos_token, auto_find_batch_size, chat_template, | |
| disable_gradient_checkpointing, distributed_backend, eval_strategy, | |
| load_best_model_at_end, | |
| merge_adapter, mixed_precision, optimizer, peft, padding, | |
| quantization, scheduler, batch_size, block_size, epochs, | |
| gradient_accumulation, learning_rate, logging_steps, lora_alpha, | |
| lora_dropout, lora_r, max_grad_norm, | |
| save_total_limit, seed, warmup_ratio, weight_decay, target_modules, | |
| steps_per_epoch_estimate, | |
| trust_remote_code_input, attn_implementation_input, new_special_tokens_input, | |
| apply_lowercase_input, remove_punctuation_input, | |
| enable_cot_input, prompt_col_input, reasoning_col_input, response_col_input, | |
| wandb_project_input, wandb_api_key_input, | |
| scratch_vocab_size, scratch_special_tokens, scratch_base_tokenizer, | |
| scratch_hidden_size, scratch_intermediate_size, scratch_num_hidden_layers, | |
| scratch_num_attention_heads, scratch_num_key_value_heads, scratch_max_pos_embed, scratch_tie_word_embeddings, | |
| auto_config_scratch, | |
| progress=gr.Progress()): | |
| temp_dir = tempfile.mkdtemp() | |
| logs = "" | |
| repo_link = "" | |
| config_data = {} | |
| def update_logs(new_msg, phase_msg, step_ratio=None): | |
| nonlocal logs | |
| logs += f"{new_msg}\n" | |
| if step_ratio is not None: | |
| progress(step_ratio) | |
| return logs, phase_msg, repo_link, None | |
| try: | |
| user = whoami() | |
| username = user.get("name") or "hf_user" | |
| model = model_base_input.strip() | |
| if not model and not train_from_scratch: | |
| logs += "Error: Debe especificar un ID de **Modelo Base** o activar 'Entrenar desde Cero'.\n" | |
| yield logs, "Error", repo_link, None | |
| return | |
| if repo_name_input and repo_name_input.strip(): | |
| repo_base = repo_name_input.strip().replace(" ", "-") | |
| else: | |
| random_suffix = uuid.uuid4().hex[:6] | |
| if train_from_scratch: | |
| model_slug = f"{scratch_architecture.lower()}-{int(scratch_hidden_size)}-{int(scratch_num_hidden_layers)}l" | |
| else: | |
| model_slug = model.split('/')[-1].replace('.', '-').lower() | |
| repo_base = f"{model_slug}-sft-{random_suffix}" | |
| repo_id = f"{username}/{repo_base}" | |
| hf_ids = [x.strip() for x in (datasets_hf_text or "").split(",") if x.strip()] | |
| all_ds = {"train": [], "validation": []} | |
| if hf_ids or uploads: | |
| yield update_logs("Cargando datasets en streaming...", "Cargando Datos", 0.05) | |
| dsh = _load_hf_streaming(hf_ids) | |
| all_ds["train"].extend(dsh.get("train", [])) | |
| all_ds["validation"].extend(dsh.get("validation", [])) | |
| dsu = _load_uploaded_stream(uploads) | |
| all_ds["train"].append(dsu["train"]) | |
| all_ds["validation"].append(dsu["validation"]) | |
| valid_train_streams = [ds for ds in all_ds["train"] if ds and next(iter(ds), None) is not None] | |
| valid_val_streams = [ds for ds in all_ds["validation"] if ds and next(iter(ds), None) is not None] | |
| if not valid_train_streams: | |
| logs += "Error: No se encontraron datasets v谩lidos o con contenido para entrenar.\n" | |
| yield logs, "Error", repo_link, None | |
| return | |
| if not valid_val_streams and (eval_strategy.lower() == "steps" or load_best_model_at_end): | |
| logs += "Advertencia: La evaluaci贸n est谩 habilitada (`eval_strategy` o `load_best_model_at_end`) pero no se encontraron datasets de validaci贸n. Deshabilitando evaluaci贸n.\n" | |
| eval_strategy = "no" | |
| load_best_model_at_end = False | |
| train_dataset = interleave_datasets(valid_train_streams) | |
| validation_dataset = None | |
| if valid_val_streams: | |
| validation_dataset = interleave_datasets(valid_val_streams) | |
| text_col = "text" | |
| if not enable_cot_input: | |
| try: | |
| sample = next(iter(train_dataset)) | |
| text_col, _ = _guess_columns(sample) | |
| except Exception: | |
| text_col = "text" | |
| yield update_logs(f"Columna de texto detectada: {text_col}", "Detectado", 0.10) | |
| else: | |
| yield update_logs("Formato de Razonamiento (CoT) activado.", "Detectado", 0.10) | |
| logs += f" Prompt: {prompt_col_input}\n" | |
| if reasoning_col_input: | |
| logs += f" Razonamiento: {reasoning_col_input}\n" | |
| logs += f" Respuesta: {response_col_input}\n" | |
| yield logs, "Detectado", repo_link, None | |
| if train_from_scratch: | |
| yield update_logs(f"--- Modo 'Entrenar desde Cero' activado (Arquitectura: {scratch_architecture}) ---", "Preparando Modelo", 0.15) | |
| logs += "PEFT, Quantization y Merge ser谩n desactivados.\n" | |
| peft = False | |
| merge_adapter = False | |
| quantization = "none" | |
| config_data["train_from_scratch"] = True | |
| config_data["architecture"] = scratch_architecture | |
| if auto_config_scratch: | |
| yield update_logs("Calculando configuraci贸n del modelo autom谩ticamente...", "Preparando Modelo") | |
| (auto_vocab_size, auto_hidden_size, auto_intermediate_size, | |
| auto_layers, auto_heads, auto_max_pos_embed, auto_tie_embeddings, auto_kv_heads, | |
| auto_base_tokenizer_name) = _calculate_auto_config(train_dataset, int(block_size), scratch_architecture) | |
| scratch_vocab_size = auto_vocab_size | |
| scratch_hidden_size = auto_hidden_size | |
| scratch_intermediate_size = auto_intermediate_size | |
| scratch_num_hidden_layers = auto_layers | |
| scratch_num_attention_heads = auto_heads | |
| scratch_num_key_value_heads = auto_kv_heads | |
| scratch_max_pos_embed = auto_max_pos_embed | |
| scratch_tie_word_embeddings = auto_tie_embeddings | |
| scratch_base_tokenizer = auto_base_tokenizer_name | |
| yield update_logs(f"Config Autocalculada: Vocab={auto_vocab_size}, Hidden={auto_hidden_size}, Layers={auto_layers}, Heads={auto_heads}, KV={auto_kv_heads}", "Preparando Modelo") | |
| yield update_logs("Iniciando entrenamiento de nuevo tokenizer...", "Preparando Modelo") | |
| base_tok_name = scratch_base_tokenizer.strip() if scratch_base_tokenizer.strip() else "meta-llama/Meta-Llama-3-8B" | |
| yield update_logs(f"Cargando tokenizer base: {base_tok_name}", "Preparando Modelo") | |
| base_tok = AutoTokenizer.from_pretrained(base_tok_name, use_fast=True) | |
| special_tokens_list = [t.strip() for t in scratch_special_tokens.split(",") if t.strip()] | |
| corpus_iterator = get_training_corpus_iterator(train_dataset, text_col) | |
| tokenizer = base_tok.train_new_from_iterator( | |
| corpus_iterator, | |
| vocab_size=int(scratch_vocab_size) | |
| ) | |
| yield update_logs(f"Nuevo tokenizer entrenado. Vocab Size: {tokenizer.vocab_size}", "Preparando Modelo") | |
| if special_tokens_list: | |
| current_special_tokens = tokenizer.all_special_tokens | |
| new_tokens_to_add = [t for t in special_tokens_list if t not in current_special_tokens] | |
| if new_tokens_to_add: | |
| tokenizer.add_special_tokens({"additional_special_tokens": new_tokens_to_add}) | |
| if "<s>" in special_tokens_list: tokenizer.bos_token = "<s>" | |
| if "</s>" in special_tokens_list: tokenizer.eos_token = "</s>" | |
| if "<unk>" in special_tokens_list: tokenizer.unk_token = "<unk>" | |
| if "<pad>" in special_tokens_list: tokenizer.pad_token = "<pad>" | |
| if "<mask>" in special_tokens_list: tokenizer.mask_token = "<mask>" | |
| if tokenizer.pad_token is None: | |
| yield update_logs("Advertencia: No se encontr贸 '<pad>' en los tokens especiales. Usando '</s>' como pad_token.", "Preparando Modelo") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if scratch_architecture not in ARCHITECTURE_MAP: | |
| logs += f"Error: Arquitectura '{scratch_architecture}' no soportada.\n" | |
| yield logs, "Error", repo_link, None | |
| return | |
| ConfigClass, ModelClass = ARCHITECTURE_MAP[scratch_architecture] | |
| yield update_logs(f"Creando {ConfigClass.__name__} para nuevo modelo...", "Preparando Modelo") | |
| config_params = { | |
| "vocab_size": tokenizer.vocab_size, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| "bos_token_id": tokenizer.bos_token_id, | |
| "eos_token_id": tokenizer.eos_token_id, | |
| "initializer_range": 0.02, | |
| "use_cache": True, | |
| "tie_word_embeddings": scratch_tie_word_embeddings, | |
| } | |
| if scratch_architecture == "GPT2": | |
| config_params.update({ | |
| "n_embd": int(scratch_hidden_size), | |
| "n_layer": int(scratch_num_hidden_layers), | |
| "n_head": int(scratch_num_attention_heads), | |
| "n_positions": int(scratch_max_pos_embed), | |
| }) | |
| else: | |
| config_params.update({ | |
| "hidden_size": int(scratch_hidden_size), | |
| "intermediate_size": int(scratch_intermediate_size), | |
| "num_hidden_layers": int(scratch_num_hidden_layers), | |
| "num_attention_heads": int(scratch_num_attention_heads), | |
| "max_position_embeddings": int(scratch_max_pos_embed), | |
| "rms_norm_eps": 1e-6, | |
| }) | |
| if scratch_architecture in ["Mistral", "Gemma"] and scratch_num_key_value_heads > 0: | |
| config_params["num_key_value_heads"] = int(scratch_num_key_value_heads) | |
| config = ConfigClass(**config_params) | |
| config_data["model_config"] = config.to_dict() | |
| yield update_logs(f"Inicializando {ModelClass.__name__} desde cero...", "Preparando Modelo") | |
| model_hf = ModelClass(config).to(device) | |
| torch_dtype = torch.float32 | |
| else: | |
| yield update_logs("--- Modo 'Fine-Tuning' activado ---", "Preparando Modelo", 0.15) | |
| config_data["train_from_scratch"] = False | |
| config_data["model_base"] = model | |
| quantization_val = quantization if quantization != "none" else None | |
| bnb_config = None | |
| if quantization_val == "int4": | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, | |
| ) | |
| elif quantization_val == "int8": | |
| bnb_config = BitsAndBytesConfig(load_in_8bit=True) | |
| elif mixed_precision in ["fp16", "bf16"]: | |
| pass | |
| else: | |
| bnb_config = BitsAndBytesConfig(load_in_8bit=True) | |
| torch_dtype = torch.float32 | |
| if mixed_precision == "bf16": | |
| torch_dtype = torch.bfloat16 | |
| elif mixed_precision == "fp16": | |
| torch_dtype = torch.float16 | |
| yield update_logs("Cargando Tokenizer y Modelo (Fine-Tuning)...", "Cargando Modelo", 0.20) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model, | |
| padding_side=padding, | |
| add_eos_token=add_eos_token, | |
| trust_remote_code=trust_remote_code_input, | |
| use_fast=False | |
| ) | |
| chat_template_str = chat_template.strip() if chat_template else "tokenizer" | |
| if chat_template_str.lower() == "none": | |
| tokenizer.chat_template = None | |
| yield update_logs("Plantilla de chat deshabilitada.", "Cargando Modelo") | |
| elif chat_template_str.lower() != "tokenizer": | |
| tokenizer.chat_template = chat_template_str | |
| yield update_logs("Aplicando plantilla de chat personalizada.", "Cargando Modelo") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model_kwargs = { | |
| "quantization_config": bnb_config, | |
| "device_map": "auto", | |
| "trust_remote_code": trust_remote_code_input, | |
| "torch_dtype": torch_dtype, | |
| } | |
| attn_impl_val = attn_implementation_input.strip().lower() if attn_implementation_input else "auto" | |
| if attn_impl_val and attn_impl_val != "auto": | |
| model_kwargs["attn_implementation"] = attn_impl_val | |
| yield update_logs(f"Usando attn_implementation: {attn_impl_val}", "Cargando Modelo") | |
| model_hf = AutoModelForCausalLM.from_pretrained( | |
| model, | |
| **model_kwargs | |
| ).to(device) | |
| if new_special_tokens_input: | |
| tokens_to_add = [t.strip() for t in new_special_tokens_input.split(",") if t.strip()] | |
| if tokens_to_add: | |
| yield update_logs(f"A帽adiendo {len(tokens_to_add)} tokens y redimensionando embeddings...", "Cargando Modelo") | |
| tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add}) | |
| model_hf.resize_token_embeddings(len(tokenizer)) | |
| if hasattr(model_hf, "tie_weights"): | |
| model_hf.tie_weights() | |
| else: | |
| input_embeddings = model_hf.get_input_embeddings() | |
| output_embeddings = model_hf.get_output_embeddings() | |
| if output_embeddings is not None and input_embeddings.weight.shape == output_embeddings.weight.shape: | |
| output_embeddings.weight.data = input_embeddings.weight.data | |
| if quantization_val is not None: | |
| model_hf = prepare_model_for_kbit_training(model_hf) | |
| peft_config = None | |
| if peft and not train_from_scratch: | |
| yield update_logs("Configurando PEFT (LoRA)...", "Preparando Trainer", 0.25) | |
| peft_config = LoraConfig( | |
| r=int(lora_r), | |
| lora_alpha=float(lora_alpha), | |
| lora_dropout=float(lora_dropout), | |
| target_modules=target_modules.split(",") if target_modules != "all-linear" else None, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| config_data["lora_config"] = peft_config.to_dict() | |
| else: | |
| yield update_logs("Entrenamiento completo (sin PEFT) activado.", "Preparando Trainer", 0.25) | |
| eval_strategy_lower = eval_strategy.lower() | |
| num_steps_calculated = int(float(epochs) * float(steps_per_epoch_estimate) / int(gradient_accumulation) / int(batch_size)) | |
| yield update_logs(f"C谩lculo de max_steps para streaming (basado en {steps_per_epoch_estimate} pasos/茅poca): {num_steps_calculated} pasos.", "Preparando Trainer") | |
| if eval_strategy_lower == "steps": | |
| save_strategy_val = IntervalStrategy.STEPS | |
| elif eval_strategy_lower == "epoch": | |
| save_strategy_val = IntervalStrategy.EPOCH | |
| else: | |
| save_strategy_val = IntervalStrategy.NO | |
| report_to_val = "none" | |
| if wandb_project_input and wandb_project_input.strip(): | |
| yield update_logs(f"Habilitando logging en Weights & Biases (Proyecto: {wandb_project_input.strip()}).", "Preparando Trainer") | |
| os.environ["WANDB_DISABLED"] = "false" | |
| os.environ["WANDB_PROJECT"] = wandb_project_input.strip() | |
| if wandb_api_key_input and wandb_api_key_input.strip(): | |
| os.environ["WANDB_API_KEY"] = wandb_api_key_input.strip() | |
| report_to_val = "wandb" | |
| else: | |
| os.environ["WANDB_DISABLED"] = "true" | |
| yield update_logs("Logging en W&B deshabilitado.", "Preparando Trainer") | |
| training_args = TrainingArguments( | |
| output_dir=os.path.join(temp_dir, "results"), | |
| num_train_epochs=float(epochs), | |
| per_device_train_batch_size=int(batch_size), | |
| per_device_eval_batch_size=int(batch_size), | |
| gradient_accumulation_steps=int(gradient_accumulation), | |
| optim=optimizer, | |
| save_strategy=save_strategy_val, | |
| save_steps=int(logging_steps) * 10, | |
| logging_steps=int(logging_steps), | |
| evaluation_strategy=eval_strategy_lower, | |
| eval_steps=int(logging_steps) if eval_strategy_lower == "steps" else None, | |
| learning_rate=float(learning_rate), | |
| fp16=(mixed_precision == "fp16"), | |
| bf16=(mixed_precision == "bf16"), | |
| max_grad_norm=float(max_grad_norm), | |
| warmup_ratio=float(warmup_ratio), | |
| weight_decay=float(weight_decay), | |
| load_best_model_at_end=load_best_model_at_end, | |
| save_total_limit=int(save_total_limit), | |
| gradient_checkpointing=not disable_gradient_checkpointing, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| push_to_hub=True, | |
| hub_model_id=repo_id, | |
| hub_private_repo=False, | |
| disable_tqdm=False, | |
| max_steps=num_steps_calculated, | |
| report_to=report_to_val, | |
| save_safetensors=True, | |
| save_only_model=True, | |
| dataloader_num_workers=0, | |
| ) | |
| yield update_logs("Inicializando SFT Trainer...", "Preparando Trainer", 0.30) | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| formatting_lambda = lambda example: _sft_formatting_func( | |
| example, | |
| text_col, | |
| tokenizer, | |
| apply_lowercase_input, | |
| remove_punctuation_input, | |
| enable_cot_input, | |
| prompt_col_input, | |
| reasoning_col_input, | |
| response_col_input | |
| ) | |
| trainer = SFTTrainer( | |
| model=model_hf, | |
| tokenizer=tokenizer, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=validation_dataset, | |
| peft_config=peft_config, | |
| formatting_func=formatting_lambda, | |
| max_seq_length=int(block_size) | |
| ) | |
| yield update_logs("Iniciando entrenamiento...", "Entrenando", 0.40) | |
| trainer.train() | |
| yield update_logs("Entrenamiento finalizado. Guardando modelo final...", "Guardando Modelo", 0.85) | |
| trainer.save_model(training_args.output_dir) | |
| tokenizer.save_pretrained(training_args.output_dir) | |
| del model_hf | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if merge_adapter and peft and not train_from_scratch: | |
| yield update_logs("Fusionando adaptadores LoRA con el modelo base...", "Fusionando", 0.90) | |
| del trainer | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| base_model_for_merge = AutoModelForCausalLM.from_pretrained( | |
| model, | |
| trust_remote_code=trust_remote_code_input, | |
| torch_dtype=torch_dtype, | |
| attn_implementation=attn_impl_val if attn_impl_val != "auto" else None | |
| ).to(device) | |
| ft = PeftModel.from_pretrained( | |
| base_model_for_merge, | |
| training_args.output_dir, torch_dtype=torch.float32, | |
| is_trainable=False, device_map={"": device} | |
| ).merge_and_unload() | |
| output_model_dir = os.path.join(tempfile.mkdtemp(), "final_merged_model") | |
| ft.save_pretrained(output_model_dir, safe_serialization=True) | |
| tokenizer.save_pretrained(output_model_dir) | |
| upload_folder( | |
| folder_path=output_model_dir, | |
| repo_id=repo_id, | |
| commit_message="Modelo fusionado (PEFT y base) en safetensors", | |
| allow_patterns=["*"] | |
| ) | |
| del ft, base_model_for_merge | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| else: | |
| yield update_logs("Subida manejada por Trainer.push_to_hub().", "Subiendo a Hub", 0.95) | |
| repo_link = f"https://huggingface.co/{repo_id}" | |
| config_data.update({ | |
| "training_arguments": {}, | |
| "quantization": quantization if not train_from_scratch else "none", | |
| "text_column": text_col if not enable_cot_input else "N/A (CoT Enabled)", | |
| "hub_url": repo_link, | |
| "load_config": { | |
| "trust_remote_code": trust_remote_code_input, | |
| "attn_implementation": attn_impl_val, | |
| "new_special_tokens": new_special_tokens_input, | |
| "chat_template": tokenizer.chat_template | |
| }, | |
| "normalization_config": { | |
| "apply_lowercase": apply_lowercase_input, | |
| "remove_punctuation": remove_punctuation_input | |
| }, | |
| "cot_config": { | |
| "enabled": enable_cot_input, | |
| "prompt_col": prompt_col_input, | |
| "reasoning_col": reasoning_col_input, | |
| "response_col": response_col_input | |
| }, | |
| "wandb_config": { | |
| "project": wandb_project_input | |
| } | |
| }) | |
| training_args_dict = training_args.to_dict() | |
| for key, value in training_args_dict.items(): | |
| if isinstance(value, set): | |
| config_data["training_arguments"][key] = list(value) | |
| elif isinstance(value, IntervalStrategy): | |
| config_data["training_arguments"][key] = str(value) | |
| else: | |
| config_data["training_arguments"][key] = value | |
| cfg_path = os.path.join(tempfile.mkdtemp(), "config_result.json") | |
| with open(cfg_path, "w", encoding="utf-8") as f: | |
| json.dump(config_data, f, ensure_ascii=False, indent=2) | |
| logs += f"Subida completa: {repo_link}\n" | |
| yield logs, "Listo", repo_link, cfg_path | |
| except Exception as e: | |
| err = f"Error fatal de ejecuci贸n: {type(e).__name__}: {e}\n" | |
| import traceback | |
| err += traceback.format_exc() | |
| yield err, "Error", "", None | |
| css = """ | |
| :root { | |
| --bg:#f8f8f8; | |
| --card:#ffffff; | |
| --muted:#e0e0e0; | |
| --text:#1f2937; | |
| --sub:#6b7280; | |
| --accent:#818cf8; | |
| --accent-hover:#a78bfa; | |
| --shadow:rgba(0,0,0,0.1); | |
| --font-family: 'Inter', sans-serif; | |
| } | |
| .gradio-container { | |
| background-color: var(--bg) !important; | |
| font-family: var(--font-family); | |
| color: var(--text); | |
| } | |
| div:not(.label-wrap), label:not(.label-wrap label) { color: var(--text) !important; } | |
| h1, h2, h3 { | |
| letter-spacing: .2px; | |
| color: var(--text); | |
| font-weight: 700; | |
| } | |
| button { | |
| border-radius: 12px !important; | |
| padding: 10px 14px !important; | |
| border: 1px solid var(--muted) !important; | |
| background: linear-gradient(180deg, #f0f0f0, #ffffff) !important; | |
| color: var(--text) !important; | |
| transition: all .15s ease; | |
| box-shadow: 0 2px 5px var(--shadow); | |
| } | |
| button:hover { | |
| transform: translateY(-1px); | |
| border-color: #c0c0c0 !important; | |
| box-shadow: 0 4px 10px var(--shadow); | |
| } | |
| .primary-btn button { | |
| background: linear-gradient(90deg, #818cf8, #a855f7, #ec4899) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: 600; | |
| } | |
| .primary-btn button:hover { | |
| transform: scale(1.01); | |
| box-shadow: 0 8px 20px rgba(130, 100, 255, 0.4); | |
| } | |
| input[type="text"], | |
| input[type="password"], | |
| input[type="number"], | |
| textarea, | |
| select { | |
| background: var(--card) !important; | |
| border-radius: 10px !important; | |
| border: 1px solid var(--muted) !important; | |
| color: var(--text) !important; | |
| padding: 10px; | |
| } | |
| .label-wrap label { | |
| color: var(--sub) !important; | |
| font-size: 14px; | |
| } | |
| .card { | |
| background: var(--card); | |
| border: 1px solid var(--muted); | |
| border-radius: 16px; | |
| padding: 20px; | |
| box-shadow: 0 4px 15px var(--shadow); | |
| } | |
| .textbox-logs textarea { | |
| font-family: monospace; | |
| font-size: 12px; | |
| line-height: 1.4; | |
| background: #f0f0f0 !important; | |
| color: var(--text) !important; | |
| } | |
| .title-h1 { | |
| text-align: center; | |
| font-size: 32px; | |
| font-weight: 800; | |
| color: #1f2937; | |
| letter-spacing: 1px; | |
| margin-bottom: 4px; | |
| } | |
| .title-subtitle { | |
| text-align: center; | |
| font-size: 14px; | |
| color: #6b7280; | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| with gr.Blocks(title="HuggingFace SFT Trainer Studio", theme="gradio/soft", css=css) as demo: | |
| gr.Markdown("# HuggingFace SFT Trainer Studio", elem_classes="title-h1") | |
| gr.Markdown("Afinaci贸n avanzada multi-dataset con TRL (SFT) y pipeline de datos en streaming.", elem_classes="title-subtitle") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="card"): | |
| gr.Markdown("### 1. Autenticaci贸n") | |
| token_input = gr.Textbox( | |
| label="Token de Acceso de Hugging Face", | |
| type="password", | |
| placeholder="Ingresa tu token (hf_xxx...)", | |
| lines=1 | |
| ) | |
| login_btn = gr.Button("Conectar y Guardar Token") | |
| login_status = gr.Textbox(label="Estado", lines=1) | |
| login_btn.click(fn=hf_login, inputs=[token_input], outputs=login_status) | |
| gr.Markdown("### 2. Configuraci贸n de Modelo y Repositorio") | |
| train_from_scratch = gr.Checkbox( | |
| label="Entrenar desde Cero (Nuevo Modelo y Tokenizer)", | |
| value=False, | |
| info="Si se activa, se ignorar谩 el 'Modelo Base' y se entrenar谩 un nuevo modelo/tokenizer." | |
| ) | |
| model_base_input = gr.Textbox( | |
| label="Modelo Base (ID del Hub)", | |
| value="", | |
| placeholder="Ej: meta-llama/Meta-Llama-3-8B. (Ignorado si 'Entrenar desde Cero' est谩 activo)", | |
| visible=True | |
| ) | |
| repo_name_input = gr.Textbox( | |
| label="Nombre del Repositorio de Salida", | |
| value="", | |
| placeholder="Opcional. Ej: mi-llama-personalizado. (Se generar谩 uno si est谩 vac铆o)" | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="card"): | |
| gr.Markdown("### 3. Datos de Entrenamiento (Streaming)") | |
| ds_text = gr.Textbox( | |
| label="Datasets de Hugging Face (separados por coma)", | |
| placeholder="tatsu-lab/alpaca, OpenAssistant/oasst1", | |
| lines=2 | |
| ) | |
| uploads = gr.Files(label="Subir datasets locales (csv/json/jsonl/txt)", file_count="multiple", type="binary") | |
| with gr.Accordion("鈿欙笍 Par谩metros Avanzados", open=False): | |
| with gr.Tabs() as tabs_advanced: | |
| with gr.TabItem("Arquitectura (Desde Cero)", id="tab_scratch"): | |
| scratch_group = gr.Group(visible=False) | |
| with scratch_group: | |
| auto_config_scratch = gr.Checkbox( | |
| label="Calcular Configuraci贸n del Modelo Autom谩ticamente (Basado en el tama帽o del dataset)", | |
| value=False | |
| ) | |
| gr.Markdown("#### Configuraci贸n del Tokenizer y Arquitectura") | |
| with gr.Row(): | |
| scratch_architecture = gr.Dropdown( | |
| label="Arquitectura del Modelo", | |
| choices=list(ARCHITECTURE_MAP.keys()), | |
| value="Llama", | |
| info="Elige la arquitectura base para el nuevo modelo." | |
| ) | |
| scratch_base_tokenizer = gr.Textbox( | |
| label="Tokenizer Base para Entrenar", | |
| value="meta-llama/Meta-Llama-3-8B", | |
| info="Usaremos este tokenizer para 'train_new_from_iterator'." | |
| ) | |
| with gr.Row(): | |
| scratch_vocab_size = gr.Number(label="Tama帽o del Vocabulario (Nuevo Tokenizer)", value=32000) | |
| scratch_special_tokens = gr.Textbox( | |
| label="Tokens Especiales (Nuevo Tokenizer)", | |
| value="<s>,<pad>,</s>,<unk>,<mask>,<|user|>,<|bot|>,<|end|>", | |
| info="Separados por coma." | |
| ) | |
| gr.Markdown("#### Configuraci贸n del Modelo (Hiperpar谩metros)") | |
| with gr.Row(): | |
| scratch_hidden_size = gr.Number(label="hidden_size / n_embd", value=512) | |
| scratch_intermediate_size = gr.Number(label="intermediate_size", value=1024, info="Ignorado por arquitecturas como GPT2.") | |
| with gr.Row(): | |
| scratch_num_hidden_layers = gr.Number(label="num_hidden_layers / n_layer", value=8) | |
| scratch_num_attention_heads = gr.Number(label="num_attention_heads / n_head", value=8) | |
| with gr.Row(): | |
| scratch_num_key_value_heads = gr.Number(label="num_key_value_heads", value=8, info="Importante para Mistral/Gemma (GQA). Dejar en 0 si no se usa.") | |
| scratch_max_pos_embed = gr.Number(label="max_position_embeddings / n_positions", value=512) | |
| with gr.Row(): | |
| scratch_tie_word_embeddings = gr.Checkbox(label="tie_word_embeddings", value=False) | |
| with gr.TabItem("Opciones Generales (Trainer)", id="tab_general"): | |
| general_group = gr.Group(visible=True) | |
| with general_group: | |
| with gr.Row(): | |
| add_eos_token = gr.Checkbox(label="add_eos_token", value=True) | |
| auto_find_batch_size = gr.Checkbox(label="auto_find_batch_size (Acelerador)", value=True) | |
| disable_gradient_checkpointing = gr.Checkbox(label="disable_gradient_checkpointing", value=False) | |
| load_best_model_at_end = gr.Checkbox(label="load_best_model_at_end", value=False, info="Requiere dataset de validaci贸n y `eval_strategy='steps'`. Guarda el mejor modelo al final.") | |
| with gr.Row(): | |
| chat_template = gr.Textbox(label="chat_template", value="tokenizer", lines=3, | |
| placeholder="Dejar 'tokenizer' para usar la del modelo. 'none' para deshabilitar. O pegar plantilla Jinja2 personalizada.\nEj: {% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n' }}{% else %}{{ 'Assistant: ' + message['content'] + '\n' }}{% endif %}{% endfor %}", | |
| info="Plantilla de chat personalizada (Jinja2).") | |
| with gr.Row(): | |
| distributed_backend = gr.Textbox(label="distributed_backend (Ignorado en GPU 煤nica)", value="ddp", placeholder="ddp o fsdp") | |
| eval_strategy = gr.Textbox(label="eval_strategy", value="steps", placeholder="no, steps o epoch", info="Recomendado: 'steps' para usar `validation_dataset`.") | |
| merge_adapter = gr.Checkbox(label="merge_adapter (Subir modelo completo)", value=True, info="Ignorado si 'Entrenar desde Cero' est谩 activo.") | |
| mixed_precision = gr.Textbox(label="mixed_precision", value="bf16", placeholder="none, fp16 o bf16") | |
| with gr.Row(): | |
| optimizer = gr.Textbox(label="optimizer", value="adamw_torch", placeholder="adamw_8bit, adamw_torch, etc.") | |
| peft = gr.Checkbox(label="peft (Activar LoRA)", value=True, info="Ignorado si 'Entrenar desde Cero' est谩 activo.") | |
| padding = gr.Textbox(label="padding", value="left", placeholder="left o right") | |
| quantization = gr.Textbox(label="quantization", value="int4", placeholder="none, int8 o int4", info="Ignorado si 'Entrenar desde Cero' est谩 activo.") | |
| scheduler = gr.Textbox(label="scheduler", value="cosine", placeholder="linear, cosine o constant") | |
| with gr.TabItem("Hiperpar谩metros Num茅ricos", id="tab_numeric"): | |
| numeric_group = gr.Group(visible=True) | |
| with numeric_group: | |
| with gr.Row(): | |
| batch_size = gr.Slider(1, 128, value=1, step=1, label="per_device_train_batch_size") | |
| block_size = gr.Slider(16, 8192, value=1024, step=16, label="max_seq_length (Bloque de texto)") | |
| epochs = gr.Slider(1, 20, value=1, step=1, label="num_train_epochs") | |
| gradient_accumulation = gr.Slider(1, 64, value=8, step=1, label="gradient_accumulation_steps") | |
| with gr.Row(): | |
| learning_rate = gr.Number(value=1e-5, label="learning_rate (lr)") | |
| logging_steps = gr.Number(value=5, label="logging_steps", info="Tambi茅n controla `evaluation_steps`.") | |
| max_grad_norm = gr.Number(value=0.3, label="max_grad_norm", info="Normalizaci贸n de gradientes (Clipping).") | |
| steps_per_epoch_estimate = gr.Number(value=10000, label="Estimaci贸n de Pasos por 脡poca", info="Usado para calcular `max_steps` en streaming.") | |
| with gr.Row(): | |
| save_total_limit = gr.Number(value=1, label="save_total_limit", info="N煤mero de checkpoints a guardar.") | |
| seed = gr.Number(value=42, label="seed") | |
| warmup_ratio = gr.Number(value=0.05, label="warmup_ratio") | |
| weight_decay = gr.Number(value=0.01, label="weight_decay") | |
| with gr.TabItem("PEFT / LoRA", id="tab_peft"): | |
| peft_group = gr.Group(visible=True) | |
| with peft_group: | |
| gr.Markdown("#### PEFT / LoRA (Ignorado si 'Entrenar desde Cero' est谩 activo o `peft` est谩 deshabilitado)") | |
| with gr.Row(): | |
| lora_r = gr.Number(value=32, label="lora_r (r)") | |
| lora_alpha = gr.Number(value=32, label="lora_alpha") | |
| lora_dropout = gr.Number(value=0.05, label="lora_dropout") | |
| target_modules = gr.Textbox(value="q_proj,k_proj,v_proj,o_proj", placeholder="Ej: all-linear o q_proj,v_proj", label="target_modules (Separados por coma)") | |
| with gr.TabItem("Configuraci贸n del Modelo (Carga)", id="tab_load"): | |
| load_config_group = gr.Group(visible=True) | |
| with load_config_group: | |
| gr.Markdown("#### Configuraci贸n del Modelo (Carga - Ignorado si 'Entrenar desde Cero' est谩 activo)") | |
| with gr.Row(): | |
| trust_remote_code_input = gr.Checkbox(label="trust_remote_code", value=True, info="Permitir cargar modelos con c贸digo personalizado.") | |
| attn_implementation_input = gr.Textbox(label="attn_implementation", value="sdpa", placeholder="auto, eager, sdpa, flash_attention_2", info="Optimizaci贸n de atenci贸n (ej: flash_attention_2 si est谩 disponible).") | |
| with gr.Row(): | |
| new_special_tokens_input = gr.Textbox(label="Nuevos tokens especiales (coma-separado)", placeholder="<|im_start|>, <|im_end|>, <|system|>", info="A帽adir tokens al vocabulario. Redimensiona embeddings.") | |
| with gr.TabItem("Formato de Datos (Razonamiento)", id="tab_cot"): | |
| cot_group = gr.Group(visible=True) | |
| with cot_group: | |
| with gr.Row(): | |
| enable_cot_input = gr.Checkbox(label="Activar formato de razonamiento (CoT)", value=False, info="Anula la detecci贸n autom谩tica de columna 'text' o 'messages'.") | |
| with gr.Row(): | |
| prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt", info="Nombre de la columna con el prompt/pregunta.") | |
| with gr.Row(): | |
| reasoning_col_input = gr.Textbox(label="Columna de Razonamiento (Opcional)", value="reasoning", info="Nombre de la columna con los pasos de CoT.") | |
| with gr.Row(): | |
| response_col_input = gr.Textbox(label="Columna de Respuesta", value="response", info="Nombre de la columna con la respuesta final.") | |
| with gr.TabItem("Normalizaci贸n de Texto (Datos)", id="tab_norm"): | |
| normalization_group = gr.Group(visible=True) | |
| with normalization_group: | |
| with gr.Row(): | |
| apply_lowercase_input = gr.Checkbox(label="Aplicar min煤sculas", value=False, info="Convierte todo el texto a min煤sculas antes de tokenizar.") | |
| remove_punctuation_input = gr.Checkbox(label="Eliminar puntuaci贸n", value=False, info="Elimina signos de puntuaci贸n del texto antes de tokenizar.") | |
| with gr.TabItem("Logging (Weights & Biases)", id="tab_wandb"): | |
| wandb_group = gr.Group(visible=True) | |
| with wandb_group: | |
| with gr.Row(): | |
| wandb_project_input = gr.Textbox(label="Nombre del Proyecto W&B", placeholder="mi-proyecto-sft", info="Dejar vac铆o para deshabilitar W&B.") | |
| with gr.Row(): | |
| wandb_api_key_input = gr.Textbox(label="API Key de W&B (Opcional)", type="password", placeholder="w_**************************************", info="Opcional. Usar si no est谩 configurado globalmente.") | |
| def toggle_scratch_config(scratch_checked): | |
| scratch_visible = scratch_checked | |
| finetune_visible = not scratch_checked | |
| updates = { | |
| model_base_input: gr.Textbox(visible=finetune_visible), | |
| scratch_group: gr.Group(visible=scratch_visible), | |
| peft_group: gr.Group(visible=finetune_visible), | |
| load_config_group: gr.Group(visible=finetune_visible), | |
| } | |
| return updates | |
| train_from_scratch.change( | |
| fn=toggle_scratch_config, | |
| inputs=[train_from_scratch], | |
| outputs=[model_base_input, scratch_group, peft_group, load_config_group], | |
| queue=False | |
| ) | |
| run_btn = gr.Button("馃殌 Iniciar Entrenamiento y Subir a Hugging Face Hub", elem_classes="primary-btn mt-6") | |
| gr.Markdown("## Resultados y Progreso") | |
| with gr.Group(elem_classes="card"): | |
| with gr.Row(): | |
| logs = gr.Textbox(label="Logs de Ejecuci贸n", lines=12, elem_classes="textbox-logs", scale=3) | |
| with gr.Column(scale=2): | |
| phase = gr.Textbox(label="Fase Actual", lines=1) | |
| link = gr.Textbox(label="Repositorio de Salida", lines=1) | |
| cfg_out = gr.File(label="Configuraci贸n Final Generada (params.json)", file_types=[".json"]) | |
| scratch_inputs = [ | |
| scratch_vocab_size, scratch_special_tokens, scratch_base_tokenizer, | |
| scratch_hidden_size, scratch_intermediate_size, scratch_num_hidden_layers, | |
| scratch_num_attention_heads, scratch_num_key_value_heads, scratch_max_pos_embed, scratch_tie_word_embeddings | |
| ] | |
| all_inputs = [ | |
| model_base_input, ds_text, uploads, repo_name_input, | |
| train_from_scratch, scratch_architecture, | |
| add_eos_token, auto_find_batch_size, chat_template, | |
| disable_gradient_checkpointing, distributed_backend, eval_strategy, | |
| load_best_model_at_end, | |
| merge_adapter, mixed_precision, optimizer, peft, padding, | |
| quantization, scheduler, batch_size, block_size, epochs, | |
| gradient_accumulation, learning_rate, logging_steps, lora_alpha, | |
| lora_dropout, lora_r, max_grad_norm, | |
| save_total_limit, seed, warmup_ratio, weight_decay, target_modules, | |
| steps_per_epoch_estimate, | |
| trust_remote_code_input, attn_implementation_input, new_special_tokens_input, | |
| apply_lowercase_input, remove_punctuation_input, | |
| enable_cot_input, prompt_col_input, reasoning_col_input, response_col_input, | |
| wandb_project_input, wandb_api_key_input, | |
| auto_config_scratch | |
| ] + scratch_inputs | |
| all_outputs = [logs, phase, link, cfg_out] | |
| run_btn.click( | |
| fn=train_and_upload, | |
| inputs=all_inputs, | |
| outputs=all_outputs | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |