Spaces:
Runtime error
Runtime error
| import os | |
| os.system("pip install -U transformers peft accelerate trl bitsandbytes datasets diffusers") | |
| os.system("pip install spaces-0.1.0-py3-none-any.whl") | |
| import io | |
| import json | |
| import tempfile | |
| import string | |
| import gc | |
| import math | |
| import uuid | |
| import logging | |
| import traceback | |
| import importlib | |
| import random | |
| import re | |
| import ast | |
| from itertools import islice | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import accelerate | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torchaudio | |
| from bs4 import BeautifulSoup | |
| from langdetect import detect_langs | |
| import textstat | |
| from datasketch import MinHash, MinHashLSH | |
| import gradio as gr | |
| import spaces | |
| from datasets import load_dataset, IterableDataset, interleave_datasets, Audio | |
| from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi | |
| from transformers import ( | |
| AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer, | |
| AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, | |
| SpeechT5ForTextToSpeech, SpeechT5Processor, SpeechT5HifiGan, AutoModelForImageClassification, | |
| AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification, | |
| DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer, | |
| DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig, | |
| LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel, | |
| PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM, | |
| DataCollatorForLanguageModeling, DefaultDataCollator | |
| ) | |
| from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training | |
| from trl import SFTTrainer, DPOTrainer | |
| from diffusers import ( | |
| UNet2DConditionModel, DDPMScheduler, AutoencoderKL, | |
| get_scheduler as get_diffusers_scheduler, StableDiffusionPipeline as StableDiffusionText2ImagePipeline, | |
| StableDiffusionImg2ImgPipeline as StableDiffusionImage2ImagePipeline | |
| ) | |
| import evaluate as hf_evaluate | |
| from jinja2 import Template | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| torch_dtype_auto = torch.float32 | |
| ARCHITECTURE_MAP = {"Llama": (LlamaConfig, LlamaForCausalLM), "Mistral": (MistralConfig, MistralForCausalLM), "Gemma": (GemmaConfig, MistralForCausalLM), "GPT2": (GPT2Config, GPT2LMHeadModel), "Phi": (PhiConfig, PhiForCausalLM), "Qwen2": (Qwen2Config, Qwen2ForCausalLM)} | |
| SCRATCH_TOKENIZER_MAP = {"Llama": "meta-llama/Llama-2-7b-hf", "Mistral": "mistralai/Mistral-7B-v0.1", "Gemma": "google/gemma-2b", "GPT2": "gpt2", "Phi": "microsoft/phi-2", "Qwen2": "Qwen/Qwen2-0.5B"} | |
| TRAINING_MODES = [ | |
| "Causal Language Modeling (SFT/LoRA)", | |
| "DPO (Direct Preference Optimization)", | |
| "Question Answering (Text)", | |
| "Token Classification (NER)", | |
| "Sequence Classification (Text)", | |
| "Text-to-Image Generation", | |
| "Image Classification (Vision)", | |
| "Audio Classification (Speech)", | |
| "ASR (Speech-to-Text)", | |
| "Text2Text Generation" | |
| ] | |
| TASK_TO_PIPELINE_MAP = { | |
| "Causal Language Modeling (SFT/LoRA)": "text-generation", | |
| "DPO (Direct Preference Optimization)": "text-generation", | |
| "Question Answering (Text)": "question-answering", | |
| "Token Classification (NER)": "token-classification", | |
| "Sequence Classification (Text)": "text-classification", | |
| "Image Classification (Vision)": "image-classification", | |
| "Audio Classification (Speech)": "audio-classification", | |
| "ASR (Speech-to-Text)": "automatic-speech-recognition", | |
| "Text2Text Generation": "text2text-generation", | |
| "Text-to-Image Generation": "text-to-image", | |
| } | |
| MODEL_CARD_TEMPLATE = """ | |
| --- | |
| language: es | |
| license: apache-2.0 | |
| tags: | |
| - autotrain-advanced | |
| - fine-tuned | |
| - {base_model_name} | |
| widget: | |
| - text: "Hola, ¿cómo estás?" | |
| --- | |
| # {repo_id} | |
| Este modelo es una versión afinada de [{base_model}](https://huggingface.co/{base_model}) entrenado con la herramienta [AutoTrain-Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced). | |
| ## Detalles del Entrenamiento | |
| - **Modo de Entrenamiento:** {training_mode} | |
| - **Modelo Base:** `{base_model}` | |
| - **Datasets:** `{datasets}` | |
| - **Entrenado en:** {date} | |
| ### Hiperparámetros de Entrenamiento | |
| ```json | |
| {hyperparameters}``` | |
| ### Frameworks Utilizados | |
| - Transformers | |
| - PEFT | |
| - BitsAndBytes | |
| - Accelerate | |
| - TRL | |
| - Gradio | |
| """ | |
| class DebiasingSFTTrainer(SFTTrainer): | |
| def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.reweighting_terms = [term.strip().lower() for term in reweighting_terms] if reweighting_terms else [] | |
| self.reweighting_factor = reweighting_factor | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| loss, outputs = super().compute_loss(model, inputs, return_outputs=True) | |
| if self.reweighting_terms and self.reweighting_factor > 1.0: | |
| input_ids = inputs.get("input_ids") | |
| decoded_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) | |
| for text in decoded_texts: | |
| if any(term in text.lower() for term in self.reweighting_terms): | |
| loss *= self.reweighting_factor | |
| break | |
| return (loss, outputs) if return_outputs else loss | |
| class DeduplicatedIterableDataset(IterableDataset): | |
| def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128): | |
| super().__init__(ex_iterable=iter([])) | |
| self.dataset = dataset | |
| self.text_col = text_col | |
| self.method = method | |
| self.threshold = threshold | |
| self.num_perm = num_perm | |
| if hasattr(dataset, '_info'): | |
| self._info = dataset._info | |
| elif hasattr(dataset, 'info'): | |
| self._info = dataset.info | |
| def __iter__(self): | |
| if self.method == 'Exacta': | |
| return self._exact_iter() | |
| elif self.method == 'Semántica (MinHash)': | |
| return self._minhash_iter() | |
| else: | |
| return iter(self.dataset) | |
| def _exact_iter(self): | |
| seen_texts = set() | |
| for example in self.dataset: | |
| text = example.get(self.text_col, "") | |
| if text and isinstance(text, str): | |
| if text not in seen_texts: | |
| seen_texts.add(text) | |
| yield example | |
| else: | |
| yield example | |
| def _minhash_iter(self): | |
| lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm) | |
| for i, example in enumerate(self.dataset): | |
| text = example.get(self.text_col, "") | |
| if text and isinstance(text, str) and text.strip(): | |
| m = MinHash(num_perm=self.num_perm) | |
| for d in text.split(): | |
| m.update(d.encode('utf8')) | |
| if not lsh.query(m): | |
| lsh.insert(f"key_{i}", m) | |
| yield example | |
| else: | |
| yield example | |
| def hf_login(token): | |
| if not token: | |
| return "Por favor, introduce un token." | |
| try: | |
| login(token=token, add_to_git_credential=True) | |
| user = whoami() | |
| return f"✅ Conectado como: {user['name']}" | |
| except Exception as e: | |
| return f"❌ Error en la conexión: {e}" | |
| def _clean_text(example, text_col, **kwargs): | |
| text = example.get(text_col, "") | |
| if not isinstance(text, str): | |
| return example | |
| if kwargs.get('remove_html_tags'): | |
| text = BeautifulSoup(text, "html.parser").get_text() | |
| if kwargs.get('remove_urls_emails'): | |
| text = re.sub(r'http\S+|www\S+|httpsS+', '', text, flags=re.MULTILINE) | |
| if kwargs.get('normalize_whitespace'): | |
| text = ' '.join(text.split()) | |
| if kwargs.get('redact_pii'): | |
| text = re.sub(r'\S+@\S+', '<EMAIL>', text) | |
| text = re.sub(r'(\d{1,4}[-.\s]?){7,}|(\+\d{1,3}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}', '<PHONE>', text) | |
| text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '<IP_ADDRESS>', text) | |
| example[text_col] = text | |
| return example | |
| def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords): | |
| text = example.get(text_col, "") | |
| if not isinstance(text, str): return False | |
| text_len = len(text.split()) | |
| if not (min_len <= text_len <= max_len): return False | |
| words = text.split() | |
| if not words: return False | |
| word_counts = {} | |
| for word in words: word_counts[word] = word_counts.get(word, 0) + 1 | |
| if not word_counts or (max(word_counts.values()) / len(words)) > rep_threshold: return False | |
| lower_text = text.lower() | |
| return not any(keyword in lower_text for keyword in exclude_keywords) | |
| def _get_filter_functions(**kwargs): | |
| filters = [] | |
| if kwargs.get('enable_quality_filter'): | |
| exclude_list = [k.strip().lower() for k in (kwargs.get('exclude_keywords_input', '') + ',' + kwargs.get('bias_keywords_input', '')).split(",") if k.strip()] | |
| filters.append(lambda ex: _apply_quality_filters(ex, kwargs['text_col'], kwargs['min_len_input'], kwargs['max_len_input'], kwargs['rep_threshold_input'], exclude_list)) | |
| if kwargs.get('enable_language_filter'): | |
| allowed_langs = [lang.strip() for lang in kwargs.get('allowed_languages', 'en').split(',')] | |
| lang_threshold = kwargs.get('language_detection_threshold', 0.95) | |
| def lang_filter(ex): | |
| text = ex.get(kwargs['text_col'], "") | |
| if not text or not isinstance(text, str) or len(text.split()) < 5: return True | |
| try: | |
| detected = detect_langs(text) | |
| return any(lang.lang in allowed_langs and lang.prob > lang_threshold for lang in detected) | |
| except: | |
| return False | |
| filters.append(lang_filter) | |
| if kwargs.get('enable_toxicity_filter'): | |
| tox_pipe = pipeline("text-classification", model="unitary/toxic-bert") | |
| tox_threshold = kwargs.get('toxicity_threshold', 0.8) | |
| def tox_filter(ex): | |
| text = ex.get(kwargs['text_col'], "") | |
| if not text or not isinstance(text, str): return True | |
| try: | |
| results = tox_pipe(text[:512], truncation=True) | |
| return not (results[0]['label'] == 'toxic' and results[0]['score'] > tox_threshold) | |
| except: | |
| return True | |
| filters.append(tox_filter) | |
| if any([kwargs.get('enable_readability_filter'), kwargs.get('enable_stopword_filter'), kwargs.get('enable_uniqueness_filter')]): | |
| stop_words = set(textstat.DEFAULT_stopwords) | |
| def stats_filter(ex): | |
| text = ex.get(kwargs['text_col'], "") | |
| if not isinstance(text, str) or not text: return True | |
| words = text.split() | |
| num_words = len(words) | |
| if num_words == 0: return True | |
| if kwargs.get('enable_readability_filter'): | |
| score = textstat.flesch_reading_ease(text) | |
| if not (kwargs['min_readability'] <= score <= kwargs['max_readability']): return False | |
| if kwargs.get('enable_stopword_filter'): | |
| if (textstat.stopword_count(text) / num_words) > kwargs['max_stopword_ratio']: return False | |
| if kwargs.get('enable_uniqueness_filter'): | |
| if (len(set(words)) / num_words) < kwargs['min_uniqueness_ratio']: return False | |
| return True | |
| filters.append(stats_filter) | |
| return filters | |
| def _load_hf_streaming(ids, split="train", probabilities=None): | |
| streams = [] | |
| valid_ids = [] | |
| for ident in ids: | |
| try: | |
| d = load_dataset(ident, streaming=True, trust_remote_code=True, verification_mode="no_checks") | |
| split_found = False | |
| if isinstance(d, dict): | |
| for s_name, ds in d.items(): | |
| if s_name.lower() == split or (split == "train" and "train" in s_name.lower()): | |
| streams.append(ds) | |
| split_found = True | |
| break | |
| else: | |
| streams.append(d) | |
| split_found = True | |
| if split_found: | |
| valid_ids.append(ident) | |
| else: | |
| logger.warning(f"Split '{split}' not found in dataset {ident}. Excluding from this source.") | |
| except Exception as e: | |
| logger.error(f"Error loading dataset {ident} split {split}: {e}. Excluding from this source.") | |
| if not streams: | |
| return None | |
| if probabilities and len(probabilities) != len(streams): | |
| logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.") | |
| probabilities = None | |
| return interleave_datasets(streams, probabilities=probabilities) | |
| def _load_uploaded_stream(files): | |
| all_rows = [] | |
| for f in files or []: | |
| content = f.read().decode("utf-8", errors="ignore") | |
| name = f.name.lower() | |
| if name.endswith(".csv"): | |
| import csv | |
| all_rows.extend(list(csv.DictReader(io.StringIO(content)))) | |
| elif name.endswith(".jsonl"): | |
| all_rows.extend([json.loads(line) for line in io.StringIO(content) if line.strip()]) | |
| elif name.endswith(".json"): | |
| data = json.loads(content) | |
| all_rows.extend(data if isinstance(data, list) else [data]) | |
| elif name.endswith(".txt"): | |
| all_rows.extend([{"text": line} for line in io.StringIO(content) if line.strip()]) | |
| if not all_rows: | |
| return None | |
| val_size = max(1, int(len(all_rows) * 0.01)) | |
| random.shuffle(all_rows) | |
| return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []} | |
| def _guess_columns(sample): | |
| text_col, image_col, audio_col, label_col = "text", "image", "audio", "label" | |
| if not isinstance(sample, dict): | |
| return text_col, image_col, audio_col, label_col | |
| keys = {k.lower(): k for k in sample.keys()} | |
| if "text" in keys: text_col = keys["text"] | |
| elif "content" in keys: text_col = keys["content"] | |
| elif "prompt" in keys: text_col = keys["prompt"] | |
| if "image" in keys: image_col = keys["image"] | |
| elif "img" in keys: image_col = keys["img"] | |
| if "audio" in keys: audio_col = keys["audio"] | |
| elif "speech" in keys: audio_col = keys["speech"] | |
| if "label" in keys: label_col = keys["label"] | |
| elif "labels" in keys: label_col = keys["labels"] | |
| return text_col, image_col, audio_col, label_col | |
| def _apply_cda(dataset, text_col, cda_config_str): | |
| try: | |
| swap_groups = json.loads(cda_config_str) | |
| except (json.JSONDecodeError, ValueError) as e: | |
| logger.error(f"Configuración de CDA inválida: {e}.") | |
| return dataset | |
| def cda_generator(): | |
| for example in dataset: | |
| original_text = example.get(text_col, "") | |
| if not isinstance(original_text, str): | |
| yield example | |
| continue | |
| yield example | |
| generated_texts = {original_text} | |
| current_texts = {original_text} | |
| for group in swap_groups: | |
| next_texts = set() | |
| for text in current_texts: | |
| for word_to_replace in group: | |
| if word_to_replace in text: | |
| for replacement_word in group: | |
| if word_to_replace != replacement_word: | |
| new_text = text.replace(word_to_replace, replacement_word) | |
| if new_text not in generated_texts: | |
| new_example = example.copy() | |
| new_example[text_col] = new_text | |
| yield new_example | |
| generated_texts.add(new_text) | |
| next_texts.add(new_text) | |
| current_texts.update(next_texts) | |
| return IterableDataset.from_generator(cda_generator) | |
| def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id): | |
| if not ratio or ratio <= 0: | |
| return dataset | |
| logger.info(f"Aplicando retrotraducción al {ratio*100}% del dataset.") | |
| try: | |
| pipe_to = pipeline("translation", model=model_id) | |
| pipe_from = pipeline("translation", model=reverse_model_id) | |
| except Exception as e: | |
| logger.error(f"No se pudieron cargar los modelos de traducción: {e}") | |
| return dataset | |
| def bt_generator(): | |
| for example in dataset: | |
| yield example | |
| if random.random() < ratio: | |
| original_text = example.get(text_col, "") | |
| if isinstance(original_text, str) and original_text: | |
| try: | |
| translated = pipe_to(original_text, max_length=512)[0]['translation_text'] | |
| back_translated = pipe_from(translated, max_length=512)[0]['translation_text'] | |
| if back_translated: | |
| new_example = example.copy() | |
| new_example[text_col] = back_translated | |
| yield new_example | |
| except Exception as e: | |
| logger.warning(f"Error en retrotraducción: {e}") | |
| return IterableDataset.from_generator(bt_generator) | |
| def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template): | |
| if not num_samples or num_samples <= 0: | |
| return None | |
| logger.info(f"Iniciando generación de {num_samples} muestras sintéticas con el modelo {model_id}.") | |
| try: | |
| generator = pipeline("text-generation", model=model_id, torch_dtype=torch_dtype_auto) | |
| except Exception as e: | |
| logger.error(f"No se pudo cargar el modelo generador sintético: {e}") | |
| return None | |
| seed_examples = list(islice(original_dataset, 200)) | |
| if not seed_examples: | |
| logger.warning("Dataset original vacío, no se pueden generar datos sintéticos.") | |
| return None | |
| def synthetic_generator(): | |
| for i in range(num_samples): | |
| seed_example = random.choice(seed_examples) | |
| seed_text = seed_example.get(text_col, "") | |
| prompt = Template(prompt_template).render(example_text=seed_text) | |
| try: | |
| generated_output = generator(prompt, max_new_tokens=256, num_return_sequences=1, do_sample=True, temperature=0.9, top_p=0.95) | |
| cleaned_text = generated_output[0]['generated_text'][len(prompt):].strip() | |
| if "new example:" in cleaned_text.lower(): | |
| cleaned_text = re.split("new example:", cleaned_text, flags=re.IGNORECASE)[-1].strip() | |
| if cleaned_text: | |
| new_example = seed_example.copy() | |
| new_example[text_col] = cleaned_text | |
| yield new_example | |
| except Exception as e: | |
| logger.warning(f"Error generando una muestra sintética: {e}") | |
| continue | |
| return IterableDataset.from_generator(synthetic_generator) | |
| def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation): | |
| safe_steps = int(steps_per_epoch_estimate or 10000) | |
| safe_batch_size = int(batch_size or 1) | |
| safe_grad_accum = int(gradient_accumulation or 8) | |
| safe_block_size = int(block_size or 1024) | |
| size = safe_steps * safe_batch_size * safe_grad_accum | |
| if size <= 1: | |
| size = 10000 | |
| log_size = math.log2(max(1000, size)) | |
| vocab_size = min(65536, 32000 + int(log_size * 2000)) | |
| preliminary_hidden_size = max(512, min(4096, 512 + int(log_size * 100))) | |
| heads = max(8, min(32, preliminary_hidden_size // 64)) | |
| if heads == 0: heads = 8 | |
| hidden_size = (preliminary_hidden_size // heads) * heads | |
| layers = max(8, min(32, 8 + int(log_size * 1.5))) | |
| kv_heads = heads if is_gpt2_like else (max(1, heads // 4)) | |
| return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads | |
| def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn): | |
| if eval_ds_id: | |
| yield update_logs_fn(f"Cargando dataset de evaluación: {eval_ds_id}", "Evaluación") | |
| return _load_hf_streaming([eval_ds_id], split="train") | |
| if uploaded_val_data: | |
| yield update_logs_fn("Usando split de validación de archivos subidos.", "Evaluación") | |
| return IterableDataset.from_generator(lambda: iter(uploaded_val_data)) | |
| if train_ds_id: | |
| yield update_logs_fn("Intentando cargar split 'validation' o 'test' del dataset de entrenamiento.", "Evaluación") | |
| try: | |
| for split_name in ["validation", "test"]: | |
| eval_ds = _load_hf_streaming([train_ds_id], split=split_name) | |
| if eval_ds: | |
| yield update_logs_fn(f"Split '{split_name}' encontrado y cargado.", "Evaluación") | |
| return eval_ds | |
| except Exception as e: | |
| yield update_logs_fn(f"Error cargando split de evaluación: {e}. Omitiendo.", "Evaluación") | |
| return None | |
| yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación") | |
| return None | |
| def _create_training_args(output_dir, repo_id, **kwargs): | |
| neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0)) | |
| optim_args_dict = {} | |
| if kwargs.get('optim_args'): | |
| try: | |
| optim_args_dict = ast.literal_eval(f"dict({kwargs['optim_args']})") | |
| except Exception as e: | |
| logger.warning(f"No se pudieron parsear los argumentos del optimizador: {e}.") | |
| args_dict = { | |
| "output_dir": os.path.join(output_dir, "results"), | |
| "per_device_train_batch_size": int(kwargs.get('batch_size', 1)), | |
| "gradient_accumulation_steps": int(kwargs.get('gradient_accumulation', 8)), | |
| "optim": kwargs.get('optimizer', 'adamw_torch'), | |
| "optim_args": optim_args_dict, | |
| "save_strategy": "steps", | |
| "logging_steps": int(kwargs.get('logging_steps', 10)), | |
| "save_steps": int(kwargs.get('save_steps', 50)), | |
| "learning_rate": float(kwargs.get('learning_rate', 2e-5)), | |
| "fp16": False, | |
| "bf16": False, | |
| "max_grad_norm": float(kwargs.get('max_grad_norm', 0.3)), | |
| "warmup_ratio": float(kwargs.get('warmup_ratio', 0.03)), | |
| "lr_scheduler_type": kwargs.get('scheduler', 'cosine'), | |
| "weight_decay": float(kwargs.get('weight_decay', 0.01)), | |
| "load_best_model_at_end": kwargs.get('run_evaluation', False), | |
| "save_total_limit": int(kwargs.get('save_total_limit', 1)), | |
| "gradient_checkpointing": not kwargs.get('disable_gradient_checkpointing', False), | |
| "push_to_hub": True, | |
| "hub_model_id": repo_id, | |
| "hub_strategy": kwargs.get('hub_strategy', 'every_save'), | |
| "dataloader_num_workers": 4, | |
| "report_to": "wandb" if kwargs.get('wandb_api_key_input') else "none", | |
| "remove_unused_columns": False, | |
| "group_by_length": kwargs.get('group_by_length', False), | |
| "metric_for_best_model": kwargs.get('metric_for_best_model', 'loss') if kwargs.get('run_evaluation') else None, | |
| "greater_is_better": kwargs.get('greater_is_better', False), | |
| "neftune_noise_alpha": neftune_alpha if neftune_alpha > 0 else None, | |
| "adam_beta1": float(kwargs.get('adam_beta1', 0.9)), | |
| "adam_beta2": float(kwargs.get('adam_beta2', 0.999)), | |
| "adam_epsilon": float(kwargs.get('adam_epsilon', 1e-8)), | |
| "no_cuda": True | |
| } | |
| max_train_samples = int(kwargs.get('max_train_samples', -1)) | |
| if max_train_samples > 0: | |
| max_steps = int(max_train_samples / args_dict["per_device_train_batch_size"] / args_dict["gradient_accumulation_steps"]) | |
| args_dict["max_steps"] = max_steps | |
| else: | |
| args_dict["num_train_epochs"] = float(kwargs.get('epochs', 1.0)) | |
| return TrainingArguments(**args_dict) | |
| def _generic_model_loader(model_name_or_path, model_class, **kwargs): | |
| quantization_type = kwargs.get('quantization', 'no') | |
| if quantization_type != "no": | |
| raise ValueError("La cuantización solo es compatible con GPU, que está deshabilitada.") | |
| attn_implementation = kwargs.get('attn_implementation', 'eager') | |
| config_kwargs = {"trust_remote_code": True} | |
| if kwargs.get('label2id'): | |
| config_kwargs.update({"label2id": kwargs['label2id'], "id2label": kwargs['id2label']}) | |
| config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs) | |
| if kwargs.get('attention_dropout', 0) > 0: config.attention_dropout = kwargs['attention_dropout'] | |
| if kwargs.get('hidden_dropout', 0) > 0: config.hidden_dropout = kwargs['hidden_dropout'] | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "config": config, | |
| "attn_implementation": attn_implementation, | |
| "torch_dtype": torch_dtype_auto, | |
| } | |
| if kwargs.get('num_labels'): | |
| model_kwargs.update({"num_labels": kwargs['num_labels'], "ignore_mismatched_sizes": True}) | |
| model = model_class.from_pretrained(model_name_or_path, **model_kwargs) | |
| return model | |
| def _find_all_linear_names(model, quantization_type): | |
| cls = torch.nn.Linear | |
| lora_module_names = set() | |
| for name, module in model.named_modules(): | |
| if isinstance(module, cls): | |
| names = name.split('.') | |
| lora_module_names.add(names[-1]) | |
| if 'lm_head' in lora_module_names: | |
| lora_module_names.remove('lm_head') | |
| common_targets = {'q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'} | |
| return list(lora_module_names.intersection(common_targets)) or list(lora_module_names) | |
| def _conversation_formatting_func(example, tokenizer, **kwargs): | |
| conv_col = "" | |
| for key in ["messages", "conversations", "turns"]: | |
| if key in example: conv_col = key; break | |
| if not conv_col: return "" | |
| conversation = example[conv_col] | |
| if isinstance(conversation, str): | |
| try: conversation = ast.literal_eval(conversation) | |
| except: return "" | |
| return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) | |
| def _sft_formatting_func(example, text_col, tokenizer, **kwargs): | |
| if kwargs.get('enable_cot_input') or kwargs.get('enable_tool_use_input'): | |
| messages = [] | |
| prompt = example.get(kwargs.get('prompt_col_input', 'prompt'), "") | |
| if prompt: messages.append({"role": "user", "content": prompt}) | |
| response_parts = [] | |
| if kwargs.get('enable_cot_input') and example.get(kwargs.get('reasoning_col_input', 'reasoning')): response_parts.append(f"<thinking>{example[kwargs.get('reasoning_col_input', 'reasoning')]}</thinking>") | |
| if kwargs.get('enable_tool_use_input') and example.get(kwargs.get('tool_use_col_input', 'tools')): response_parts.append(f"<tool_code>{example[kwargs.get('tool_use_col_input', 'tools')]}</tool_code>") | |
| if example.get(kwargs.get('response_col_input', 'response')): response_parts.append(example[kwargs.get('response_col_input', 'response')]) | |
| if response_parts: messages.append({"role": "assistant", "content": "\n".join(response_parts)}) | |
| if messages: | |
| try: | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) | |
| except Exception as e: | |
| logger.error(f"Error applying chat template: {e}.") | |
| return "\n".join([m['content'] for m in messages]) | |
| return example.get(text_col, "") | |
| def _dpo_formatting_func(example, **kwargs): | |
| return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")} | |
| def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col): | |
| model.eval() | |
| encodings = tokenizer("\n\n".join(ex[text_col] for ex in islice(eval_dataset, 1000)), return_tensors="pt") | |
| max_length = model.config.max_position_embeddings | |
| stride = 512 | |
| seq_len = encodings.input_ids.size(1) | |
| nlls = [] | |
| prev_end_loc = 0 | |
| with torch.no_grad(): | |
| for begin_loc in range(0, seq_len, stride): | |
| end_loc = min(begin_loc + max_length, seq_len) | |
| trg_len = end_loc - prev_end_loc | |
| input_ids = encodings.input_ids[:, begin_loc:end_loc] | |
| target_ids = input_ids.clone() | |
| target_ids[:, :-trg_len] = -100 | |
| outputs = model(input_ids, labels=target_ids) | |
| neg_log_likelihood = outputs.loss | |
| nlls.append(neg_log_likelihood) | |
| prev_end_loc = end_loc | |
| if end_loc == seq_len: | |
| break | |
| ppl = torch.exp(torch.stack(nlls).mean()) | |
| return ppl.item() | |
| def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type): | |
| adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()] | |
| if not adapter_ids: | |
| yield "No se proporcionaron IDs de adaptadores válidos. Omitiendo la fusión múltiple." | |
| return base_model_id | |
| try: | |
| weights = [float(w.strip()) for w in weights_str.split(',')] | |
| except: | |
| weights = [1.0] * len(adapter_ids) | |
| if len(weights) != len(adapter_ids): | |
| weights = [1.0] * len(adapter_ids) | |
| yield "Pesos de adaptadores inválidos, usando 1.0 para todos." | |
| yield f"Cargando modelo base {base_model_id} para fusión múltiple..." | |
| model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch_dtype_auto, trust_remote_code=True) | |
| for i, adapter_id in enumerate(adapter_ids): | |
| yield f"Cargando adaptador {i+1}: {adapter_id}" | |
| model.load_adapter(adapter_id, adapter_name=f"adapter_{i}") | |
| adapter_names = [f"adapter_{i}" for i in range(len(adapter_ids))] | |
| yield f"Combinando adaptadores: {adapter_names} con pesos: {weights} y tipo: {combination_type}" | |
| model.add_weighted_adapter(adapters=adapter_names, weights=weights, adapter_name="combined", combination_type=combination_type) | |
| model.set_adapter("combined") | |
| yield "Fusionando combinación de adaptadores en el modelo base..." | |
| merged_model = model.merge_and_unload() | |
| temp_dir = tempfile.mkdtemp() | |
| yield f"Guardando modelo fusionado en {temp_dir}" | |
| merged_model.save_pretrained(temp_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
| tokenizer.save_pretrained(temp_dir) | |
| yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}." | |
| return temp_dir | |
| def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs, model_card_content, **kwargs): | |
| yield update_logs("Iniciando ciclo de entrenamiento...", "Entrenando") | |
| trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False) | |
| yield update_logs("Entrenamiento finalizado.", "Guardando") | |
| output_dir = trainer.args.output_dir | |
| trainer.save_model(output_dir) | |
| if tokenizer: | |
| tokenizer.save_pretrained(output_dir) | |
| with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f: | |
| f.write(model_card_content) | |
| yield update_logs("Subiendo al Hub...", "Subiendo") | |
| upload_folder(folder_path=output_dir, repo_id=repo_id, commit_message="Fin de entrenamiento") | |
| del trainer | |
| gc.collect() | |
| return output_dir | |
| def train_sft_dpo(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| is_dpo = kwargs.get('training_mode') == "DPO (Direct Preference Optimization)" | |
| text_col = kwargs.get('text_col') | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, use_fast=False) | |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token | |
| if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja'] | |
| yield update_logs(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForCausalLM, **kwargs) | |
| peft_config = None | |
| if kwargs.get('peft'): | |
| target_modules = kwargs.get('target_modules').split(",") if not kwargs.get('auto_find_target_modules') else _find_all_linear_names(model, 'no') | |
| yield update_logs(f"Módulos LoRA detectados/especificados: {target_modules}", "Configuración") | |
| peft_config = LoraConfig( | |
| r=int(kwargs.get('lora_r')), lora_alpha=int(kwargs.get('lora_alpha')), lora_dropout=float(kwargs.get('lora_dropout')), | |
| target_modules=target_modules, bias="none", task_type="CAUSAL_LM", use_dora=kwargs.get('use_dora', False), | |
| use_rslora=kwargs.get('use_rslora', False), init_lora_weights=kwargs.get('init_lora_weights', 'gaussian'), | |
| modules_to_save=kwargs.get('modules_to_save').split(',') if kwargs.get('modules_to_save') else None | |
| ) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset = yield from _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs) | |
| TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer) | |
| trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config} | |
| if is_dpo: | |
| trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2}) | |
| if eval_dataset: | |
| eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs)) | |
| else: | |
| sft_kwargs = kwargs.copy() | |
| if 'text_col' in sft_kwargs: | |
| del sft_kwargs['text_col'] | |
| trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)}) | |
| if kwargs.get('enable_loss_reweighting'): | |
| trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)}) | |
| trainer = TrainerClass(**trainer_kwargs) | |
| yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs, model_card_content, **kwargs) | |
| except Exception as e: | |
| raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}") | |
| def train_sequence_classification(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| labels = [s.strip() for s in kwargs['classification_labels'].split(',')] | |
| label2id = {l: i for i, l in enumerate(labels)} | |
| id2label = {i: l for i, l in enumerate(labels)} | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| yield update_logs(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs) | |
| def preprocess(examples): | |
| return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512) | |
| train_dataset = train_dataset.map(preprocess) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset = yield from _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs) | |
| if eval_dataset: eval_dataset = eval_dataset.map(preprocess) | |
| metric = hf_evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| predictions = np.argmax(logits, axis=-1) | |
| return metric.compute(predictions=predictions, references=labels) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, compute_metrics=compute_metrics, | |
| tokenizer=tokenizer, data_collator=DataCollatorWithPadding(tokenizer=tokenizer) | |
| ) | |
| yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs, model_card_content, **kwargs) | |
| except Exception as e: | |
| raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}") | |
| def train_token_classification(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| labels = [s.strip() for s in kwargs['classification_labels'].split(',')] | |
| label2id = {l: i for i, l in enumerate(labels)} | |
| id2label = {i: l for i, l in enumerate(labels)} | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, add_prefix_space=True) | |
| yield update_logs(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForTokenClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs) | |
| def tokenize_and_align_labels(examples): | |
| tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) | |
| labels = [] | |
| for i, label in enumerate(examples["ner_tags"]): | |
| word_ids = tokenized_inputs.word_ids(batch_index=i) | |
| previous_word_idx = None | |
| label_ids = [] | |
| for word_idx in word_ids: | |
| if word_idx is None or word_idx == previous_word_idx: | |
| label_ids.append(-100) | |
| else: | |
| label_ids.append(label[word_idx]) | |
| previous_word_idx = word_idx | |
| labels.append(label_ids) | |
| tokenized_inputs["labels"] = labels | |
| return tokenized_inputs | |
| train_dataset = train_dataset.map(tokenize_and_align_labels, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset = yield from _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs) | |
| if eval_dataset: eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True) | |
| metric = hf_evaluate.load("seqeval") | |
| def compute_metrics(p): | |
| predictions, labels = p | |
| predictions = np.argmax(predictions, axis=2) | |
| true_predictions = [[id2label[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)] | |
| true_labels = [[id2label[l] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)] | |
| results = metric.compute(predictions=true_predictions, references=true_labels) | |
| return {"precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"]} | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, | |
| data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer), | |
| compute_metrics=compute_metrics | |
| ) | |
| yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs, model_card_content, **kwargs) | |
| except Exception as e: | |
| raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}") | |
| def train_seq2seq(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| yield update_logs(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForSeq2SeqLM, **kwargs) | |
| def preprocess_function(examples): | |
| inputs = [ex[kwargs['text_col']] for ex in examples["translation"]] | |
| targets = [ex[kwargs['label_col']] for ex in examples["translation"]] | |
| model_inputs = tokenizer(inputs, max_length=128, truncation=True) | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer(targets, max_length=128, truncation=True) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| train_dataset = train_dataset.map(preprocess_function, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset = yield from _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs) | |
| if eval_dataset: eval_dataset = eval_dataset.map(preprocess_function, batched=True) | |
| metric = hf_evaluate.load("sacrebleu") | |
| def compute_metrics(eval_preds): | |
| preds, labels = eval_preds | |
| if isinstance(preds, tuple): preds = preds[0] | |
| decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
| labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
| decoded_preds = [pred.strip() for pred in decoded_preds] | |
| decoded_labels = [[label.strip()] for label in decoded_labels] | |
| result = metric.compute(predictions=decoded_preds, references=decoded_labels) | |
| return {"bleu": result["score"]} | |
| training_args_dict = _create_training_args(output_dir, repo_id, **kwargs).to_dict() | |
| training_args_dict["predict_with_generate"] = True | |
| training_args = Seq2SeqTrainingArguments(**training_args_dict) | |
| trainer = Seq2SeqTrainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, | |
| data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model), | |
| compute_metrics=compute_metrics | |
| ) | |
| yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs, model_card_content, **kwargs) | |
| except Exception as e: | |
| raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}") | |
| def train_text_to_image(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs): | |
| yield update_logs("El entrenamiento de Text-to-Image aún no está implementado.", "Error") | |
| raise NotImplementedError("El entrenamiento de difusión (Text-to-Image) es una característica planificada y aún no está completamente implementada en esta interfaz.") | |
| def _train_and_upload(**kwargs): | |
| logs, repo_link, final_model_path = "", "", None | |
| yield { | |
| training_logs: "Iniciando...", | |
| training_phase: "Inicio", | |
| repo_link_output: "", | |
| start_training_button: gr.update(value="Entrenando...", interactive=False), | |
| stop_training_button: gr.update(visible=True) | |
| } | |
| def update_logs(new_msg, phase_msg): | |
| nonlocal logs | |
| logs += f"[{phase_msg}] {new_msg}\n" | |
| return { | |
| training_logs: logs, | |
| training_phase: phase_msg, | |
| repo_link_output: repo_link | |
| } | |
| try: | |
| yield update_logs("Verificando autenticación...", "Inicio") | |
| user = whoami() | |
| username = user.get("name") | |
| if not username: | |
| raise ValueError("No se pudo obtener el nombre de usuario de Hugging Face. Por favor, verifica tu token.") | |
| model_name = kwargs.get('model_base_input', '').strip() | |
| if kwargs.get('enable_multi_adapter_merge'): | |
| temp_model_path = model_name | |
| lora_merge_generator = _merge_multiple_loras(model_name, kwargs['multi_adapter_model_ids'], kwargs['multi_adapter_weights'], kwargs['multi_adapter_combination_type']) | |
| try: | |
| while True: | |
| status = next(lora_merge_generator) | |
| yield update_logs(status, "Fusión Múltiple") | |
| except StopIteration as e: | |
| temp_model_path = e.value | |
| model_name = temp_model_path | |
| repo_name_input = kwargs.get('repo_name_input', '').strip() | |
| if repo_name_input: | |
| repo_base = re.sub(r'[^a-zA-Z0-9_.-]+', '-', repo_name_input) | |
| repo_base = re.sub(r'^[.-]+|[.-]+$', '', repo_base) | |
| else: | |
| model_name_base = model_name.split('/')[-1] if model_name else "finetuned-model" | |
| sanitized_model_name_base = re.sub(r'[^a-zA-Z0-9_.-]+', '-', model_name_base) | |
| sanitized_model_name_base = re.sub(r'^[.-]+|[.-]+$', '', sanitized_model_name_base) | |
| repo_base = f"{sanitized_model_name_base}-{uuid.uuid4().hex[:6]}" | |
| if not repo_base: | |
| repo_base = f"autotrain-model-{uuid.uuid4().hex[:8]}" | |
| max_repo_base_len = 96 - (len(username) + 1) | |
| repo_base = repo_base[:max_repo_base_len] | |
| repo_id = f"{username}/{repo_base}" | |
| yield update_logs(f"Creando o verificando repositorio: '{repo_id}'", "Inicio") | |
| create_repo(repo_id, exist_ok=True) | |
| repo_link = f"https://huggingface.co/{repo_id}" | |
| yield update_logs("Repositorio listo.", "Inicio") | |
| base_model_id_for_training = model_name | |
| if kwargs.get('train_from_scratch'): | |
| yield update_logs("Preparando entrenamiento desde cero...", "Modelo Cero") | |
| architecture = kwargs.get('scratch_architecture') | |
| if not architecture or architecture not in ARCHITECTURE_MAP: | |
| raise ValueError(f"Arquitectura '{architecture}' no es válida o no está soportada para entrenamiento desde cero. Opciones válidas: {list(ARCHITECTURE_MAP.keys())}") | |
| config_class, model_class = ARCHITECTURE_MAP[architecture] | |
| if kwargs.get('auto_config_scratch'): | |
| vocab_size, hidden_size, intermediate_size, layers, heads, block_size_val, tie_word_embeddings, kv_heads = _calculate_auto_config(kwargs.get('block_size'), architecture == "GPT2", kwargs.get('steps_per_epoch_estimate'), kwargs.get('batch_size'), kwargs.get('gradient_accumulation')) | |
| else: | |
| vocab_size, hidden_size, intermediate_size, layers, heads, kv_heads, tie_word_embeddings = 32000, 1024, 2048, 8, 8, 8, False | |
| config = config_class(vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=layers, num_attention_heads=heads, num_key_value_heads=kv_heads, max_position_embeddings=int(kwargs.get('block_size', 1024)), tie_word_embeddings=tie_word_embeddings) | |
| model = model_class(config) | |
| temp_model_dir = tempfile.mkdtemp() | |
| model.save_pretrained(temp_model_dir) | |
| tokenizer_id = kwargs.get('tokenizer_name') or SCRATCH_TOKENIZER_MAP.get(architecture, "gpt2") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
| tokenizer.save_pretrained(temp_model_dir) | |
| yield update_logs(f"Tokenizer base '{tokenizer_id}' guardado para el modelo desde cero.", "Modelo Cero") | |
| except Exception as e: | |
| raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}") | |
| base_model_id_for_training = temp_model_dir | |
| kwargs["peft"] = False | |
| kwargs["merge_adapter"] = False | |
| kwargs['tokenizer_name'] = temp_model_dir | |
| yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") | |
| hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()] | |
| if not hf_ids and not kwargs.get('uploads'): | |
| raise ValueError("No se proporcionaron datasets.") | |
| dataset_weights_str = kwargs.get('dataset_weights', '') | |
| probabilities = None | |
| if dataset_weights_str: | |
| try: | |
| probabilities = [float(w.strip()) for w in dataset_weights_str.split(',')] | |
| except ValueError: | |
| yield update_logs("Pesos de dataset inválidos. Se ignorarán.", "Datos") | |
| probabilities = None | |
| train_dataset, uploaded_val_data = None, None | |
| if kwargs.get('uploads'): | |
| uploaded_data_map = _load_uploaded_stream(kwargs.get('uploads')) | |
| if uploaded_data_map and uploaded_data_map["train"]: | |
| train_dataset = IterableDataset.from_generator(lambda: iter(uploaded_data_map["train"])) | |
| uploaded_val_data = uploaded_data_map["validation"] | |
| yield update_logs(f"Cargados {len(uploaded_data_map['train'])} ejemplos de archivos locales.", "Datos") | |
| if hf_ids: | |
| hf_train_dataset = _load_hf_streaming(hf_ids, split="train", probabilities=probabilities if not train_dataset else None) | |
| if hf_train_dataset: | |
| if train_dataset is None: | |
| train_dataset = hf_train_dataset | |
| else: | |
| all_streams = [train_dataset, hf_train_dataset] | |
| all_probs = [0.5, 0.5] if not probabilities else [probabilities[0]] + probabilities[1:] | |
| train_dataset = interleave_datasets(all_streams, probabilities=all_probs) | |
| if train_dataset is None: | |
| raise ValueError("No se pudieron cargar datos de entrenamiento válidos.") | |
| first_example = next(iter(train_dataset)) | |
| text_col, image_col, audio_col, label_col = _guess_columns(first_example) | |
| kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data}) | |
| yield update_logs(f"Columnas detectadas (texto: {text_col})", "Datos") | |
| if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]): | |
| yield update_logs("Aplicando normalización y limpieza de texto...", "Datos") | |
| clean_kwargs = kwargs.copy() | |
| if 'text_col' in clean_kwargs: | |
| del clean_kwargs['text_col'] | |
| train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs)) | |
| filters = _get_filter_functions(**kwargs) | |
| if filters: | |
| yield update_logs(f"Aplicando {len(filters)} filtro(s) de calidad y contenido...", "Datos") | |
| for f in filters: | |
| train_dataset = train_dataset.filter(f) | |
| if kwargs.get('enable_back_translation'): | |
| train_dataset = _apply_back_translation(train_dataset, text_col, kwargs['bt_augmentation_ratio'], kwargs['bt_model_id'], kwargs['bt_reverse_model_id']) | |
| if kwargs.get('enable_synthetic_data'): | |
| synthetic_ds = _generate_synthetic_data(train_dataset, text_col, kwargs['synthetic_model_id'], int(kwargs['num_synthetic_samples']), kwargs['synthetic_prompt_template']) | |
| if synthetic_ds: | |
| yield update_logs(f"Mezclando dataset con datos sintéticos...", "Datos") | |
| train_dataset = interleave_datasets([train_dataset, synthetic_ds]) | |
| if kwargs.get('enable_cda') and kwargs.get('cda_json_config'): | |
| yield update_logs("Aplicando Aumentación de Datos Contrafactual...", "Datos") | |
| train_dataset = _apply_cda(train_dataset, text_col, kwargs['cda_json_config']) | |
| if kwargs.get('deduplication_method') != 'Ninguna': | |
| yield update_logs(f"Aplicando deduplicación ({kwargs['deduplication_method']})...", "Datos") | |
| train_dataset = DeduplicatedIterableDataset( | |
| dataset=train_dataset, | |
| text_col=text_col, | |
| method=kwargs['deduplication_method'], | |
| threshold=kwargs['minhash_threshold'], | |
| num_perm=kwargs['minhash_num_perm'] | |
| ) | |
| if kwargs.get('wandb_api_key_input'): | |
| os.environ["WANDB_API_KEY"] = kwargs['wandb_api_key_input'] | |
| os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}" | |
| os.environ["WANDB_LOG_MODEL"] = "checkpoint" | |
| from datetime import datetime | |
| model_card_content = MODEL_CARD_TEMPLATE.format( | |
| repo_id=repo_id, | |
| base_model=model_name, | |
| base_model_name=model_name.split('/')[-1], | |
| training_mode=kwargs.get('training_mode'), | |
| datasets=', '.join(hf_ids) if hf_ids else "Archivos locales", | |
| hyperparameters=json.dumps({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool)) and 'token' not in k and 'key' not in k and v is not None}, indent=2), | |
| date=datetime.now().strftime("%Y-%m-%d") | |
| ) | |
| training_mode = kwargs.get('training_mode') | |
| training_function_map = { | |
| "Causal Language Modeling (SFT/LoRA)": train_sft_dpo, | |
| "DPO (Direct Preference Optimization)": train_sft_dpo, | |
| "Sequence Classification (Text)": train_sequence_classification, | |
| "Token Classification (NER)": train_token_classification, | |
| "Text2Text Generation": train_seq2seq, | |
| "Text-to-Image Generation": train_text_to_image, | |
| } | |
| train_func = training_function_map.get(training_mode) | |
| if train_func: | |
| train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, lambda m, p: update_logs(m, p), model_card_content, **kwargs) | |
| while True: | |
| try: | |
| update = next(train_generator) | |
| yield update | |
| except StopIteration as e: | |
| final_model_path = e.value | |
| break | |
| else: | |
| raise ValueError(f"El modo de entrenamiento '{training_mode}' no está implementado.") | |
| if kwargs.get('run_perplexity_evaluation') and kwargs.get('run_evaluation') and final_model_path and training_mode in ["Causal Language Modeling (SFT/LoRA)", "DPO (Direct Preference Optimization)"]: | |
| yield update_logs("Iniciando evaluación de perplejidad...", "Evaluación Final") | |
| model = AutoModelForCausalLM.from_pretrained(final_model_path, torch_dtype=torch_dtype_auto) | |
| tokenizer = AutoTokenizer.from_pretrained(final_model_path) | |
| eval_dataset_perp = yield from _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), uploaded_val_data, lambda m, p: update_logs(m, p)) | |
| if eval_dataset_perp: | |
| ppl = _evaluate_perplexity(model, tokenizer, eval_dataset_perp, text_col) | |
| yield update_logs(f"Evaluación de Perplejidad completada. Perplejidad: {ppl:.4f}", "Evaluación Final") | |
| final_log_update = update_logs(f"✅ Entrenamiento y subida completados: {repo_link}", "Listo") | |
| final_log_update.update({ | |
| start_training_button: gr.update(value="Iniciar Entrenamiento", interactive=True), | |
| stop_training_button: gr.update(visible=False), | |
| repo_link_output: f"### ✅ [Modelo Finalizado: Visita el Repositorio en el Hub]({repo_link})" | |
| }) | |
| yield final_log_update | |
| except Exception as e: | |
| err_msg = f"❌ Error fatal: {type(e).__name__}: {e}\n{traceback.format_exc()}" | |
| error_update = update_logs(err_msg, "Error") | |
| error_update.update({ | |
| start_training_button: gr.update(value="Iniciar Entrenamiento", interactive=True), | |
| stop_training_button: gr.update(visible=False) | |
| }) | |
| yield error_update | |
| def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in): | |
| if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| task_name = TASK_TO_PIPELINE_MAP.get(task_mode) | |
| if not task_name: return f"La inferencia para el modo '{task_mode}' no está soportada.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| try: | |
| pipe = pipeline(task_name, model=model_id, torch_dtype=torch_dtype_auto, trust_remote_code=True) | |
| result = None | |
| if task_name == "text-generation": | |
| if not text_in: return "Por favor, introduce un prompt de texto.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(text_in, max_new_tokens=100, do_sample=True, temperature=0.7, top_p=0.95) | |
| elif task_name == "question-answering": | |
| if not text_in or not context_in: return "Por favor, introduce una pregunta y un contexto.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(question=text_in, context=context_in) | |
| elif task_name in ["token-classification", "text2text-generation", "text-classification"]: | |
| if not text_in: return f"Por favor, introduce texto para {task_name}.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(text_in) | |
| elif task_name in ["image-classification", "audio-classification", "automatic-speech-recognition"]: | |
| input_data = image_in if "image" in task_name else audio_in | |
| if input_data is None: return f"Por favor, proporciona una entrada de { 'imagen' if 'image' in task_name else 'audio' }.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(input_data) | |
| return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| def update_inference_ui(task_mode): | |
| task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "") | |
| show_text = task_name in ["text-generation", "text2text-generation", "token-classification", "question-answering", "text-classification", "text-to-image"] | |
| show_context = task_name == "question-answering" | |
| show_image = task_name in ["image-classification"] | |
| show_audio = task_name in ["audio-classification", "automatic-speech-recognition"] | |
| text_label = "Pregunta" if task_name == "question-answering" else "Entrada de Texto / Prompt" | |
| context_label = "Contexto (para QA)" | |
| return gr.update(visible=show_text, label=text_label), gr.update(visible=show_context, label=context_label), gr.update(visible=show_image), gr.update(visible=show_audio) | |
| def gradio_train_wrapper(*args): | |
| all_input_keys = [ | |
| "training_mode", "model_base_input", "tokenizer_name_input", "repo_name_input", "train_from_scratch", "auto_config_scratch", "scratch_architecture", | |
| "enable_multi_adapter_merge", "multi_adapter_model_ids", "multi_adapter_weights", "multi_adapter_combination_type", | |
| "datasets_hf_text", "uploads", "dataset_weights", "eval_dataset_hf", | |
| "learning_rate", "epochs", "batch_size", "gradient_accumulation", "block_size", "max_train_samples", "optimizer", "scheduler", "mixed_precision", | |
| "warmup_ratio", "weight_decay", "max_grad_norm", "logging_steps", "save_steps", "save_total_limit", | |
| "adam_beta1", "adam_beta2", "adam_epsilon", | |
| "disable_gradient_checkpointing", "group_by_length", "packing", "neftune_noise_alpha", "optim_args", "attn_implementation", | |
| "peft", "merge_adapter", "quantization", "lora_r", "lora_alpha", "lora_dropout", "auto_find_target_modules", "target_modules", "modules_to_save", "use_dora", "use_rslora", "init_lora_weights", | |
| "remove_html_tags", "normalize_whitespace", "remove_urls_emails", "redact_pii", | |
| "enable_quality_filter", "min_len_input", "max_len_input", "rep_threshold_input", "exclude_keywords_input", | |
| "enable_language_filter", "allowed_languages", "language_detection_threshold", "enable_toxicity_filter", "toxicity_threshold", | |
| "deduplication_method", "minhash_threshold", "minhash_num_perm", | |
| "enable_cda", "cda_json_config", "enable_back_translation", "bt_augmentation_ratio", "bt_model_id", "bt_reverse_model_id", | |
| "enable_synthetic_data", "synthetic_model_id", "num_synthetic_samples", "synthetic_prompt_template", | |
| "format_as_conversation", "chat_template_jinja", "prompt_col_input", "dpo_chosen_col_input", "dpo_rejected_col_input", | |
| "enable_cot_input", "reasoning_col_input", "enable_tool_use_input", "tool_use_col_input", "response_col_input", | |
| "classification_labels", | |
| "diffusion_resolution", | |
| "run_evaluation", "metric_for_best_model", "greater_is_better", "run_perplexity_evaluation", | |
| "enable_loss_reweighting", "reweighting_terms", "reweighting_factor", | |
| "wandb_api_key_input", "wandb_project_input" | |
| ] | |
| kwargs = dict(zip(all_input_keys, args)) | |
| yield from _train_and_upload(**kwargs) | |
| def toggle_training_mode_ui(is_scratch): | |
| return { | |
| model_base_input: gr.update(visible=not is_scratch), | |
| tokenizer_name_input: gr.update(visible=not is_scratch), | |
| multi_adapter_accordion: gr.update(visible=not is_scratch), | |
| peft_accordion: gr.update(visible=not is_scratch), | |
| auto_config_scratch: gr.update(visible=is_scratch), | |
| scratch_architecture: gr.update(visible=is_scratch), | |
| } | |
| def toggle_task_specific_ui(training_mode): | |
| is_classification = "Classification" in training_mode | |
| is_dpo = "DPO" in training_mode | |
| is_sft = "Causal" in training_mode | |
| is_ner = "Token Classification" in training_mode | |
| is_diffusion = "Image Generation" in training_mode | |
| return { | |
| classification_labels_ui: gr.update(visible=is_classification or is_ner), | |
| dpo_ui: gr.update(visible=is_dpo), | |
| sft_ui: gr.update(visible=is_sft), | |
| diffusion_ui: gr.update(visible=is_diffusion), | |
| peft_accordion: gr.update(visible=not is_diffusion), | |
| } | |
| def toggle_auto_modules_ui(is_auto): | |
| return gr.update(visible=not is_auto) | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown("# 🚀 AutoTrain-Advanced: Tu Plataforma de Entrenamiento de Modelos") | |
| gr.Markdown("### Una interfaz completa para fine-tuning, PEFT (LoRA, QLoRA), y despliegue de modelos en Hugging Face.") | |
| with gr.Tab("1. Autenticación"): | |
| gr.Markdown("#### Conecta tu cuenta de Hugging Face para guardar y cargar modelos.") | |
| with gr.Row(): | |
| hf_token_input = gr.Textbox(label="Token de Hugging Face (con permisos de escritura)", type="password", placeholder="hf_...", scale=3) | |
| login_button = gr.Button("Conectar", variant="primary", scale=1) | |
| login_status = gr.Textbox(label="Estado de Conexión", interactive=False) | |
| login_button.click(hf_login, inputs=[hf_token_input], outputs=[login_status]) | |
| with gr.Tab("2. Entrenamiento"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## ⚙️ Configuración del Entrenamiento") | |
| training_mode = gr.Dropdown(TRAINING_MODES, label="Modo de Entrenamiento", value=TRAINING_MODES[0]) | |
| with gr.Accordion("📦 Modelo y Repositorio", open=True): | |
| model_base_input = gr.Textbox(label="ID del Modelo Base", placeholder="p.ej. 'mistralai/Mistral-7B-v0.1' o 'stabilityai/stable-diffusion-2-1-base'") | |
| tokenizer_name_input = gr.Textbox(label="ID del Tokenizer (opcional)", placeholder="p.ej. si el modelo no tiene tokenizer o quieres usar otro") | |
| repo_name_input = gr.Textbox(label="Nombre del Repositorio de Destino", placeholder="p.ej. 'mi-modelo-afinado'") | |
| train_from_scratch = gr.Checkbox(label="Entrenar desde Cero", value=False) | |
| auto_config_scratch = gr.Checkbox(label="Auto-Configuración", value=True, visible=False) | |
| scratch_architecture = gr.Textbox(label="Arquitectura (p.ej. Llama, Mistral, GPT2)", value="Llama", visible=False) | |
| with gr.Accordion("🔄 Fusión de Múltiples Adaptadores (Avanzado)", open=False) as multi_adapter_accordion: | |
| enable_multi_adapter_merge = gr.Checkbox(label="Habilitar Fusión Múltiple", value=False) | |
| multi_adapter_model_ids = gr.Textbox(label="IDs de Adaptadores (separados por comas)", placeholder="org/adapter1,org/adapter2") | |
| multi_adapter_weights = gr.Textbox(label="Pesos (separados por comas)", placeholder="0.5,0.5") | |
| multi_adapter_combination_type = gr.Dropdown(["slerp", "linear", "cat", "svd", "dare_linear", "dare_ties", "ties"], label="Tipo de Combinación", value="slerp") | |
| with gr.Accordion("📚 Dataset", open=True): | |
| datasets_hf_text = gr.Textbox(label="Datasets de Hugging Face (separados por comas)", placeholder="p.ej. 'databricks/databricks-dolly-15k' o 'lambdalabs/pokemon-blip-captions'") | |
| uploads = gr.File(label="Subir Archivos Locales (.jsonl, .csv, .txt)", file_count="multiple") | |
| dataset_weights = gr.Textbox(label="Pesos de los Datasets (opcional, csv)", placeholder="p.ej. 0.7, 0.3") | |
| eval_dataset_hf = gr.Textbox(label="Dataset de Evaluación (opcional)", placeholder="p.ej. 'nombre/dataset_eval'") | |
| with gr.Accordion("🎓 Hiperparámetros", open=False): | |
| with gr.Row(): | |
| learning_rate = gr.Textbox(label="Tasa de Aprendizaje", value="2e-5") | |
| epochs = gr.Textbox(label="Épocas", value="1") | |
| batch_size = gr.Textbox(label="Tamaño de Lote", value="1") | |
| gradient_accumulation = gr.Textbox(label="Acumulación de Gradiente", value="8") | |
| with gr.Row(): | |
| block_size = gr.Textbox(label="Longitud de Secuencia", value="1024") | |
| max_train_samples = gr.Textbox(label="Máx. Muestras de Entrenamiento", value="10000") | |
| optimizer = gr.Dropdown(["adamw_torch", "sgd", "adagrad"], label="Optimizador", value="adamw_torch") | |
| scheduler = gr.Dropdown(["cosine", "linear", "constant", "polynomial"], label="Planificador LR", value="cosine") | |
| mixed_precision = gr.Radio(["no"], label="Precisión Mixta (Solo GPU)", value="no", interactive=False) | |
| with gr.Accordion("Avanzados", open=False): | |
| warmup_ratio = gr.Slider(minimum=0.0, maximum=0.5, step=0.01, label="Ratio de Calentamiento", value=0.03) | |
| weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01") | |
| max_grad_norm = gr.Textbox(label="Norma Máxima de Gradiente", value="0.3") | |
| logging_steps = gr.Textbox(label="Pasos de Registro", value="10") | |
| save_steps = gr.Textbox(label="Pasos de Guardado", value="50") | |
| save_total_limit = gr.Textbox(label="Límite Total de Guardado", value="1") | |
| with gr.Row(): | |
| adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9") | |
| adam_beta2 = gr.Textbox(label="Adam Beta2", value="0.999") | |
| adam_epsilon = gr.Textbox(label="Adam Epsilon", value="1e-8") | |
| disable_gradient_checkpointing = gr.Checkbox(label="Deshabilitar Gradient Checkpointing", value=False) | |
| group_by_length = gr.Checkbox(label="Agrupar por Longitud", value=False) | |
| packing = gr.Checkbox(label="Packing", value=False) | |
| neftune_noise_alpha = gr.Textbox(label="NEFTune Ruido Alfa (0 para desactivar)", value="0") | |
| optim_args = gr.Textbox(label="Argumentos del Optimizador (formato dict)", placeholder="ej: betas=(0.9,0.995)") | |
| attn_implementation = gr.Dropdown(["eager"], label="Implementación de Atención", value="eager", interactive=False) | |
| with gr.Accordion("🦋 PEFT (LoRA / QLoRA)", open=True) as peft_accordion: | |
| peft = gr.Checkbox(label="Habilitar PEFT/LoRA", value=True) | |
| merge_adapter = gr.Checkbox(label="Fusionar Adaptador al Final", value=False) | |
| quantization = gr.Dropdown(["no"], label="Cuantización (Solo GPU)", value="no", interactive=False) | |
| with gr.Row(): | |
| lora_r = gr.Textbox(label="LoRA r", value="16") | |
| lora_alpha = gr.Textbox(label="LoRA alpha", value="32") | |
| lora_dropout = gr.Textbox(label="LoRA dropout", value="0.05") | |
| auto_find_target_modules = gr.Checkbox(label="Auto-encontrar Módulos de Destino", value=True) | |
| target_modules = gr.Textbox(label="Módulos de Destino (separados por comas)", placeholder="q_proj,v_proj", visible=False) | |
| modules_to_save = gr.Textbox(label="Módulos a Guardar (separados por comas)", placeholder="embed_tokens,lm_head") | |
| with gr.Row(): | |
| use_dora = gr.Checkbox(label="Usar DoRA", value=False) | |
| use_rslora = gr.Checkbox(label="Usar RSLora", value=False) | |
| init_lora_weights = gr.Dropdown(["gaussian", "loftq", "pissa"], label="Inicialización de Pesos LoRA", value="gaussian") | |
| with gr.Accordion("🧹 Procesamiento y Aumentación de Datos", open=False): | |
| with gr.Tab("Limpieza y Normalización"): | |
| remove_html_tags = gr.Checkbox(label="Eliminar Etiquetas HTML", value=True) | |
| normalize_whitespace = gr.Checkbox(label="Normalizar Espacios en Blanco", value=True) | |
| remove_urls_emails = gr.Checkbox(label="Eliminar URLs/Emails", value=True) | |
| redact_pii = gr.Checkbox(label="Redactar PII (Teléfonos, Emails, IPs)", value=True) | |
| with gr.Tab("Filtrado"): | |
| enable_quality_filter = gr.Checkbox(label="Habilitar Filtros de Calidad Básicos", value=True) | |
| min_len_input = gr.Slider(1, 100, 10, label="Longitud Mínima (palabras)") | |
| max_len_input = gr.Slider(100, 5000, 2000, label="Longitud Máxima (palabras)") | |
| rep_threshold_input = gr.Slider(0, 1, 0.2, label="Umbral de Repetición de Palabras") | |
| exclude_keywords_input = gr.Textbox(label="Palabras Clave a Excluir (csv)") | |
| enable_language_filter = gr.Checkbox(label="Habilitar Filtro de Idioma", value=False) | |
| allowed_languages = gr.Textbox(label="Idiomas Permitidos (códigos ISO, csv)", value="en,es") | |
| language_detection_threshold = gr.Slider(0.5, 1.0, 0.95, label="Umbral de Detección de Idioma") | |
| enable_toxicity_filter = gr.Checkbox(label="Habilitar Filtro de Toxicidad", value=False) | |
| toxicity_threshold = gr.Slider(0.5, 1.0, 0.8, label="Umbral de Toxicidad") | |
| with gr.Tab("Deduplicación"): | |
| deduplication_method = gr.Radio(["Ninguna", "Exacta", "Semántica (MinHash)"], label="Método de Deduplicación", value="Ninguna") | |
| minhash_threshold = gr.Slider(0.5, 1.0, 0.85, label="Umbral MinHash", visible=False) | |
| minhash_num_perm = gr.Slider(64, 512, 128, step=64, label="Permutaciones MinHash", visible=False) | |
| deduplication_method.change(lambda x: (gr.update(visible=x=="Semántica (MinHash)"), gr.update(visible=x=="Semántica (MinHash)")), inputs=[deduplication_method], outputs=[minhash_threshold, minhash_num_perm]) | |
| with gr.Tab("Aumentación"): | |
| enable_cda = gr.Checkbox(label="Habilitar Aumentación Contrafactual (CDA)", value=False) | |
| cda_json_config = gr.Textbox(label="Configuración CDA (JSON)", placeholder='[["she", "he"], ["woman", "man"]]') | |
| enable_back_translation = gr.Checkbox(label="Habilitar Retrotraducción", value=False) | |
| bt_augmentation_ratio = gr.Slider(0.0, 1.0, 0.1, label="Ratio de Aumentación BT") | |
| bt_model_id = gr.Textbox(label="Modelo de Traducción (p.ej. a DE)", value="Helsinki-NLP/opus-mt-en-de") | |
| bt_reverse_model_id = gr.Textbox(label="Modelo de Traducción Inversa (p.ej. a EN)", value="Helsinki-NLP/opus-mt-de-en") | |
| with gr.Tab("Generación Sintética"): | |
| enable_synthetic_data = gr.Checkbox(label="Habilitar Generación de Datos Sintéticos", value=False) | |
| synthetic_model_id = gr.Textbox(label="ID del Modelo Generador", placeholder="p.ej. 'mistralai/Mistral-7B-Instruct-v0.2'") | |
| num_synthetic_samples = gr.Number(label="Número de Muestras Sintéticas", value=1000) | |
| synthetic_prompt_template = gr.Textbox(label="Plantilla de Prompt Sintético", value="Given the following text, create a new, similar example.\n\nExample:\n{{ example_text }}\n\nNew example:", lines=5) | |
| with gr.Accordion("📝 Configuración de Formato y Tarea", open=False): | |
| with gr.Group(visible=False) as diffusion_ui: | |
| diffusion_resolution = gr.Slider(label="Resolución de Imagen", minimum=256, maximum=1024, value=512, step=64) | |
| with gr.Group(visible=False) as classification_labels_ui: | |
| classification_labels = gr.Textbox(label="Etiquetas de Clasificación (separadas por comas)", placeholder="p.ej. positivo,negativo,neutro") | |
| with gr.Group(visible=False) as dpo_ui: | |
| prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt") | |
| dpo_chosen_col_input = gr.Textbox(label="Columna de Respuesta Elegida", value="chosen") | |
| dpo_rejected_col_input = gr.Textbox(label="Columna de Respuesta Rechazada", value="rejected") | |
| with gr.Group(visible=True) as sft_ui: | |
| format_as_conversation = gr.Checkbox(label="Formatear como Conversación (experimental)", value=False) | |
| chat_template_jinja = gr.Textbox(label="Plantilla de Chat Jinja2 (opcional)", lines=5) | |
| enable_cot_input = gr.Checkbox(label="Formato Chain-of-Thought", value=False) | |
| reasoning_col_input = gr.Textbox(label="Columna de Razonamiento", value="reasoning") | |
| enable_tool_use_input = gr.Checkbox(label="Formato de Uso de Herramientas", value=False) | |
| tool_use_col_input = gr.Textbox(label="Columna de Uso de Herramientas", value="tools") | |
| response_col_input = gr.Textbox(label="Columna de Respuesta Final", value="response") | |
| with gr.Accordion("📊 Evaluación y Mitigación de Sesgos", open=False): | |
| run_evaluation = gr.Checkbox(label="Ejecutar Evaluación en el Conjunto de Validación", value=False) | |
| metric_for_best_model = gr.Textbox(label="Métrica para el Mejor Modelo", value="loss") | |
| greater_is_better = gr.Checkbox(label="¿Métrica Mayor es Mejor?", value=False) | |
| run_perplexity_evaluation = gr.Checkbox(label="Calcular Perplejidad al Final", value=True) | |
| with gr.Tab("Mitigación de Sesgos"): | |
| enable_loss_reweighting = gr.Checkbox(label="Habilitar Re-ponderación de Pérdida", value=False) | |
| reweighting_terms = gr.Textbox(label="Términos para Re-ponderar (csv)", placeholder="sesgo,injusto") | |
| reweighting_factor = gr.Slider(1.0, 10.0, 2.0, label="Factor de Re-ponderación") | |
| with gr.Accordion("🔌 Integraciones", open=False): | |
| wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password") | |
| wandb_project_input = gr.Textbox(label="Proyecto W&B") | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 📈 Progreso y Resultados") | |
| with gr.Row(): | |
| start_training_button = gr.Button("Iniciar Entrenamiento", variant="primary", scale=3) | |
| stop_training_button = gr.Button("Detener", variant="stop", visible=False, scale=1) | |
| training_phase = gr.Label(label="Fase Actual", value="En espera") | |
| training_logs = gr.Textbox(label="Registros de Entrenamiento", lines=35, interactive=False) | |
| repo_link_output = gr.Markdown(label="Enlace al Repositorio del Modelo") | |
| all_inputs = [ | |
| training_mode, model_base_input, tokenizer_name_input, repo_name_input, train_from_scratch, auto_config_scratch, scratch_architecture, | |
| enable_multi_adapter_merge, multi_adapter_model_ids, multi_adapter_weights, multi_adapter_combination_type, | |
| datasets_hf_text, uploads, dataset_weights, eval_dataset_hf, | |
| learning_rate, epochs, batch_size, gradient_accumulation, block_size, max_train_samples, optimizer, scheduler, mixed_precision, | |
| warmup_ratio, weight_decay, max_grad_norm, logging_steps, save_steps, save_total_limit, | |
| adam_beta1, adam_beta2, adam_epsilon, | |
| disable_gradient_checkpointing, group_by_length, packing, neftune_noise_alpha, optim_args, attn_implementation, | |
| peft, merge_adapter, quantization, lora_r, lora_alpha, lora_dropout, auto_find_target_modules, target_modules, modules_to_save, use_dora, use_rslora, init_lora_weights, | |
| remove_html_tags, normalize_whitespace, remove_urls_emails, redact_pii, | |
| enable_quality_filter, min_len_input, max_len_input, rep_threshold_input, exclude_keywords_input, | |
| enable_language_filter, allowed_languages, language_detection_threshold, enable_toxicity_filter, toxicity_threshold, | |
| deduplication_method, minhash_threshold, minhash_num_perm, | |
| enable_cda, cda_json_config, enable_back_translation, bt_augmentation_ratio, bt_model_id, bt_reverse_model_id, | |
| enable_synthetic_data, synthetic_model_id, num_synthetic_samples, synthetic_prompt_template, | |
| format_as_conversation, chat_template_jinja, prompt_col_input, dpo_chosen_col_input, dpo_rejected_col_input, | |
| enable_cot_input, reasoning_col_input, enable_tool_use_input, tool_use_col_input, response_col_input, | |
| classification_labels, | |
| diffusion_resolution, | |
| run_evaluation, metric_for_best_model, greater_is_better, run_perplexity_evaluation, | |
| enable_loss_reweighting, reweighting_terms, reweighting_factor, | |
| wandb_api_key_input, wandb_project_input | |
| ] | |
| all_outputs = [training_logs, training_phase, repo_link_output, start_training_button, stop_training_button] | |
| train_from_scratch.change( | |
| toggle_training_mode_ui, | |
| inputs=[train_from_scratch], | |
| outputs=[model_base_input, tokenizer_name_input, multi_adapter_accordion, peft_accordion, auto_config_scratch, scratch_architecture] | |
| ) | |
| training_mode.change( | |
| toggle_task_specific_ui, | |
| inputs=[training_mode], | |
| outputs=[classification_labels_ui, dpo_ui, sft_ui, diffusion_ui, peft_accordion] | |
| ) | |
| auto_find_target_modules.change( | |
| toggle_auto_modules_ui, | |
| inputs=[auto_find_target_modules], | |
| outputs=[target_modules] | |
| ) | |
| train_event = start_training_button.click( | |
| gradio_train_wrapper, | |
| inputs=all_inputs, | |
| outputs=all_outputs | |
| ) | |
| stop_training_button.click(fn=None, inputs=None, outputs=None, cancels=[train_event]) | |
| with gr.Tab("3. Inferencia"): | |
| gr.Markdown("## 🧪 Probar un Modelo del Hub") | |
| gr.Markdown("Carga cualquier modelo compatible desde el Hub de Hugging Face y pruébalo directamente aquí.") | |
| with gr.Row(): | |
| inf_task_mode = gr.Dropdown(TRAINING_MODES, label="Tipo de Tarea", value=TRAINING_MODES[0]) | |
| inf_model_id = gr.Textbox(label="ID del Modelo en el Hub", placeholder="TuUsuario/TuModeloEntrenado") | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| inf_text_in = gr.Textbox(label="Entrada de Texto / Prompt", lines=5) | |
| inf_context_in = gr.Textbox(label="Contexto (para QA)", lines=3, visible=False) | |
| inf_image_in = gr.Image(label="Entrada de Imagen", type="pil", visible=False) | |
| inf_audio_in = gr.Audio(label="Entrada de Audio", type="filepath", visible=False) | |
| with gr.Row(): | |
| run_inference_btn = gr.Button("Ejecutar Inferencia", variant="primary") | |
| with gr.Column(scale=3): | |
| inf_text_out = gr.Textbox(label="Salida de Texto", lines=10, interactive=False) | |
| inf_image_out = gr.Image(label="Salida de Imagen", visible=False) | |
| inf_audio_out = gr.Audio(label="Salida de Audio", visible=False) | |
| inf_task_mode.change( | |
| update_inference_ui, | |
| inputs=[inf_task_mode], | |
| outputs=[inf_text_in, inf_context_in, inf_image_in, inf_audio_in] | |
| ) | |
| run_inference_btn.click( | |
| run_inference, | |
| inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in], | |
| outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_out, inf_audio_out] | |
| ) | |
| demo.launch(debug=True) |