Spaces:
Runtime error
Runtime error
| import os | |
| os.system("pip install -U transformers peft accelerate trl bitsandbytes datasets diffusers") | |
| os.system("pip install spaces-0.1.0-py3-none-any.whl") | |
| import io | |
| import json | |
| import tempfile | |
| import string | |
| import gc | |
| import math | |
| import uuid | |
| import logging | |
| import traceback | |
| import importlib | |
| import random | |
| import re | |
| import ast | |
| from itertools import islice | |
| from pathlib import Path | |
| from collections import defaultdict | |
| from datetime import datetime | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| import numpy as np | |
| import pandas as pd | |
| import accelerate | |
| from PIL import Image | |
| import torchvision | |
| import torchvision.transforms as T | |
| from torchvision import transforms | |
| import torchaudio | |
| from bs4 import BeautifulSoup | |
| from langdetect import detect_langs | |
| import textstat | |
| from datasketch import MinHash, MinHashLSH | |
| import gradio as gr | |
| from datasets import load_dataset, IterableDataset, Dataset as HFDataset, DatasetDict, interleave_datasets, Audio | |
| from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi, hf_hub_download, list_repo_files | |
| from transformers import ( | |
| AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer, | |
| AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, | |
| SpeechT5ForTextToSpeech, SpeechT5Processor, SpeechT5HifiGan, AutoModelForImageClassification, | |
| AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification, | |
| DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer, | |
| DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig, | |
| LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel, | |
| PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM, | |
| DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor | |
| ) | |
| from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training, AdaLoraConfig | |
| from trl import SFTTrainer, DPOTrainer | |
| from diffusers import ( | |
| UNet2DConditionModel, DDPMScheduler, AutoencoderKL, DiffusionPipeline, | |
| get_scheduler as get_diffusers_scheduler, StableDiffusionPipeline as StableDiffusionText2ImagePipeline, | |
| StableDiffusionImg2ImgPipeline as StableDiffusionImage2ImagePipeline, | |
| get_cosine_schedule_with_warmup | |
| ) | |
| import evaluate as hf_evaluate | |
| from jinja2 import Template | |
| import spaces | |
| from tqdm.auto import tqdm | |
| logger = logging.getLogger(__name__) | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| torch_dtype_auto = torch.float16 | |
| else: | |
| device = "cpu" | |
| torch_dtype_auto = torch.float32 | |
| ARCHITECTURE_MAP = {"Llama": (LlamaConfig, LlamaForCausalLM), "Mistral": (MistralConfig, MistralForCausalLM), "Gemma": (GemmaConfig, GemmaForCausalLM), "GPT2": (GPT2Config, GPT2LMHeadModel), "Phi": (PhiConfig, PhiForCausalLM), "Qwen2": (Qwen2Config, Qwen2ForCausalLM)} | |
| SCRATCH_TOKENIZER_MAP = {"Llama": "meta-llama/Llama-2-7b-hf", "Mistral": "mistralai/Mistral-7B-v0.1", "Gemma": "google/gemma-2b", "GPT2": "gpt2", "Phi": "microsoft/phi-2", "Qwen2": "Qwen/Qwen2-0.5B"} | |
| TRAINING_MODES = [ | |
| "Causal Language Modeling (SFT/LoRA)", | |
| "DPO (Direct Preference Optimization)", | |
| "Question Answering (Text)", | |
| "Token Classification (NER)", | |
| "Sequence Classification (Text)", | |
| "Text-to-Image (LoRA)", | |
| "DreamBooth LoRA (Text-to-Image)", | |
| "Image Classification (Vision)", | |
| "Audio Classification (Speech)", | |
| "ASR (Speech-to-Text)", | |
| "Text2Text Generation" | |
| ] | |
| TASK_TO_PIPELINE_MAP = { | |
| "Causal Language Modeling (SFT/LoRA)": "text-generation", | |
| "DPO (Direct Preference Optimization)": "text-generation", | |
| "Question Answering (Text)": "question-answering", | |
| "Token Classification (NER)": "token-classification", | |
| "Sequence Classification (Text)": "text-classification", | |
| "Image Classification (Vision)": "image-classification", | |
| "Audio Classification (Speech)": "audio-classification", | |
| "ASR (Speech-to-Text)": "automatic-speech-recognition", | |
| "Text2Text Generation": "text2text-generation", | |
| "Text-to-Image (LoRA)": "text-to-image", | |
| "DreamBooth LoRA (Text-to-Image)": "text-to-image", | |
| } | |
| MODEL_CARD_TEMPLATE = """--- | |
| language: es | |
| license: apache-2.0 | |
| tags: | |
| - autotrain-advanced | |
| - fine-tuned | |
| - {base_model_name} | |
| widget: | |
| - text: "Hola, ¿cómo estás?" | |
| --- | |
| # {repo_id} | |
| Este modelo es una versión afinada de [{base_model}](https://huggingface.co/{base_model}) entrenado con la herramienta [AutoTrain-Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced). | |
| ## Detalles del Entrenamiento | |
| - **Modo de Entrenamiento:** {training_mode} | |
| - **Modelo Base:** `{base_model}` | |
| - **Datasets:** `{datasets}` | |
| - **Entrenado en:** {date} | |
| ### Hiperparámetros de Entrenamiento | |
| ```json | |
| {hyperparameters}``` | |
| ### Frameworks Utilizados | |
| - Transformers | |
| - PEFT | |
| - BitsAndBytes | |
| - Accelerate | |
| - TRL | |
| - Diffusers | |
| - Gradio | |
| """ | |
| DATASET_CARD_TEMPLATE = """--- | |
| license: mit | |
| --- | |
| # {repo_id} | |
| Este dataset fue creado utilizando la herramienta [AutoTrain-Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced). | |
| ## Detalles del Dataset | |
| - **Tipo de Creación:** {creation_type} | |
| - **Modelo de Generación (si aplica):** `{generation_model}` | |
| - **Fecha de Creación:** {date} | |
| """ | |
| _tox_pipe_singleton = None | |
| class DebiasingSFTTrainer(SFTTrainer): | |
| def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.reweighting_terms = [term.strip().lower() for term in reweighting_terms] if reweighting_terms else [] | |
| self.reweighting_factor = reweighting_factor | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| loss, outputs = super().compute_loss(model, inputs, return_outputs=True) | |
| if self.reweighting_terms and self.reweighting_factor > 1.0: | |
| input_ids = inputs.get("input_ids") | |
| decoded_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) | |
| for text in decoded_texts: | |
| if any(term in text.lower() for term in self.reweighting_terms): | |
| loss *= self.reweighting_factor | |
| break | |
| return (loss, outputs) if return_outputs else loss | |
| def _deduplication_generator(dataset, text_col, method, threshold, num_perm): | |
| if method == 'Exacta': | |
| seen_texts = set() | |
| for example in dataset: | |
| text = example.get(text_col, "") | |
| if text and isinstance(text, str): | |
| if text not in seen_texts: | |
| seen_texts.add(text) | |
| yield example | |
| else: | |
| yield example | |
| elif method == 'Semántica (MinHash)': | |
| lsh = MinHashLSH(threshold=threshold, num_perm=num_perm) | |
| for i, example in enumerate(dataset): | |
| text = example.get(text_col, "") | |
| if text and isinstance(text, str) and text.strip(): | |
| m = MinHash(num_perm=num_perm) | |
| for d in text.split(): | |
| m.update(d.encode('utf8')) | |
| if not lsh.query(m): | |
| lsh.insert(f"key_{i}", m) | |
| yield example | |
| else: | |
| yield example | |
| else: | |
| yield from dataset | |
| def _create_deduplicated_iterable_dataset(dataset, text_col, method, threshold=0.85, num_perm=128): | |
| return IterableDataset.from_generator( | |
| _deduplication_generator, | |
| gen_kwargs={ | |
| "dataset": dataset, | |
| "text_col": text_col, | |
| "method": method, | |
| "threshold": threshold, | |
| "num_perm": num_perm, | |
| } | |
| ) | |
| def hf_login(token): | |
| if not token: | |
| return "Por favor, introduce un token." | |
| try: | |
| login(token=token, add_to_git_credential=True) | |
| user = whoami() | |
| return f"✅ Conectado como: {user['name']}" | |
| except Exception as e: | |
| return f"❌ Error en la conexión: {e}" | |
| def _clean_text(example, text_col, **kwargs): | |
| text = example.get(text_col, "") | |
| if not isinstance(text, str): | |
| return example | |
| if kwargs.get('remove_html_tags'): | |
| text = BeautifulSoup(text, "html.parser").get_text() | |
| if kwargs.get('remove_urls_emails'): | |
| text = re.sub(r'http\S+|www\S+|httpsS+', '', text, flags=re.MULTILINE) | |
| if kwargs.get('normalize_whitespace'): | |
| text = ' '.join(text.split()) | |
| if kwargs.get('redact_pii'): | |
| text = re.sub(r'\S+@\S+', '<EMAIL>', text) | |
| text = re.sub(r'(\d{1,4}[-.\s]?){7,}|(\+\d{1,3}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}', '<PHONE>', text) | |
| text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '<IP_ADDRESS>', text) | |
| example[text_col] = text | |
| return example | |
| def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords): | |
| text = example.get(text_col, "") | |
| if not isinstance(text, str): return False | |
| text_len = len(text.split()) | |
| if not (min_len <= text_len <= max_len): return False | |
| words = text.split() | |
| if not words: return False | |
| word_counts = {} | |
| for word in words: word_counts[word] = word_counts.get(word, 0) + 1 | |
| if not word_counts or (max(word_counts.values()) / len(words)) > rep_threshold: return False | |
| lower_text = text.lower() | |
| return not any(keyword in lower_text for keyword in exclude_keywords) | |
| def _get_filter_functions(**kwargs): | |
| filters = [] | |
| if kwargs.get('enable_quality_filter'): | |
| exclude_list = [k.strip().lower() for k in (kwargs.get('exclude_keywords_input', '') + ',' + kwargs.get('bias_keywords_input', '')).split(",") if k.strip()] | |
| filters.append(lambda ex: _apply_quality_filters(ex, kwargs['text_col'], kwargs['min_len_input'], kwargs['max_len_input'], kwargs['rep_threshold_input'], exclude_list)) | |
| if kwargs.get('enable_language_filter'): | |
| allowed_langs = [lang.strip() for lang in kwargs.get('allowed_languages', 'en').split(',')] | |
| lang_threshold = kwargs.get('language_detection_threshold', 0.95) | |
| def lang_filter(ex): | |
| text = ex.get(kwargs['text_col'], "") | |
| if not text or not isinstance(text, str) or len(text.split()) < 5: return True | |
| try: | |
| detected = detect_langs(text) | |
| return any(lang.lang in allowed_langs and lang.prob > lang_threshold for lang in detected) | |
| except: | |
| return False | |
| filters.append(lang_filter) | |
| if kwargs.get('enable_toxicity_filter'): | |
| tox_threshold = kwargs.get('toxicity_threshold', 0.8) | |
| def tox_filter(ex): | |
| global _tox_pipe_singleton | |
| if _tox_pipe_singleton is None: | |
| logger.info("Initializing toxicity filter pipeline...") | |
| _tox_pipe_singleton = pipeline("text-classification", model="unitary/toxic-bert", device=0 if device == 'cuda' else -1) | |
| text = ex.get(kwargs['text_col'], "") | |
| if not text or not isinstance(text, str): return True | |
| try: | |
| results = _tox_pipe_singleton(text[:512], truncation=True) | |
| return not (results[0]['label'] == 'toxic' and results[0]['score'] > tox_threshold) | |
| except Exception: | |
| return True | |
| filters.append(tox_filter) | |
| if any([kwargs.get('enable_readability_filter'), kwargs.get('enable_stopword_filter'), kwargs.get('enable_uniqueness_filter')]): | |
| stop_words = set(textstat.DEFAULT_stopwords) | |
| def stats_filter(ex): | |
| text = ex.get(kwargs['text_col'], "") | |
| if not isinstance(text, str) or not text: return True | |
| words = text.split() | |
| num_words = len(words) | |
| if num_words == 0: return True | |
| if kwargs.get('enable_readability_filter'): | |
| score = textstat.flesch_reading_ease(text) | |
| if not (kwargs['min_readability'] <= score <= kwargs['max_readability']): return False | |
| if kwargs.get('enable_stopword_filter'): | |
| if (textstat.stopword_count(text) / num_words) > kwargs['max_stopword_ratio']: return False | |
| if kwargs.get('enable_uniqueness_filter'): | |
| if (len(set(words)) / num_words) < kwargs['min_uniqueness_ratio']: return False | |
| return True | |
| filters.append(stats_filter) | |
| return filters | |
| def _load_hf_streaming(ids, split="train", probabilities=None): | |
| streams = [] | |
| valid_ids = [] | |
| for ident in ids: | |
| try: | |
| d = load_dataset(ident, streaming=True, trust_remote_code=True) | |
| split_found = False | |
| if isinstance(d, dict): | |
| for s_name, ds in d.items(): | |
| if s_name.lower() == split or (split == "train" and "train" in s_name.lower()): | |
| streams.append(ds) | |
| split_found = True | |
| break | |
| else: | |
| streams.append(d) | |
| split_found = True | |
| if split_found: | |
| valid_ids.append(ident) | |
| else: | |
| logger.warning(f"Split '{split}' not found in dataset {ident}. Excluding from this source.") | |
| except Exception as e: | |
| logger.error(f"Error loading dataset {ident} split {split}: {e}. Excluding from this source.") | |
| if not streams: | |
| return None | |
| if probabilities and len(probabilities) != len(streams): | |
| logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.") | |
| probabilities = None | |
| return interleave_datasets(streams, probabilities=probabilities) | |
| def _load_uploaded_stream(files): | |
| all_rows = [] | |
| for f in files or []: | |
| content = f.read().decode("utf-8", errors="ignore") | |
| name = f.name.lower() | |
| if name.endswith(".csv"): | |
| import csv | |
| all_rows.extend(list(csv.DictReader(io.StringIO(content)))) | |
| elif name.endswith(".jsonl"): | |
| all_rows.extend([json.loads(line) for line in io.StringIO(content) if line.strip()]) | |
| elif name.endswith(".json"): | |
| data = json.loads(content) | |
| all_rows.extend(data if isinstance(data, list) else [data]) | |
| elif name.endswith(".txt"): | |
| all_rows.extend([{"text": line} for line in io.StringIO(content) if line.strip()]) | |
| if not all_rows: | |
| return None | |
| val_size = max(1, int(len(all_rows) * 0.01)) | |
| random.shuffle(all_rows) | |
| return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []} | |
| def _guess_columns(sample): | |
| text_col, image_col, audio_col, label_col = "text", "image", "audio", "label" | |
| if not isinstance(sample, dict): | |
| return text_col, image_col, audio_col, label_col | |
| keys = {k.lower(): k for k in sample.keys()} | |
| if "text" in keys: text_col = keys["text"] | |
| elif "content" in keys: text_col = keys["content"] | |
| elif "prompt" in keys: text_col = keys["prompt"] | |
| if "image" in keys: image_col = keys["image"] | |
| elif "img" in keys: image_col = keys["img"] | |
| if "audio" in keys: audio_col = keys["audio"] | |
| elif "speech" in keys: audio_col = keys["speech"] | |
| if "label" in keys: label_col = keys["label"] | |
| elif "labels" in keys: label_col = keys["labels"] | |
| return text_col, image_col, audio_col, label_col | |
| def _apply_cda(dataset, text_col, cda_config_str): | |
| try: | |
| swap_groups = json.loads(cda_config_str) | |
| except (json.JSONDecodeError, ValueError) as e: | |
| logger.error(f"Configuración de CDA inválida: {e}.") | |
| return dataset | |
| def cda_generator(): | |
| for example in dataset: | |
| original_text = example.get(text_col, "") | |
| if not isinstance(original_text, str): | |
| yield example | |
| continue | |
| yield example | |
| generated_texts = {original_text} | |
| current_texts = {original_text} | |
| for group in swap_groups: | |
| next_texts = set() | |
| for text in current_texts: | |
| for word_to_replace in group: | |
| if word_to_replace in text: | |
| for replacement_word in group: | |
| if word_to_replace != replacement_word: | |
| new_text = text.replace(word_to_replace, replacement_word) | |
| if new_text not in generated_texts: | |
| new_example = example.copy() | |
| new_example[text_col] = new_text | |
| yield new_example | |
| generated_texts.add(new_text) | |
| next_texts.add(new_text) | |
| current_texts.update(next_texts) | |
| return IterableDataset.from_generator(cda_generator) | |
| def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id): | |
| if not ratio or ratio <= 0: | |
| return dataset | |
| logger.info(f"Aplicando retrotraducción al {ratio*100}% del dataset.") | |
| try: | |
| pipe_to = pipeline("translation", model=model_id, device=0 if device == 'cuda' else -1) | |
| pipe_from = pipeline("translation", model=reverse_model_id, device=0 if device == 'cuda' else -1) | |
| except Exception as e: | |
| logger.error(f"No se pudieron cargar los modelos de traducción: {e}") | |
| return dataset | |
| def bt_generator(): | |
| for example in dataset: | |
| yield example | |
| if random.random() < ratio: | |
| original_text = example.get(text_col, "") | |
| if isinstance(original_text, str) and original_text: | |
| try: | |
| translated = pipe_to(original_text, max_length=512)[0]['translation_text'] | |
| back_translated = pipe_from(translated, max_length=512)[0]['translation_text'] | |
| if back_translated: | |
| new_example = example.copy() | |
| new_example[text_col] = back_translated | |
| yield new_example | |
| except Exception as e: | |
| logger.warning(f"Error en retrotraducción: {e}") | |
| return IterableDataset.from_generator(bt_generator) | |
| def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template): | |
| if not num_samples or num_samples <= 0: | |
| return None | |
| logger.info(f"Iniciando generación de {num_samples} muestras sintéticas con el modelo {model_id}.") | |
| try: | |
| generator = pipeline("text-generation", model=model_id, torch_dtype=torch_dtype_auto, device=0 if device == 'cuda' else -1) | |
| except Exception as e: | |
| logger.error(f"No se pudo cargar el modelo generador sintético: {e}") | |
| return None | |
| seed_examples = list(islice(original_dataset, 200)) | |
| if not seed_examples: | |
| logger.warning("Dataset original vacío, no se pueden generar datos sintéticos.") | |
| return None | |
| def synthetic_generator(): | |
| for i in range(num_samples): | |
| seed_example = random.choice(seed_examples) | |
| seed_text = seed_example.get(text_col, "") | |
| prompt = Template(prompt_template).render(example_text=seed_text) | |
| try: | |
| generated_output = generator(prompt, max_new_tokens=256, num_return_sequences=1, do_sample=True, temperature=0.9, top_p=0.95) | |
| cleaned_text = generated_output[0]['generated_text'][len(prompt):].strip() | |
| if "new example:" in cleaned_text.lower(): | |
| cleaned_text = re.split("new example:", cleaned_text, flags=re.IGNORECASE)[-1].strip() | |
| if cleaned_text: | |
| new_example = seed_example.copy() | |
| new_example[text_col] = cleaned_text | |
| yield new_example | |
| except Exception as e: | |
| logger.warning(f"Error generando una muestra sintética: {e}") | |
| continue | |
| return IterableDataset.from_generator(synthetic_generator) | |
| def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation): | |
| safe_steps = int(steps_per_epoch_estimate or 10000) | |
| safe_batch_size = int(batch_size or 1) | |
| safe_grad_accum = int(gradient_accumulation or 8) | |
| safe_block_size = int(block_size or 1024) | |
| size = safe_steps * safe_batch_size * safe_grad_accum | |
| if size <= 1: | |
| size = 10000 | |
| log_size = math.log2(max(1000, size)) | |
| vocab_size = min(65536, 32000 + int(log_size * 2000)) | |
| preliminary_hidden_size = max(512, min(4096, 512 + int(log_size * 100))) | |
| heads = max(8, min(32, preliminary_hidden_size // 64)) | |
| if heads == 0: heads = 8 | |
| hidden_size = (preliminary_hidden_size // heads) * heads | |
| layers = max(8, min(32, 8 + int(log_size * 1.5))) | |
| kv_heads = heads if is_gpt2_like else (max(1, heads // 4)) | |
| return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads | |
| def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn): | |
| if eval_ds_id: | |
| yield update_logs_fn(f"Cargando dataset de evaluación: {eval_ds_id}", "Evaluación") | |
| return _load_hf_streaming([eval_ds_id], split="train") | |
| if uploaded_val_data: | |
| yield update_logs_fn("Usando split de validación de archivos subidos.", "Evaluación") | |
| return HFDataset.from_list(uploaded_val_data) | |
| if train_ds_id: | |
| yield update_logs_fn("Intentando cargar split 'validation' o 'test' del dataset de entrenamiento.", "Evaluación") | |
| try: | |
| for split_name in ["validation", "test"]: | |
| eval_ds = _load_hf_streaming([train_ds_id], split=split_name) | |
| if eval_ds: | |
| yield update_logs_fn(f"Split '{split_name}' encontrado y cargado.", "Evaluación") | |
| return eval_ds | |
| except Exception as e: | |
| yield update_logs_fn(f"Error cargando split de evaluación: {e}. Omitiendo.", "Evaluación") | |
| return None | |
| yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación") | |
| return None | |
| def _create_training_args(output_dir, repo_id, **kwargs): | |
| neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0)) | |
| optim_args_dict = {} | |
| if kwargs.get('optim_args'): | |
| try: | |
| optim_args_dict = ast.literal_eval(f"dict({kwargs['optim_args']})") | |
| except Exception as e: | |
| logger.warning(f"No se pudieron parsear los argumentos del optimizador: {e}.") | |
| args_dict = { | |
| "output_dir": os.path.join(output_dir, "results"), | |
| "per_device_train_batch_size": int(kwargs.get('batch_size', 1)), | |
| "gradient_accumulation_steps": int(kwargs.get('gradient_accumulation', 8)), | |
| "optim": kwargs.get('optimizer', 'adamw_torch'), | |
| "optim_args": optim_args_dict, | |
| "save_strategy": "steps", | |
| "logging_steps": int(kwargs.get('logging_steps', 10)), | |
| "save_steps": int(kwargs.get('save_steps', 50)), | |
| "eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None, | |
| "learning_rate": float(kwargs.get('learning_rate', 2e-5)), | |
| "fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda', | |
| "bf16": kwargs.get('mixed_precision') == 'bf16' and device == 'cuda', | |
| "max_grad_norm": float(kwargs.get('max_grad_norm', 1.0)), | |
| "warmup_ratio": float(kwargs.get('warmup_ratio', 0.03)), | |
| "lr_scheduler_type": kwargs.get('scheduler', 'cosine'), | |
| "weight_decay": float(kwargs.get('weight_decay', 0.01)), | |
| "load_best_model_at_end": kwargs.get('run_evaluation', False), | |
| "save_total_limit": int(kwargs.get('save_total_limit', 1)), | |
| "gradient_checkpointing": not kwargs.get('disable_gradient_checkpointing', False) and device == 'cuda', | |
| "push_to_hub": True, | |
| "hub_model_id": repo_id, | |
| "hub_strategy": kwargs.get('hub_strategy', 'every_save'), | |
| "dataloader_num_workers": 2, | |
| "report_to": "wandb" if kwargs.get('wandb_api_key_input') else "none", | |
| "remove_unused_columns": False, | |
| "group_by_length": kwargs.get('group_by_length', False), | |
| "packing": kwargs.get('packing', False), | |
| "metric_for_best_model": kwargs.get('metric_for_best_model', 'loss') if kwargs.get('run_evaluation') else None, | |
| "greater_is_better": kwargs.get('greater_is_better', False), | |
| "neftune_noise_alpha": neftune_alpha if neftune_alpha > 0 else None, | |
| "adam_beta1": float(kwargs.get('adam_beta1', 0.9)), | |
| "adam_beta2": float(kwargs.get('adam_beta2', 0.999)), | |
| "adam_epsilon": float(kwargs.get('adam_epsilon', 1e-8)), | |
| "no_cuda": device == 'cpu' | |
| } | |
| if kwargs.get('early_stopping_patience', 0) > 0 and kwargs.get('run_evaluation', False): | |
| args_dict['early_stopping_patience'] = int(kwargs['early_stopping_patience']) | |
| args_dict['load_best_model_at_end'] = True | |
| is_diffusion_task = kwargs.get('training_mode', '') in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"] | |
| if is_diffusion_task: | |
| args_dict["num_train_epochs"] = float(kwargs.get('epochs', 1.0)) | |
| else: | |
| max_steps_val = int(kwargs.get('max_steps', -1)) | |
| if max_steps_val > 0: | |
| args_dict["max_steps"] = max_steps_val | |
| else: | |
| raise ValueError("Para datasets en streaming se requiere un valor positivo para 'Máximos Pasos de Entrenamiento'.") | |
| return TrainingArguments(**args_dict) | |
| def _generic_model_loader(model_name_or_path, model_class, **kwargs): | |
| quantization_type = kwargs.get('quantization', 'no') | |
| bnb_config = None | |
| if quantization_type != "no" and device == "cuda": | |
| try: | |
| import bitsandbytes as bnb | |
| if quantization_type == "4bit": | |
| bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype_auto, bnb_4bit_use_double_quant=True) | |
| elif quantization_type == "8bit": | |
| bnb_config = BitsAndBytesConfig(load_in_8bit=True) | |
| except ImportError: | |
| logger.warning("bitsandbytes no está instalado. No se puede cargar en 4bit/8bit.") | |
| elif quantization_type != "no" and device == "cpu": | |
| logger.warning("La cuantización solo es compatible con GPU CUDA. Se procederá sin cuantización.") | |
| attn_implementation = kwargs.get('attn_implementation', 'eager') | |
| if attn_implementation == "flash_attention_2" and device != 'cuda': | |
| attn_implementation = "eager" | |
| logger.warning("Flash Attention 2 solo está disponible en CUDA. Se usará la implementación 'eager'.") | |
| config_kwargs = {"trust_remote_code": True} | |
| if kwargs.get('label2id'): | |
| config_kwargs.update({"label2id": kwargs['label2id'], "id2label": kwargs['id2label']}) | |
| config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs) | |
| if kwargs.get('attention_dropout', 0) > 0: config.attention_dropout = kwargs['attention_dropout'] | |
| if kwargs.get('hidden_dropout', 0) > 0: config.hidden_dropout = kwargs['hidden_dropout'] | |
| model_kwargs = { | |
| "trust_remote_code": True, "config": config, "attn_implementation": attn_implementation, | |
| "torch_dtype": torch_dtype_auto, "quantization_config": bnb_config, | |
| } | |
| if device == "cuda" and bnb_config is None: | |
| model_kwargs["device_map"] = "auto" | |
| elif device == "cpu": | |
| model_kwargs["device_map"] = "cpu" | |
| if kwargs.get('num_labels'): | |
| model_kwargs.update({"num_labels": kwargs['num_labels'], "ignore_mismatched_sizes": True}) | |
| model = model_class.from_pretrained(model_name_or_path, **model_kwargs) | |
| if device == 'cpu' and hasattr(model, 'to'): | |
| model.to(device) | |
| return model | |
| def _find_all_linear_names(model, quantization_type): | |
| cls = torch.nn.Linear | |
| if quantization_type != 'no' and device == "cuda": | |
| try: | |
| import bitsandbytes as bnb | |
| if quantization_type == '4bit': | |
| cls = bnb.nn.Linear4bit | |
| elif quantization_type == '8bit': | |
| cls = bnb.nn.Linear8bitLt | |
| except ImportError: | |
| logger.warning("bitsandbytes no está instalado. No se puede determinar los módulos cuantizados.") | |
| lora_module_names = set() | |
| for name, module in model.named_modules(): | |
| if isinstance(module, cls): | |
| names = name.split('.') | |
| lora_module_names.add(names[-1]) | |
| if 'lm_head' in lora_module_names: | |
| lora_module_names.remove('lm_head') | |
| common_targets = {'q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'} | |
| return list(lora_module_names.intersection(common_targets)) or list(lora_module_names) | |
| def _sft_formatting_func(example, text_col, tokenizer, **kwargs): | |
| if kwargs.get('sft_format_style') == "Conversacional": | |
| conv_col = "" | |
| for key in ["messages", "conversations", "turns"]: | |
| if key in example: conv_col = key; break | |
| if not conv_col: return "" | |
| conversation = example[conv_col] | |
| if isinstance(conversation, str): | |
| try: conversation = ast.literal_eval(conversation) | |
| except: return "" | |
| return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) | |
| if kwargs.get('sft_format_style') == "Razonamiento/Herramientas": | |
| messages = [] | |
| prompt = example.get(kwargs.get('prompt_col_input', 'prompt'), "") | |
| if prompt: messages.append({"role": "user", "content": prompt}) | |
| response_parts = [] | |
| if kwargs.get('enable_cot_input') and example.get(kwargs.get('reasoning_col_input', 'reasoning')): | |
| response_parts.append(f"<thinking>{example[kwargs.get('reasoning_col_input', 'reasoning')]}</thinking>") | |
| if kwargs.get('enable_tool_use_input') and example.get(kwargs.get('tool_use_col_input', 'tools')): | |
| response_parts.append(f"<tool_code>{example[kwargs.get('tool_use_col_input', 'tools')]}</tool_code>") | |
| if example.get(kwargs.get('response_col_input', 'response')): | |
| response_parts.append(example.get(kwargs.get('response_col_input', 'response'))) | |
| if response_parts: | |
| messages.append({"role": "assistant", "content": "\n".join(response_parts)}) | |
| if messages: | |
| try: | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) | |
| except Exception as e: | |
| logger.error(f"Error aplicando la plantilla de chat: {e}.") | |
| return "\n".join([m['content'] for m in messages]) | |
| return "" | |
| return example.get(text_col, "") | |
| def _dpo_formatting_func(example, **kwargs): | |
| return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")} | |
| def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col): | |
| model.eval() | |
| encodings = tokenizer("\n\n".join(ex[text_col] for ex in islice(eval_dataset, 1000)), return_tensors="pt").to(model.device) | |
| max_length = model.config.max_position_embeddings | |
| stride = 512 | |
| seq_len = encodings.input_ids.size(1) | |
| nlls = [] | |
| prev_end_loc = 0 | |
| with torch.no_grad(): | |
| for begin_loc in range(0, seq_len, stride): | |
| end_loc = min(begin_loc + max_length, seq_len) | |
| trg_len = end_loc - prev_end_loc | |
| input_ids = encodings.input_ids[:, begin_loc:end_loc] | |
| target_ids = input_ids.clone() | |
| target_ids[:, :-trg_len] = -100 | |
| outputs = model(input_ids, labels=target_ids) | |
| neg_log_likelihood = outputs.loss | |
| nlls.append(neg_log_likelihood) | |
| prev_end_loc = end_loc | |
| if end_loc == seq_len: | |
| break | |
| ppl = torch.exp(torch.stack(nlls).mean()) | |
| return ppl.item() | |
| def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type): | |
| adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()] | |
| if not adapter_ids: | |
| yield "No se proporcionaron IDs de adaptadores válidos. Omitiendo la fusión múltiple." | |
| return base_model_id | |
| try: | |
| weights = [float(w.strip()) for w in weights_str.split(',')] | |
| except: | |
| weights = [1.0] * len(adapter_ids) | |
| if len(weights) != len(adapter_ids): | |
| weights = [1.0] * len(adapter_ids) | |
| yield "Pesos de adaptadores inválidos, usando 1.0 para todos." | |
| yield f"Cargando modelo base {base_model_id} para fusión múltiple..." | |
| model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch_dtype_auto, trust_remote_code=True, device_map=device) | |
| for i, adapter_id in enumerate(adapter_ids): | |
| yield f"Cargando adaptador {i+1}: {adapter_id}" | |
| model.load_adapter(adapter_id, adapter_name=f"adapter_{i}") | |
| adapter_names = [f"adapter_{i}" for i in range(len(adapter_ids))] | |
| yield f"Combinando adaptadores: {adapter_names} con pesos: {weights} y tipo: {combination_type}" | |
| model.add_weighted_adapter(adapters=adapter_names, weights=weights, adapter_name="combined", combination_type=combination_type) | |
| model.set_adapter("combined") | |
| yield "Fusionando combinación de adaptadores en el modelo base..." | |
| merged_model = model.merge_and_unload() | |
| temp_dir = tempfile.mkdtemp() | |
| yield f"Guardando modelo fusionado en {temp_dir}" | |
| merged_model.save_pretrained(temp_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
| tokenizer.save_pretrained(temp_dir) | |
| yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}." | |
| return temp_dir | |
| def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando") | |
| trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False) | |
| final_metrics = {} | |
| if kwargs.get('run_evaluation'): | |
| eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log] | |
| if eval_logs: | |
| final_metrics = eval_logs[-1] | |
| final_metrics = {k.replace('eval_', ''): v for k, v in final_metrics.items()} | |
| yield update_logs_fn("Entrenamiento finalizado.", "Guardando") | |
| output_dir = trainer.args.output_dir | |
| trainer.save_model(output_dir) | |
| if tokenizer: | |
| tokenizer.save_pretrained(output_dir) | |
| with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f: | |
| f.write(model_card_content) | |
| yield update_logs_fn("Subiendo al Hub...", "Subiendo") | |
| upload_folder(folder_path=output_dir, repo_id=repo_id, commit_message="Fin de entrenamiento") | |
| del trainer | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return output_dir, final_metrics | |
| def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| is_dpo = kwargs.get('training_mode') == "DPO (Direct Preference Optimization)" | |
| text_col = kwargs.get('text_col') | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, use_fast=False) | |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token | |
| if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja'] | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForCausalLM, **kwargs) | |
| peft_config = None | |
| if kwargs.get('peft'): | |
| target_modules = kwargs.get('target_modules').split(",") if not kwargs.get('auto_find_target_modules') else _find_all_linear_names(model, kwargs.get('quantization')) | |
| yield update_logs_fn(f"Módulos LoRA detectados/especificados: {target_modules}", "Configuración") | |
| peft_config = LoraConfig( | |
| r=int(kwargs.get('lora_r')), lora_alpha=int(kwargs.get('lora_alpha')), lora_dropout=float(kwargs.get('lora_dropout')), | |
| target_modules=target_modules, bias="none", task_type="CAUSAL_LM", use_dora=kwargs.get('use_dora', False), | |
| use_rslora=kwargs.get('use_rslora', False), init_lora_weights=kwargs.get('init_lora_weights', 'gaussian'), | |
| modules_to_save=kwargs.get('modules_to_save').split(',') if kwargs.get('modules_to_save') else None | |
| ) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer) | |
| trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "tokenizer": tokenizer, "peft_config": peft_config} | |
| if is_dpo: | |
| trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2}) | |
| if train_dataset: | |
| train_dataset = train_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs)) | |
| if eval_dataset: | |
| eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs)) | |
| trainer_kwargs.update({"train_dataset": train_dataset, "eval_dataset": eval_dataset}) | |
| else: | |
| sft_kwargs = kwargs.copy() | |
| trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs), "max_seq_length": int(kwargs.get('block_size'))}) | |
| if kwargs.get('enable_loss_reweighting'): | |
| trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': float(kwargs.get('reweighting_factor', 2.0))}) | |
| trainer = TrainerClass(**trainer_kwargs) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}") | |
| def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| labels = [s.strip() for s in kwargs['classification_labels'].split(',')] | |
| label2id = {l: i for i, l in enumerate(labels)} | |
| id2label = {i: l for i, l in enumerate(labels)} | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs) | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| def preprocess(examples): | |
| return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512) | |
| train_dataset = train_dataset.map(preprocess, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True) | |
| metric = hf_evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| predictions = np.argmax(logits, axis=-1) | |
| return metric.compute(predictions=predictions, references=labels) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, compute_metrics=compute_metrics, | |
| tokenizer=tokenizer, data_collator=DataCollatorWithPadding(tokenizer=tokenizer) | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}") | |
| def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| labels = [s.strip() for s in kwargs['classification_labels'].split(',')] | |
| label2id = {l: i for i, l in enumerate(labels)} | |
| id2label = {i: l for i, l in enumerate(labels)} | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, add_prefix_space=True) | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForTokenClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs) | |
| def tokenize_and_align_labels(examples): | |
| tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) | |
| labels = [] | |
| for i, label in enumerate(examples["ner_tags"]): | |
| word_ids = tokenized_inputs.word_ids(batch_index=i) | |
| previous_word_idx = None | |
| label_ids = [] | |
| for word_idx in word_ids: | |
| if word_idx is None or word_idx == previous_word_idx: | |
| label_ids.append(-100) | |
| else: | |
| label_ids.append(label[word_idx]) | |
| previous_word_idx = word_idx | |
| labels.append(label_ids) | |
| tokenized_inputs["labels"] = labels | |
| return tokenized_inputs | |
| train_dataset = train_dataset.map(tokenize_and_align_labels, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| if eval_dataset: eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True) | |
| metric = hf_evaluate.load("seqeval") | |
| def compute_metrics(p): | |
| predictions, labels = p | |
| predictions = np.argmax(predictions, axis=2) | |
| true_predictions = [[id2label[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)] | |
| true_labels = [[id2label[l] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)] | |
| results = metric.compute(predictions=true_predictions, references=true_labels) | |
| return {"precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"]} | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, | |
| data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer), | |
| compute_metrics=compute_metrics | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}") | |
| def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForQuestionAnswering, **kwargs) | |
| max_length = 384 | |
| doc_stride = 128 | |
| def prepare_train_features(examples): | |
| tokenized_examples = tokenizer( | |
| examples["question"], | |
| examples["context"], | |
| truncation="only_second", | |
| max_length=max_length, | |
| stride=doc_stride, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding="max_length", | |
| ) | |
| sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") | |
| offset_mapping = tokenized_examples.pop("offset_mapping") | |
| tokenized_examples["start_positions"] = [] | |
| tokenized_examples["end_positions"] = [] | |
| for i, offsets in enumerate(offset_mapping): | |
| input_ids = tokenized_examples["input_ids"][i] | |
| cls_index = input_ids.index(tokenizer.cls_token_id) | |
| sequence_ids = tokenized_examples.sequence_ids(i) | |
| sample_index = sample_mapping[i] | |
| answers = examples["answers"][sample_index] | |
| if len(answers["answer_start"]) == 0: | |
| tokenized_examples["start_positions"].append(cls_index) | |
| tokenized_examples["end_positions"].append(cls_index) | |
| else: | |
| start_char = answers["answer_start"][0] | |
| end_char = start_char + len(answers["text"][0]) | |
| token_start_index = 0 | |
| while sequence_ids[token_start_index] != 1: | |
| token_start_index += 1 | |
| token_end_index = len(input_ids) - 1 | |
| while sequence_ids[token_end_index] != 1: | |
| token_end_index -= 1 | |
| if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): | |
| tokenized_examples["start_positions"].append(cls_index) | |
| tokenized_examples["end_positions"].append(cls_index) | |
| else: | |
| while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: | |
| token_start_index += 1 | |
| tokenized_examples["start_positions"].append(token_start_index - 1) | |
| while offsets[token_end_index][1] >= end_char: | |
| token_end_index -= 1 | |
| tokenized_examples["end_positions"].append(token_end_index + 1) | |
| return tokenized_examples | |
| train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=next(iter(train_dataset)).keys()) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| eval_dataset_raw = None | |
| for update in eval_dataset_raw_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset_raw = update | |
| if eval_dataset_raw: | |
| eval_dataset = eval_dataset_raw.map(prepare_train_features, batched=True, remove_columns=next(iter(eval_dataset_raw)).keys()) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| data_collator = DefaultDataCollator() | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}") | |
| def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForSeq2SeqLM, **kwargs) | |
| def preprocess_function(examples): | |
| inputs = [ex[kwargs['text_col']] for ex in examples["translation"]] | |
| targets = [ex[kwargs['label_col']] for ex in examples["translation"]] | |
| model_inputs = tokenizer(inputs, max_length=128, truncation=True) | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer(targets, max_length=128, truncation=True) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| train_dataset = train_dataset.map(preprocess_function, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| if eval_dataset: eval_dataset = eval_dataset.map(preprocess_function, batched=True) | |
| metric = hf_evaluate.load("sacrebleu") | |
| def compute_metrics(eval_preds): | |
| preds, labels = eval_preds | |
| if isinstance(preds, tuple): preds = preds[0] | |
| decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
| labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
| decoded_preds = [pred.strip() for pred in decoded_preds] | |
| decoded_labels = [[label.strip()] for label in decoded_labels] | |
| result = metric.compute(predictions=decoded_preds, references=decoded_labels) | |
| return {"bleu": result["score"]} | |
| training_args_dict = _create_training_args(output_dir, repo_id, **kwargs).to_dict() | |
| training_args_dict["predict_with_generate"] = True | |
| training_args = Seq2SeqTrainingArguments(**training_args_dict) | |
| trainer = Seq2SeqTrainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, | |
| data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model), | |
| compute_metrics=compute_metrics | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}") | |
| def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| if device == 'cpu': | |
| raise ValueError("El entrenamiento de Text-to-Image solo es compatible con GPU CUDA.") | |
| output_dir = tempfile.mkdtemp() | |
| accelerator = accelerate.Accelerator( | |
| gradient_accumulation_steps=int(kwargs.get('gradient_accumulation', 8)), | |
| mixed_precision=kwargs.get('mixed_precision', 'no') | |
| ) | |
| yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)") | |
| tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=torch_dtype_auto) | |
| vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=torch_dtype_auto) | |
| unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=torch_dtype_auto) | |
| noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler") | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| unet.train() | |
| yield update_logs_fn("Agregando adaptadores LoRA al UNet...", "Text-to-Image (LoRA)") | |
| unet_lora_config = LoraConfig( | |
| r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)), | |
| target_modules=["to_q", "to_k", "to_v", "to_out.0"], | |
| ) | |
| unet.add_adapter(unet_lora_config) | |
| if kwargs.get('dreambooth_train_text_encoder', False): | |
| yield update_logs_fn("Agregando adaptadores LoRA al Text Encoder...", "DreamBooth LoRA") | |
| text_encoder_lora_config = LoraConfig( | |
| r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)), | |
| target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], | |
| ) | |
| text_encoder.add_adapter(text_encoder_lora_config) | |
| yield update_logs_fn("Procesando dataset de imágenes...", "Text-to-Image (LoRA)") | |
| resolution = int(kwargs.get('diffusion_resolution', 512)) | |
| train_transforms = transforms.Compose([ | |
| transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(resolution), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| def preprocess_train(examples): | |
| images = [image.convert("RGB") for image in examples[kwargs.get('image_col', 'image')]] | |
| examples["pixel_values"] = [train_transforms(image) for image in images] | |
| examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids | |
| return examples | |
| with accelerator.main_process_first(): | |
| processed_dataset = train_dataset.map( | |
| function=preprocess_train, | |
| batched=True, | |
| remove_columns=[col for col in next(iter(train_dataset)).keys() if col not in ['pixel_values', 'input_ids']], | |
| ) | |
| def collate_fn(examples): | |
| pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
| input_ids = torch.stack([e["input_ids"][0] for e in examples]) | |
| return {"pixel_values": pixel_values, "input_ids": input_ids} | |
| train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1))) | |
| params_to_optimize = list(unet.parameters()) | |
| if kwargs.get('dreambooth_train_text_encoder', False): | |
| params_to_optimize += list(text_encoder.parameters()) | |
| optimizer = torch.optim.AdamW( | |
| params_to_optimize, lr=float(kwargs.get('learning_rate', 2e-5)), | |
| betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))), | |
| weight_decay=float(kwargs.get('weight_decay', 0.01)), | |
| eps=float(kwargs.get('adam_epsilon', 1e-8)), | |
| ) | |
| num_epochs = int(kwargs.get('epochs', 1)) | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / int(kwargs.get('gradient_accumulation', 8))) | |
| max_train_steps = num_epochs * num_update_steps_per_epoch | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer=optimizer, | |
| num_warmup_steps=int(max_train_steps * float(kwargs.get('warmup_ratio', 0.03))), | |
| num_training_steps=max_train_steps, | |
| ) | |
| unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| unet, text_encoder, optimizer, train_dataloader, lr_scheduler | |
| ) | |
| vae.to(accelerator.device, dtype=torch_dtype_auto) | |
| global_step = 0 | |
| final_loss = 0 | |
| for epoch in range(num_epochs): | |
| for step, batch in enumerate(train_dataloader): | |
| with accelerator.accumulate(unet): | |
| latents = vae.encode(batch["pixel_values"].to(dtype=torch_dtype_auto)).latent_dist.sample() | |
| latents = latents * vae.config.scaling_factor | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| encoder_hidden_states = text_encoder(batch["input_ids"])[0] | |
| noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | |
| final_loss = loss.detach().item() | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| params_to_clip = list(unet.parameters()) | |
| if kwargs.get('dreambooth_train_text_encoder', False): | |
| params_to_clip += list(text_encoder.parameters()) | |
| accelerator.clip_grad_norm_(params_to_clip, float(kwargs.get('max_grad_norm', 1.0))) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| if accelerator.is_main_process: | |
| if global_step % int(kwargs.get('logging_steps', 10)) == 0: | |
| yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss:.4f}", "Entrenando Difusión") | |
| global_step += 1 | |
| if global_step >= max_train_steps: | |
| break | |
| if global_step >= max_train_steps: | |
| break | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| pipeline = StableDiffusionText2ImagePipeline.from_pretrained( | |
| model_name, | |
| unet=accelerator.unwrap_model(unet), | |
| text_encoder=accelerator.unwrap_model(text_encoder), | |
| torch_dtype=torch_dtype_auto, | |
| ) | |
| pipeline.save_pretrained(output_dir) | |
| with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f: | |
| f.write(model_card_content) | |
| yield update_logs_fn("Subiendo al Hub...", "Subiendo") | |
| upload_folder(folder_path=output_dir, repo_id=repo_id, commit_message="Fin de entrenamiento de difusión") | |
| del unet, vae, text_encoder, optimizer, train_dataloader, lr_scheduler, pipeline | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return output_dir, {"final_loss": final_loss} | |
| def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| if device == 'cpu': | |
| raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.") | |
| dreambooth_prompt = kwargs.get('dreambooth_instance_prompt') | |
| if not dreambooth_prompt: | |
| raise ValueError("Se requiere un 'instance prompt' para el entrenamiento de DreamBooth.") | |
| def add_prompt(example): | |
| example[kwargs.get('text_col', 'text')] = dreambooth_prompt | |
| return example | |
| train_dataset = train_dataset.map(add_prompt) | |
| yield update_logs_fn(f"Usando el prompt de instancia para todas las imágenes: '{dreambooth_prompt}'", "DreamBooth LoRA") | |
| final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| def _get_data_processing_pipeline(**kwargs): | |
| hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()] | |
| if not hf_ids and not kwargs.get('uploads'): | |
| raise ValueError("No se proporcionaron datasets.") | |
| dataset_weights_str = kwargs.get('dataset_weights', '') | |
| probabilities = None | |
| if dataset_weights_str: | |
| try: | |
| probabilities = [float(w.strip()) for w in dataset_weights_str.split(',')] | |
| except ValueError: | |
| probabilities = None | |
| train_dataset, uploaded_val_data = None, None | |
| if kwargs.get('uploads'): | |
| uploaded_data_map = _load_uploaded_stream(kwargs.get('uploads')) | |
| if uploaded_data_map and uploaded_data_map["train"]: | |
| train_dataset = IterableDataset.from_generator(lambda: iter(uploaded_data_map["train"])) | |
| uploaded_val_data = uploaded_data_map["validation"] | |
| if hf_ids: | |
| hf_train_dataset = _load_hf_streaming(hf_ids, split="train", probabilities=probabilities if not train_dataset else None) | |
| if hf_train_dataset: | |
| if train_dataset is None: | |
| train_dataset = hf_train_dataset | |
| else: | |
| all_streams = [train_dataset, hf_train_dataset] | |
| all_probs = [0.5, 0.5] | |
| train_dataset = interleave_datasets(all_streams, probabilities=all_probs) | |
| if train_dataset is None: | |
| raise ValueError("No se pudieron cargar datos de entrenamiento válidos.") | |
| try: | |
| first_example = next(iter(train_dataset)) | |
| except StopIteration: | |
| raise ValueError("El dataset de entrenamiento está vacío después del procesamiento.") | |
| text_col, image_col, audio_col, label_col = _guess_columns(first_example) | |
| kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data}) | |
| is_text_task = kwargs['training_mode'] not in ["DreamBooth LoRA (Text-to-Image)", "Text-to-Image (LoRA)", "Image Classification (Vision)", "Audio Classification (Speech)"] | |
| if is_text_task: | |
| if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]): | |
| clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']} | |
| train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs)) | |
| filters = _get_filter_functions(**kwargs) | |
| if filters: | |
| for f in filters: | |
| train_dataset = train_dataset.filter(f) | |
| if kwargs.get('enable_back_translation'): | |
| train_dataset = _apply_back_translation(train_dataset, text_col, kwargs['bt_augmentation_ratio'], kwargs['bt_model_id'], kwargs['bt_reverse_model_id']) | |
| if kwargs.get('enable_synthetic_data'): | |
| synthetic_ds = _generate_synthetic_data(train_dataset, text_col, kwargs['synthetic_model_id'], int(kwargs['num_synthetic_samples']), kwargs['synthetic_prompt_template']) | |
| if synthetic_ds: | |
| train_dataset = interleave_datasets([train_dataset, synthetic_ds]) | |
| if kwargs.get('enable_cda') and kwargs.get('cda_json_config'): | |
| train_dataset = _apply_cda(train_dataset, text_col, kwargs['cda_json_config']) | |
| dedup_method = kwargs.get('deduplication_method') | |
| if dedup_method != 'Ninguna': | |
| train_dataset = _create_deduplicated_iterable_dataset( | |
| dataset=train_dataset, | |
| text_col=text_col, | |
| method=dedup_method, | |
| threshold=kwargs.get('minhash_threshold', 0.85), | |
| num_perm=int(kwargs.get('minhash_num_perm', 128)) | |
| ) | |
| return train_dataset, kwargs | |
| def _train_and_upload(**kwargs): | |
| logs, repo_link, final_model_path, final_metrics = "", "", None, {} | |
| yield ( | |
| "Iniciando...", | |
| "Inicio", | |
| "", | |
| gr.update(value=None), | |
| gr.update(value="Entrenando...", interactive=False), | |
| gr.update(visible=True) | |
| ) | |
| def update_logs(new_msg, phase_msg): | |
| nonlocal logs, repo_link, final_metrics | |
| logs += f"[{phase_msg}] {new_msg}\n" | |
| return ( | |
| logs, | |
| phase_msg, | |
| repo_link, | |
| gr.update(value=final_metrics if final_metrics else None) | |
| ) | |
| try: | |
| yield update_logs("Verificando autenticación...", "Inicio") + (gr.update(), gr.update()) | |
| user = whoami() | |
| username = user.get("name") | |
| if not username: | |
| raise ValueError("No se pudo obtener el nombre de usuario de Hugging Face. Por favor, verifica tu token.") | |
| model_name = kwargs.get('model_base_input', '').strip() | |
| if kwargs.get('enable_multi_adapter_merge'): | |
| temp_model_path = model_name | |
| lora_merge_generator = _merge_multiple_loras(model_name, kwargs['multi_adapter_model_ids'], kwargs['multi_adapter_weights'], kwargs['multi_adapter_combination_type']) | |
| try: | |
| while True: | |
| status = next(lora_merge_generator) | |
| yield update_logs(status, "Fusión Múltiple") + (gr.update(), gr.update()) | |
| except StopIteration as e: | |
| temp_model_path = e.value | |
| model_name = temp_model_path | |
| repo_name_input = kwargs.get('repo_name_input', '').strip() | |
| if repo_name_input: | |
| repo_base = re.sub(r'[^a-zA-Z0-9_.-]+', '-', repo_name_input) | |
| repo_base = re.sub(r'^[.-]+|[.-]+$', '', repo_base) | |
| else: | |
| model_name_base = model_name.split('/')[-1] if model_name else "finetuned-model" | |
| sanitized_model_name_base = re.sub(r'[^a-zA-Z0-9_.-]+', '-', model_name_base) | |
| sanitized_model_name_base = re.sub(r'^[.-]+|[.-]+$', '', sanitized_model_name_base) | |
| repo_base = f"{sanitized_model_name_base}-{uuid.uuid4().hex[:6]}" | |
| if not repo_base: | |
| repo_base = f"autotrain-model-{uuid.uuid4().hex[:8]}" | |
| max_repo_base_len = 96 - (len(username) + 1) | |
| repo_base = repo_base[:max_repo_base_len] | |
| repo_id = f"{username}/{repo_base}" | |
| yield update_logs(f"Creando o verificando repositorio: '{repo_id}'", "Inicio") + (gr.update(), gr.update()) | |
| create_repo(repo_id, exist_ok=True, private=kwargs.get('private_repo', False)) | |
| repo_link = f"https://huggingface.co/{repo_id}" | |
| yield update_logs("Repositorio listo.", "Inicio") + (gr.update(), gr.update()) | |
| base_model_id_for_training = model_name | |
| if kwargs.get('train_from_scratch'): | |
| yield update_logs("Preparando entrenamiento desde cero...", "Modelo Cero") + (gr.update(), gr.update()) | |
| architecture = kwargs.get('scratch_architecture') | |
| if not architecture or architecture not in ARCHITECTURE_MAP: | |
| raise ValueError(f"Arquitectura '{architecture}' no es válida o no está soportada para entrenamiento desde cero. Opciones válidas: {list(ARCHITECTURE_MAP.keys())}") | |
| config_class, model_class = ARCHITECTURE_MAP[architecture] | |
| if kwargs.get('auto_config_scratch'): | |
| vocab_size, hidden_size, intermediate_size, layers, heads, block_size_val, tie_word_embeddings, kv_heads = _calculate_auto_config(kwargs.get('block_size'), architecture == "GPT2", kwargs.get('steps_per_epoch_estimate'), kwargs.get('batch_size'), kwargs.get('gradient_accumulation')) | |
| else: | |
| vocab_size, hidden_size, intermediate_size, layers, heads, kv_heads, tie_word_embeddings = 32000, 1024, 2048, 8, 8, 8, False | |
| config = config_class(vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=layers, num_attention_heads=heads, num_key_value_heads=kv_heads, max_position_embeddings=int(kwargs.get('block_size', 1024)), tie_word_embeddings=tie_word_embeddings) | |
| model = model_class(config) | |
| temp_model_dir = tempfile.mkdtemp() | |
| model.save_pretrained(temp_model_dir) | |
| tokenizer_id = kwargs.get('tokenizer_name') or SCRATCH_TOKENIZER_MAP.get(architecture, "gpt2") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
| tokenizer.save_pretrained(temp_model_dir) | |
| yield update_logs(f"Tokenizer base '{tokenizer_id}' guardado para el modelo desde cero.", "Modelo Cero") + (gr.update(), gr.update()) | |
| except Exception as e: | |
| raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}") | |
| base_model_id_for_training = temp_model_dir | |
| kwargs["peft"] = False | |
| kwargs['tokenizer_name'] = temp_model_dir | |
| yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update()) | |
| yield update_logs("Procesando y cargando datasets...", "Datos") + (gr.update(), gr.update()) | |
| train_dataset, kwargs = _get_data_processing_pipeline(**kwargs) | |
| yield update_logs(f"Columnas detectadas (texto: {kwargs['text_col']}, imagen: {kwargs['image_col']})", "Datos") + (gr.update(), gr.update()) | |
| if kwargs.get('wandb_api_key_input'): | |
| os.environ["WANDB_API_KEY"] = kwargs['wandb_api_key_input'] | |
| os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}" | |
| os.environ["WANDB_LOG_MODEL"] = "checkpoint" | |
| model_card_content = MODEL_CARD_TEMPLATE.format( | |
| repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1], | |
| training_mode=kwargs.get('training_mode'), | |
| datasets=', '.join([x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]) or "Archivos locales", | |
| hyperparameters=json.dumps({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool)) and 'token' not in k and 'key' not in k and v is not None}, indent=2), | |
| date=datetime.now().strftime("%Y-%m-%d") | |
| ) | |
| training_mode = kwargs.get('training_mode') | |
| training_function_map = { | |
| "Causal Language Modeling (SFT/LoRA)": train_sft_dpo, | |
| "DPO (Direct Preference Optimization)": train_sft_dpo, | |
| "Question Answering (Text)": train_question_answering, | |
| "Sequence Classification (Text)": train_sequence_classification, | |
| "Token Classification (NER)": train_token_classification, | |
| "Text2Text Generation": train_seq2seq, | |
| "Text-to-Image (LoRA)": train_text_to_image, | |
| "DreamBooth LoRA (Text-to-Image)": train_dreambooth_lora, | |
| } | |
| train_func = training_function_map.get(training_mode) | |
| if train_func: | |
| train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs) | |
| while True: | |
| try: | |
| update = next(train_generator) | |
| if isinstance(update, tuple) and len(update) == 4: | |
| yield update + (gr.update(), gr.update()) | |
| else: | |
| pass | |
| except StopIteration as e: | |
| final_model_path, final_metrics = e.value | |
| break | |
| else: | |
| raise ValueError(f"El modo de entrenamiento '{training_mode}' no está implementado.") | |
| if kwargs.get('run_perplexity_evaluation') and final_model_path and training_mode in ["Causal Language Modeling (SFT/LoRA)", "DPO (Direct Preference Optimization)"]: | |
| yield update_logs("Iniciando evaluación de perplejidad...", "Evaluación Final") + (gr.update(), gr.update()) | |
| model = AutoModelForCausalLM.from_pretrained(final_model_path, torch_dtype=torch_dtype_auto, device_map=device) | |
| tokenizer = AutoTokenizer.from_pretrained(final_model_path) | |
| eval_dataset_perp = None | |
| eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p)) | |
| for update in eval_gen: | |
| if isinstance(update, dict): | |
| yield update + (gr.update(), gr.update()) | |
| else: | |
| eval_dataset_perp = update | |
| if eval_dataset_perp: | |
| ppl = _evaluate_perplexity(model, tokenizer, eval_dataset_perp, kwargs['text_col']) | |
| final_metrics['perplexity'] = ppl | |
| yield update_logs(f"Evaluación de Perplejidad completada. Perplejidad: {ppl:.4f}", "Evaluación Final") + (gr.update(), gr.update()) | |
| final_logs, final_phase, final_repo_link, _ = update_logs(f"✅ Entrenamiento y subida completados: {repo_link}", "Listo") | |
| yield ( | |
| final_logs, | |
| final_phase, | |
| f"### ✅ [Modelo Finalizado: Visita el Repositorio en el Hub]({final_repo_link})", | |
| gr.update(value=final_metrics), | |
| gr.update(value="Iniciar Entrenamiento", interactive=True), | |
| gr.update(visible=False) | |
| ) | |
| except Exception as e: | |
| err_msg = f"❌ Error fatal: {type(e).__name__}: {e}\n{traceback.format_exc()}" | |
| error_logs, error_phase, _, _ = update_logs(err_msg, "Error") | |
| yield ( | |
| error_logs, | |
| error_phase, | |
| "", | |
| gr.update(value=None), | |
| gr.update(value="Iniciar Entrenamiento", interactive=True), | |
| gr.update(visible=False) | |
| ) | |
| def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens): | |
| if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| task_name = TASK_TO_PIPELINE_MAP.get(task_mode) | |
| if not task_name: return f"La inferencia para el modo '{task_mode}' no está soportada.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| try: | |
| pipe = pipeline(task_name, model=model_id, torch_dtype=torch_dtype_auto, trust_remote_code=True, device=0 if device == 'cuda' else -1) | |
| result = None | |
| if task_name == "text-generation": | |
| if not text_in: return "Por favor, introduce un prompt de texto.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(text_in, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=temperature, top_p=top_p) | |
| elif task_name == "question-answering": | |
| if not text_in or not context_in: return "Por favor, introduce una pregunta y un contexto.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(question=text_in, context=context_in) | |
| elif task_name in ["token-classification", "text2text-generation", "text-classification"]: | |
| if not text_in: return f"Por favor, introduce texto para {task_name}.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(text_in) | |
| elif task_name in ["image-classification", "audio-classification", "automatic-speech-recognition"]: | |
| input_data = image_in if "image" in task_name else audio_in | |
| if input_data is None: return f"Por favor, proporciona una entrada de { 'imagen' if 'image' in task_name else 'audio' }.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(input_data) | |
| return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| def update_inference_ui(task_mode): | |
| task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "") | |
| is_text_gen = task_name == "text-generation" | |
| show_text = task_name in ["text-generation", "text2text-generation", "token-classification", "question-answering", "text-classification", "text-to-image"] | |
| show_context = task_name == "question-answering" | |
| show_image = task_name in ["image-classification"] | |
| show_audio = task_name in ["audio-classification", "automatic-speech-recognition"] | |
| text_label = "Pregunta" if task_name == "question-answering" else "Entrada de Texto / Prompt" | |
| return ( | |
| gr.update(visible=show_text, label=text_label), | |
| gr.update(visible=show_context), | |
| gr.update(visible=show_image), | |
| gr.update(visible=show_audio), | |
| gr.update(visible=is_text_gen) | |
| ) | |
| def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()): | |
| if not hf_token: | |
| return "Error: Se requiere un token de Hugging Face.", "" | |
| if not repo_name: | |
| return "Error: Se requiere un nombre de repositorio para el dataset.", "" | |
| try: | |
| login(token=hf_token) | |
| user = whoami() | |
| username = user.get("name") | |
| repo_base = f"{username}-{uuid.uuid4().hex[:6]}" if not repo_name else re.sub(r'[^a-zA-Z0-9_.-]+', '-', repo_name)[:90] | |
| repo_id = f"{username}/{repo_base}" | |
| create_repo(repo_id, repo_type="dataset", exist_ok=True) | |
| all_data = [] | |
| if creation_type == "Sintético": | |
| if not synth_model or not synth_prompt or not synth_num_samples: | |
| return "Error: Para la generación sintética se requiere un modelo, un prompt y un número de muestras.", "" | |
| progress(0, desc="Cargando modelo generador...") | |
| generator = pipeline("text-generation", model=synth_model, torch_dtype=torch_dtype_auto, device=0 if device == 'cuda' else -1) | |
| for i in progress.tqdm(range(int(synth_num_samples)), desc="Generando muestras"): | |
| try: | |
| generated_output = generator(synth_prompt, max_new_tokens=256, num_return_sequences=1, do_sample=True, temperature=0.9, top_p=0.95) | |
| cleaned_text = generated_output[0]['generated_text'][len(synth_prompt):].strip() | |
| if cleaned_text: | |
| all_data.append({"text": cleaned_text}) | |
| except Exception as e: | |
| logger.warning(f"Error al generar muestra {i}: {e}") | |
| elif creation_type == "Basado en Archivo": | |
| if not file_uploads: | |
| return "Error: Por favor, sube al menos un archivo.", "" | |
| progress(0.5, desc="Procesando archivos subidos...") | |
| file_data = _load_uploaded_stream(file_uploads) | |
| all_data = file_data.get("train", []) + file_data.get("validation", []) | |
| if not all_data: | |
| return "Error: No se generaron o procesaron datos.", "" | |
| progress(0.8, desc="Guardando y subiendo al Hub...") | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| data_file = os.path.join(temp_dir, "data.jsonl") | |
| with open(data_file, "w", encoding="utf-8") as f: | |
| for item in all_data: | |
| f.write(json.dumps(item, ensure_ascii=False) + "\n") | |
| readme_content = DATASET_CARD_TEMPLATE.format( | |
| repo_id=repo_id, | |
| creation_type=creation_type, | |
| generation_model=synth_model if creation_type == "Sintético" else "N/A", | |
| date=datetime.now().strftime("%Y-%m-%d") | |
| ) | |
| readme_file = os.path.join(temp_dir, "README.md") | |
| with open(readme_file, "w", encoding="utf-8") as f: | |
| f.write(readme_content) | |
| api = HfApi() | |
| api.upload_folder( | |
| folder_path=temp_dir, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| commit_message="Creación de dataset con AutoTrain-Advanced" | |
| ) | |
| dataset_link = f"https://huggingface.co/datasets/{repo_id}" | |
| return f"✅ Dataset creado y subido exitosamente a {repo_id}", f"### ✅ [Dataset Disponible: Visita el Repositorio]({dataset_link})" | |
| except Exception as e: | |
| return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", "" | |
| def gradio_train_wrapper(*args): | |
| kwargs = dict(zip(all_input_components_dict.keys(), args)) | |
| yield from _train_and_upload(**kwargs) | |
| def gradio_preview_data_wrapper(*args): | |
| kwargs = dict(zip(all_input_components_dict.keys(), args)) | |
| try: | |
| preview_text = "Procesando vista previa...\n" | |
| yield preview_text | |
| dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs) | |
| text_col = processed_kwargs.get('text_col') | |
| model_id_for_tokenizer = kwargs.get('model_base_input') | |
| if not model_id_for_tokenizer: | |
| raise ValueError("Se necesita un ID de modelo base para cargar el tokenizer para la vista previa.") | |
| tokenizer_id = kwargs.get('tokenizer_name') or model_id_for_tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_id, trust_remote_code=True, use_fast=False | |
| ) | |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token | |
| if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja'] | |
| preview_samples = [] | |
| for i, example in enumerate(islice(dataset, 5)): | |
| formatted_text = "" | |
| if kwargs['training_mode'] == "DPO (Direct Preference Optimization)": | |
| formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False) | |
| elif kwargs['training_mode'] == "Causal Language Modeling (SFT/LoRA)": | |
| formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs) | |
| else: | |
| formatted_text = str(example) | |
| preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n") | |
| preview_text = "\n".join(preview_samples) | |
| if not preview_samples: | |
| preview_text = "No se pudieron generar muestras. Revisa la configuración del dataset, los filtros y el formato." | |
| yield preview_text | |
| except Exception as e: | |
| yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}" | |
| def toggle_training_mode_ui(is_scratch): | |
| return ( | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch) | |
| ) | |
| def toggle_task_specific_ui(training_mode): | |
| is_classification = "Classification" in training_mode | |
| is_dpo = "DPO" in training_mode | |
| is_sft = "Causal" in training_mode | |
| is_ner = "Token Classification" in training_mode | |
| is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"] | |
| is_streaming = not is_diffusion | |
| return ( | |
| gr.update(visible=is_classification or is_ner), | |
| gr.update(visible=is_dpo), | |
| gr.update(visible=is_sft), | |
| gr.update(visible=is_diffusion), | |
| gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"), | |
| gr.update(visible=not is_diffusion), | |
| gr.update(visible=is_diffusion), | |
| gr.update(visible=is_streaming), | |
| gr.update(visible=not is_streaming), | |
| ) | |
| def toggle_sft_format_ui(format_style): | |
| is_tool = format_style == "Razonamiento/Herramientas" | |
| return gr.update(visible=is_tool) | |
| def toggle_auto_modules_ui(is_auto): | |
| return gr.update(visible=not is_auto) | |
| def toggle_dataset_creator_ui(choice): | |
| is_synth = choice == "Sintético" | |
| return gr.update(visible=is_synth), gr.update(visible=not is_synth) | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown("# 🚀 AutoTrain-Advanced: Tu Plataforma de Entrenamiento de Modelos") | |
| gr.Markdown("### Una interfaz completa para fine-tuning, PEFT (LoRA, QLoRA), y despliegue de modelos en Hugging Face.") | |
| with gr.Tab("1. Autenticación"): | |
| gr.Markdown("#### Conecta tu cuenta de Hugging Face para guardar y cargar modelos.") | |
| with gr.Row(): | |
| hf_token_input = gr.Textbox(label="Token de Hugging Face (con permisos de escritura)", type="password", placeholder="hf_...", scale=3) | |
| login_button = gr.Button("Conectar", variant="primary", scale=1) | |
| login_status = gr.Textbox(label="Estado de Conexión", interactive=False) | |
| login_button.click(hf_login, inputs=[hf_token_input], outputs=[login_status]) | |
| with gr.Tab("2. Creación de Dataset"): | |
| gr.Markdown("## 🧩 Genera o Procesa Datasets y Súbelos al Hub") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dset_repo_name = gr.Textbox(label="Nombre del Repositorio del Dataset", placeholder="mi-nuevo-dataset") | |
| dset_creation_type = gr.Radio(["Sintético", "Basado en Archivo"], label="Tipo de Creación", value="Sintético") | |
| with gr.Group(visible=True) as dset_synth_group: | |
| dset_synth_model = gr.Textbox(label="Modelo Generador", placeholder="p.ej. 'mistralai/Mistral-7B-Instruct-v0.2'") | |
| dset_synth_prompt = gr.Textbox(label="Prompt de Generación", lines=5, placeholder="Escribe una reseña de producto de 5 estrellas para...") | |
| dset_synth_num_samples = gr.Number(label="Número de Muestras", value=100) | |
| with gr.Group(visible=False) as dset_file_group: | |
| dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple") | |
| dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary") | |
| with gr.Column(scale=2): | |
| dset_status_output = gr.Textbox(label="Estado", lines=10, interactive=False) | |
| dset_link_output = gr.Markdown() | |
| dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group]) | |
| dset_create_button.click( | |
| create_and_upload_dataset, | |
| inputs=[hf_token_input, dset_repo_name, dset_creation_type, dset_synth_model, dset_synth_prompt, dset_synth_num_samples, dset_file_uploads], | |
| outputs=[dset_status_output, dset_link_output] | |
| ) | |
| with gr.Tab("3. Entrenamiento"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## ⚙️ Configuración del Entrenamiento") | |
| training_mode = gr.Dropdown(TRAINING_MODES, label="Modo de Entrenamiento", value=TRAINING_MODES[0]) | |
| with gr.Accordion("📦 Modelo y Repositorio", open=True): | |
| model_base_input = gr.Textbox(label="ID del Modelo Base", placeholder="p.ej. 'mistralai/Mistral-7B-v0.1'") | |
| tokenizer_name_input = gr.Textbox(label="ID del Tokenizer (opcional)", placeholder="p.ej. si el modelo no tiene tokenizer") | |
| repo_name_input = gr.Textbox(label="Nombre del Repositorio de Destino", placeholder="p.ej. 'mi-modelo-afinado'") | |
| train_from_scratch = gr.Checkbox(label="Entrenar desde Cero", value=False) | |
| auto_config_scratch = gr.Checkbox(label="Auto-Configuración", value=True, visible=False) | |
| scratch_architecture = gr.Textbox(label="Arquitectura (p.ej. Llama, Mistral)", value="Llama", visible=False) | |
| with gr.Accordion("🔄 Fusión de Múltiples Adaptadores (Avanzado)", open=False) as multi_adapter_accordion: | |
| enable_multi_adapter_merge = gr.Checkbox(label="Habilitar Fusión Múltiple", value=False) | |
| multi_adapter_model_ids = gr.Textbox(label="IDs de Adaptadores (csv)", placeholder="org/adapter1,org/adapter2") | |
| multi_adapter_weights = gr.Textbox(label="Pesos (csv)", placeholder="0.5,0.5") | |
| multi_adapter_combination_type = gr.Dropdown(["slerp", "linear", "cat", "svd", "dare_linear", "dare_ties", "ties"], label="Tipo de Combinación", value="slerp") | |
| with gr.Accordion("📚 Dataset", open=True): | |
| datasets_hf_text = gr.Textbox(label="Datasets de Hugging Face (csv)", placeholder="p.ej. 'databricks/dolly-15k'") | |
| uploads = gr.File(label="Subir Archivos Locales (.jsonl, .csv, .txt)", file_count="multiple") | |
| dataset_weights = gr.Textbox(label="Pesos de los Datasets (csv)", placeholder="p.ej. 0.7, 0.3") | |
| eval_dataset_hf = gr.Textbox(label="Dataset de Evaluación (opcional)", placeholder="p.ej. 'nombre/dataset_eval'") | |
| preview_data_button = gr.Button("Previsualizar Datos Procesados") | |
| data_preview_output = gr.Textbox(label="Vista Previa de Datos", lines=8, interactive=False) | |
| with gr.Accordion("🎓 Hiperparámetros", open=False): | |
| with gr.Row(): | |
| learning_rate = gr.Textbox(label="Tasa de Aprendizaje", value="2e-5") | |
| batch_size = gr.Textbox(label="Tamaño de Lote", value="1") | |
| gradient_accumulation = gr.Textbox(label="Acumulación de Gradiente", value="8") | |
| with gr.Row(): | |
| block_size = gr.Textbox(label="Longitud de Secuencia", value="1024") | |
| with gr.Group(visible=True) as max_steps_ui: | |
| max_steps = gr.Textbox(label="Máximos Pasos de Entrenamiento", value="100") | |
| with gr.Group(visible=False) as epochs_ui: | |
| epochs = gr.Textbox(label="Épocas", value="1") | |
| with gr.Row(): | |
| optimizer = gr.Dropdown(["adamw_torch", "adafactor", "sgd", "adagrad"], label="Optimizador", value="adamw_torch") | |
| scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine") | |
| mixed_precision = gr.Radio(["no", "fp16", "bf16"], label="Precisión Mixta", value="no") | |
| with gr.Accordion("Avanzados", open=False): | |
| warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento") | |
| weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01") | |
| max_grad_norm = gr.Textbox(label="Norma Máxima de Gradiente", value="1.0") | |
| logging_steps = gr.Textbox(label="Pasos de Registro", value="10") | |
| save_steps = gr.Textbox(label="Pasos de Guardado", value="50") | |
| save_total_limit = gr.Textbox(label="Límite Total de Guardado", value="1") | |
| early_stopping_patience = gr.Number(label="Paciencia para Early Stopping (0 para desactivar)", value=0) | |
| resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False) | |
| with gr.Row(): | |
| adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9") | |
| adam_beta2 = gr.Textbox(label="Adam Beta2", value="0.999") | |
| adam_epsilon = gr.Textbox(label="Adam Epsilon", value="1e-8") | |
| disable_gradient_checkpointing = gr.Checkbox(label="Deshabilitar Gradient Checkpointing", value=False) | |
| group_by_length = gr.Checkbox(label="Agrupar por Longitud", value=False) | |
| neftune_noise_alpha = gr.Textbox(label="NEFTune Ruido Alfa (0 para desactivar)", value="0") | |
| optim_args = gr.Textbox(label="Argumentos del Optimizador (formato dict)", placeholder="ej: betas=(0.9,0.995)") | |
| attn_implementation = gr.Dropdown(["eager", "flash_attention_2"], label="Implementación de Atención", value="eager") | |
| with gr.Accordion("🦋 PEFT (LoRA / QLoRA)", open=True) as peft_accordion: | |
| peft = gr.Checkbox(label="Habilitar PEFT/LoRA", value=True) | |
| quantization = gr.Dropdown(["no", "4bit", "8bit"], label="Cuantización", value="no") | |
| with gr.Row(): | |
| lora_r = gr.Textbox(label="LoRA r", value="16") | |
| lora_alpha = gr.Textbox(label="LoRA alpha", value="32") | |
| lora_dropout = gr.Textbox(label="LoRA dropout", value="0.05") | |
| auto_find_target_modules = gr.Checkbox(label="Auto-encontrar Módulos de Destino", value=True) | |
| target_modules = gr.Textbox(label="Módulos de Destino (csv)", placeholder="q_proj,v_proj", visible=False) | |
| modules_to_save = gr.Textbox(label="Módulos a Guardar (csv)", placeholder="embed_tokens,lm_head") | |
| with gr.Row(): | |
| use_dora = gr.Checkbox(label="Usar DoRA", value=False) | |
| use_rslora = gr.Checkbox(label="Usar RSLora", value=False) | |
| init_lora_weights = gr.Dropdown(["gaussian", "loftq", "pissa"], label="Inicialización de Pesos LoRA", value="gaussian") | |
| with gr.Accordion("🧹 Procesamiento y Aumentación de Datos", open=False): | |
| with gr.Tab("Limpieza y Normalización"): | |
| remove_html_tags = gr.Checkbox(label="Eliminar Etiquetas HTML", value=True) | |
| normalize_whitespace = gr.Checkbox(label="Normalizar Espacios en Blanco", value=True) | |
| remove_urls_emails = gr.Checkbox(label="Eliminar URLs/Emails", value=True) | |
| redact_pii = gr.Checkbox(label="Redactar PII", value=True) | |
| with gr.Tab("Filtrado"): | |
| enable_quality_filter = gr.Checkbox(label="Habilitar Filtros de Calidad", value=True) | |
| min_len_input = gr.Slider(1, 100, 10, label="Longitud Mínima (palabras)") | |
| max_len_input = gr.Slider(100, 5000, 2000, label="Longitud Máxima (palabras)") | |
| rep_threshold_input = gr.Slider(0, 1, 0.2, label="Umbral de Repetición") | |
| exclude_keywords_input = gr.Textbox(label="Palabras Clave a Excluir (csv)") | |
| with gr.Tab("Deduplicación"): | |
| deduplication_method = gr.Radio(["Ninguna", "Exacta", "Semántica (MinHash)"], label="Método de Deduplicación", value="Ninguna") | |
| minhash_threshold = gr.Slider(0.7, 0.99, 0.85, label="Umbral MinHash") | |
| minhash_num_perm = gr.Slider(64, 256, 128, step=16, label="Permutaciones MinHash") | |
| with gr.Tab("Aumentación"): | |
| enable_back_translation = gr.Checkbox(label="Habilitar Retrotraducción", value=False) | |
| bt_model_id = gr.Textbox(label="Modelo de Traducción", value="Helsinki-NLP/opus-mt-en-de") | |
| bt_reverse_model_id = gr.Textbox(label="Modelo Inverso", value="Helsinki-NLP/opus-mt-de-en") | |
| with gr.Tab("Generación Sintética"): | |
| enable_synthetic_data = gr.Checkbox(label="Habilitar Datos Sintéticos", value=False) | |
| synthetic_model_id = gr.Textbox(label="ID del Modelo Generador", placeholder="p.ej. 'mistralai/Mistral-7B-Instruct-v0.2'") | |
| num_synthetic_samples = gr.Number(label="Número de Muestras", value=1000) | |
| with gr.Accordion("📝 Configuración de Formato y Tarea", open=False): | |
| with gr.Group(visible=True) as sft_ui: | |
| sft_format_style = gr.Radio(["Columna de Texto", "Conversacional", "Razonamiento/Herramientas"], label="Formato de Datos SFT", value="Columna de Texto") | |
| chat_template_jinja = gr.Textbox(label="Plantilla de Chat Jinja2 (opcional)", lines=5) | |
| with gr.Group(visible=False) as sft_tool_ui: | |
| enable_cot_input = gr.Checkbox(label="Habilitar Razonamiento (CoT)", value=True) | |
| enable_tool_use_input = gr.Checkbox(label="Habilitar Uso de Herramientas", value=True) | |
| prompt_col_input = gr.Textbox(label="Columna de Prompt/Usuario", value="prompt") | |
| response_col_input = gr.Textbox(label="Columna de Respuesta Final", value="response") | |
| reasoning_col_input = gr.Textbox(label="Columna de Razonamiento", value="reasoning") | |
| tool_use_col_input = gr.Textbox(label="Columna de Uso de Herramientas", value="tools") | |
| with gr.Group(visible=False) as dpo_ui: | |
| dpo_prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt") | |
| dpo_chosen_col_input = gr.Textbox(label="Columna Elegida", value="chosen") | |
| dpo_rejected_col_input = gr.Textbox(label="Columna Rechazada", value="rejected") | |
| with gr.Group(visible=False) as classification_labels_ui: | |
| classification_labels = gr.Textbox(label="Etiquetas de Clasificación (csv)", placeholder="p.ej. positivo,negativo") | |
| with gr.Group(visible=False) as diffusion_ui: | |
| diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resolución") | |
| with gr.Group(visible=False) as dreambooth_ui: | |
| dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'") | |
| dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True) | |
| with gr.Accordion("📊 Evaluación y Mitigación de Sesgos", open=False): | |
| run_evaluation = gr.Checkbox(label="Ejecutar Evaluación", value=False) | |
| run_perplexity_evaluation = gr.Checkbox(label="Calcular Perplejidad", value=True) | |
| enable_loss_reweighting = gr.Checkbox(label="Habilitar Re-ponderación de Pérdida", value=False) | |
| reweighting_terms = gr.Textbox(label="Términos para Re-ponderar (csv)", placeholder="sesgo,injusto") | |
| reweighting_factor = gr.Slider(1.1, 10.0, 2.0, label="Factor de Re-ponderación") | |
| enable_cda = gr.Checkbox(label="Habilitar Aumentación Contrafactual (CDA)", value=False) | |
| cda_json_config = gr.Textbox(label="Configuración CDA (JSON)", placeholder='[["ella", "él"], ["mujer", "hombre"]]') | |
| with gr.Accordion("🔌 Integraciones", open=False): | |
| wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password") | |
| wandb_project_input = gr.Textbox(label="Proyecto W&B") | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 📈 Progreso y Resultados") | |
| with gr.Row(): | |
| start_training_button = gr.Button("Iniciar Entrenamiento", variant="primary", scale=3) | |
| stop_training_button = gr.Button("Detener", variant="stop", visible=False, scale=1) | |
| training_phase = gr.Label(label="Fase Actual", value="En espera") | |
| repo_link_output = gr.Markdown(label="Enlace al Repositorio del Modelo") | |
| final_eval_results = gr.JSON(label="Resultados de Evaluación Final") | |
| training_logs = gr.Textbox(label="Registros de Entrenamiento", lines=35, interactive=False) | |
| all_input_components_dict = { | |
| "training_mode": training_mode, "model_base_input": model_base_input, "tokenizer_name_input": tokenizer_name_input, | |
| "repo_name_input": repo_name_input, "train_from_scratch": train_from_scratch, "auto_config_scratch": auto_config_scratch, | |
| "scratch_architecture": scratch_architecture, "enable_multi_adapter_merge": enable_multi_adapter_merge, | |
| "multi_adapter_model_ids": multi_adapter_model_ids, "multi_adapter_weights": multi_adapter_weights, | |
| "multi_adapter_combination_type": multi_adapter_combination_type, "datasets_hf_text": datasets_hf_text, | |
| "uploads": uploads, "dataset_weights": dataset_weights, "eval_dataset_hf": eval_dataset_hf, | |
| "learning_rate": learning_rate, "epochs": epochs, "max_steps": max_steps, "batch_size": batch_size, "gradient_accumulation": gradient_accumulation, | |
| "block_size": block_size, "optimizer": optimizer, "scheduler": scheduler, | |
| "mixed_precision": mixed_precision, "warmup_ratio": warmup_ratio, "weight_decay": weight_decay, "max_grad_norm": max_grad_norm, | |
| "logging_steps": logging_steps, "save_steps": save_steps, "save_total_limit": save_total_limit, "resume_from_checkpoint": resume_from_checkpoint, | |
| "adam_beta1": adam_beta1, "adam_beta2": adam_beta2, "adam_epsilon": adam_epsilon, | |
| "disable_gradient_checkpointing": disable_gradient_checkpointing, "group_by_length": group_by_length, | |
| "neftune_noise_alpha": neftune_noise_alpha, "optim_args": optim_args, "attn_implementation": attn_implementation, | |
| "early_stopping_patience": early_stopping_patience, | |
| "peft": peft, "quantization": quantization, "lora_r": lora_r, "lora_alpha": lora_alpha, | |
| "lora_dropout": lora_dropout, "auto_find_target_modules": auto_find_target_modules, "target_modules": target_modules, | |
| "modules_to_save": modules_to_save, "use_dora": use_dora, "use_rslora": use_rslora, "init_lora_weights": init_lora_weights, | |
| "remove_html_tags": remove_html_tags, "normalize_whitespace": normalize_whitespace, "remove_urls_emails": remove_urls_emails, | |
| "redact_pii": redact_pii, "enable_quality_filter": enable_quality_filter, "min_len_input": min_len_input, | |
| "max_len_input": max_len_input, "rep_threshold_input": rep_threshold_input, "exclude_keywords_input": exclude_keywords_input, | |
| "deduplication_method": deduplication_method, "minhash_threshold": minhash_threshold, "minhash_num_perm": minhash_num_perm, | |
| "enable_cda": enable_cda, "cda_json_config": cda_json_config, | |
| "enable_back_translation": enable_back_translation, "bt_model_id": bt_model_id, | |
| "bt_reverse_model_id": bt_reverse_model_id, "enable_synthetic_data": enable_synthetic_data, | |
| "synthetic_model_id": synthetic_model_id, "num_synthetic_samples": num_synthetic_samples, | |
| "sft_format_style": sft_format_style, "chat_template_jinja": chat_template_jinja, | |
| "enable_cot_input": enable_cot_input, "enable_tool_use_input": enable_tool_use_input, | |
| "prompt_col_input": prompt_col_input, "response_col_input": response_col_input, | |
| "reasoning_col_input": reasoning_col_input, "tool_use_col_input": tool_use_col_input, | |
| "dpo_prompt_col_input": dpo_prompt_col_input, "dpo_chosen_col_input": dpo_chosen_col_input, | |
| "dpo_rejected_col_input": dpo_rejected_col_input, "classification_labels": classification_labels, | |
| "diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation, | |
| "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms, "reweighting_factor": reweighting_factor, | |
| "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input, | |
| "dreambooth_instance_prompt": dreambooth_instance_prompt, | |
| "dreambooth_train_text_encoder": dreambooth_train_text_encoder | |
| } | |
| all_input_components_list = list(all_input_components_dict.values()) | |
| all_output_components = [training_logs, training_phase, repo_link_output, final_eval_results, start_training_button, stop_training_button] | |
| preview_data_button.click( | |
| gradio_preview_data_wrapper, | |
| inputs=all_input_components_list, | |
| outputs=[data_preview_output] | |
| ) | |
| train_from_scratch.change( | |
| toggle_training_mode_ui, | |
| inputs=[train_from_scratch], | |
| outputs=[model_base_input, tokenizer_name_input, multi_adapter_accordion, peft_accordion, auto_config_scratch, scratch_architecture] | |
| ) | |
| training_mode.change( | |
| toggle_task_specific_ui, | |
| inputs=[training_mode], | |
| outputs=[classification_labels_ui, dpo_ui, sft_ui, diffusion_ui, dreambooth_ui, peft_accordion, epochs_ui, max_steps_ui, peft_accordion] | |
| ) | |
| sft_format_style.change( | |
| toggle_sft_format_ui, | |
| inputs=[sft_format_style], | |
| outputs=[sft_tool_ui] | |
| ) | |
| auto_find_target_modules.change( | |
| toggle_auto_modules_ui, | |
| inputs=[auto_find_target_modules], | |
| outputs=[target_modules] | |
| ) | |
| train_event = start_training_button.click( | |
| gradio_train_wrapper, | |
| inputs=all_input_components_list, | |
| outputs=all_output_components | |
| ) | |
| stop_training_button.click(fn=None, inputs=None, outputs=None, cancels=[train_event]) | |
| with gr.Tab("4. Inferencia"): | |
| gr.Markdown("## 🧪 Probar un Modelo del Hub") | |
| with gr.Row(): | |
| inf_task_mode = gr.Dropdown(TRAINING_MODES, label="Tipo de Tarea", value=TRAINING_MODES[0]) | |
| inf_model_id = gr.Textbox(label="ID del Modelo en el Hub", placeholder="TuUsuario/TuModeloEntrenado") | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| inf_text_in = gr.Textbox(label="Entrada de Texto / Prompt", lines=5) | |
| inf_context_in = gr.Textbox(label="Contexto (para QA)", lines=3, visible=False) | |
| inf_image_in = gr.Image(label="Entrada de Imagen", type="pil", visible=False) | |
| inf_audio_in = gr.Audio(label="Entrada de Audio", type="filepath", visible=False) | |
| with gr.Accordion("Opciones Avanzadas de Generación", open=False, visible=True) as inf_advanced_options: | |
| inf_temperature = gr.Slider(0.1, 2.0, 0.7, label="Temperatura") | |
| inf_top_p = gr.Slider(0.1, 1.0, 0.95, label="Top-p") | |
| inf_max_new_tokens = gr.Slider(10, 1024, 100, step=1, label="Máximos Tokens Nuevos") | |
| with gr.Row(): | |
| run_inference_btn = gr.Button("Ejecutar Inferencia", variant="primary") | |
| with gr.Column(scale=3): | |
| inf_text_out = gr.Textbox(label="Salida de Texto", lines=15, interactive=False) | |
| inf_task_mode.change( | |
| update_inference_ui, | |
| inputs=[inf_task_mode], | |
| outputs=[inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_advanced_options] | |
| ) | |
| run_inference_btn.click( | |
| run_inference, | |
| inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens], | |
| outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in] | |
| ) | |
| with gr.Tab("5. Explicación del Código y Mecanismos Avanzados"): | |
| gr.Markdown(""" | |
| ### 🧠 Explicación del Código y Mecanismos Avanzados | |
| """) | |
| gr.Markdown("#### 1. CORE MECHANISMS") | |
| gr.Markdown(""" | |
| * **PEFT/LoRA**: Parameter-Efficient Fine-Tuning. Only low-rank matrices ($A$ and $B$) are trained for low-rank updates ($W' = W + B A$). This drastically reduces trainable parameters. | |
| * **QLoRA (4-bit)**: Loads the base model weights in 4-bit precision (NF4 with double quantization) using `bitsandbytes`, massively reducing VRAM usage while training LoRA adapters. | |
| * **Accelerator**: Manages device placement (CPU/GPU), mixed precision (`fp16`/`bf16`), and gradient accumulation for stable large-batch training simulation. | |
| * **Early Stopping**: Halts training if validation loss doesn't improve over a set number of steps (`early_stopping_patience`). | |
| * **Gradient Accumulation**: Simulates larger batch sizes by accumulating gradients over several forward/backward passes before an optimization step. | |
| * **Gradient Clipping**: Limits the maximum norm of the gradients (`max_grad_norm`) to prevent exploding gradients during training. | |
| * **Memory Optimization**: Optional use of `xFormers` (FlashAttention or memory-efficient attention) to reduce memory footprint and speed up training on compatible GPUs. | |
| """) | |
| gr.Markdown("#### 2. DATA PROCESSING & AUGMENTATION") | |
| gr.Markdown(""" | |
| * **Streaming Datasets**: Uses `datasets` streaming mode to handle very large datasets without loading all into RAM. | |
| * **Data Cleaning**: Removes HTML tags, normalizes whitespace, redacts PII, and removes URLs/emails. | |
| * **Advanced Filtering**: Includes optional filters for text length, word repetition, language detection, and basic toxicity detection (via `unitary/toxic-bert`). | |
| * **Data Augmentation**: Supports **Back-Translation (BT)** for introducing paraphrasing variations and **Counterfactual Data Augmentation (CDA)** for controlled bias testing (e.g., swapping gendered pronouns). | |
| * **Synthetic Data Generation**: Uses a specified LLM to generate new training examples based on an initial prompt template. | |
| * **Deduplication**: Implements both **Exact** and **Semantic (MinHash LSH)** deduplication to prevent data contamination during iterative fine-tuning. | |
| """) | |
| gr.Markdown("#### 3. TRAINING MODES") | |
| gr.Markdown(""" | |
| * **SFT (Supervised Fine-Tuning)**: Standard fine-tuning, supports **Conversation** and **Reasoning/Tool Use (CoT)** formatting styles. | |
| * **DPO (Direct Preference Optimization)**: Trains directly on preference pairs (chosen vs. rejected), using the `trl` library. | |
| * **Task-Specific Heads**: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`). | |
| * **Seq2Seq**: For translation/summarization tasks, using `Seq2SeqTrainer`. | |
| * **Diffusion (Text-to-Image/DreamBooth)**: Fine-tunes the UNet (and optionally Text Encoder) using LoRA for image generation tasks, with custom image/video data handling. | |
| """) | |
| gr.Markdown("#### 4. MODEL INITIALIZATION") | |
| gr.Markdown(""" | |
| * **Model From Scratch**: Allows initializing a model (e.g., Llama, Mistral) from a config rather than a pre-trained checkpoint, with optional auto-configuration based on expected training scale. | |
| * **Multi-Adapter Merging**: Advanced feature to combine multiple existing LoRA adapters into a single, new adapter using weighted averaging (`slerp`, `linear`, etc.). | |
| """) | |
| gr.Markdown("#### 5. OUTPUT & DEPLOYMENT") | |
| gr.Markdown(""" | |
| * **Hugging Face Hub Integration**: All trained artifacts (full model/LoRA adapter) are automatically pushed to a specified repository on the HF Hub using the provided token. | |
| * **Model Card Generation**: Automatically generates a `README.md` detailing training parameters and model provenance. | |
| * **Inference Tabs**: Separate UI for testing the trained LoRA adapter on CPU (for Gemma/LoRA) or various pipeline modes on GPU. | |
| """) | |
| gr.Markdown("### 💡 Hardware Fallback") | |
| gr.Markdown(f"If CUDA/GPU is unavailable, the system defaults to CPU: **{device.upper()}**. Training and inference on CPU will be significantly slower, especially for large models or Diffusers.") | |
| if __name__ == "__main__": | |
| #demo.queue().launch(debug=True, share=True) | |
| demo.launch(debug=True, share=True) |