Spaces:
Runtime error
Runtime error
| import os | |
| os.system("pip install -U torchao transformers peft accelerate trl gradio_huggingfacehub_search bitsandbytes datasets diffusers") | |
| os.system("pip install spaces-0.1.0-py3-none-any.whl") | |
| import io | |
| import json | |
| import tempfile | |
| import string | |
| import gc | |
| import math | |
| import uuid | |
| import logging | |
| import traceback | |
| import importlib | |
| import random | |
| import re | |
| import ast | |
| import shutil | |
| from itertools import islice | |
| from pathlib import Path | |
| from collections import defaultdict | |
| from datetime import datetime | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| import numpy as np | |
| import pandas as pd | |
| import accelerate | |
| from PIL import Image | |
| import torchvision | |
| import torchvision.transforms as T | |
| from torchvision import transforms | |
| import torchaudio | |
| from bs4 import BeautifulSoup | |
| from langdetect import detect_langs | |
| import textstat | |
| from datasketch import MinHash, MinHashLSH | |
| import gradio as gr | |
| from datasets import load_dataset, IterableDataset, Dataset as HFDataset, DatasetDict, interleave_datasets, Audio | |
| from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi, hf_hub_download, list_repo_files, snapshot_download, list_models | |
| from transformers import ( | |
| AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer, | |
| AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, | |
| AutoModelForImageClassification, AutoModel, TorchAoConfig, | |
| AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification, | |
| DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, DataCollatorWithPadding, pipeline, | |
| DataCollatorForSeq2Seq, AutoModelForSequenceClassification, | |
| LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel, | |
| PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM, | |
| DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor | |
| ) | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| from trl import SFTTrainer, DPOTrainer | |
| import evaluate as hf_evaluate | |
| from jinja2 import Template | |
| import spaces | |
| from tqdm.auto import tqdm | |
| from diffusers import ( | |
| UNet2DConditionModel, DDPMScheduler, AutoencoderKL, | |
| get_scheduler as get_diffusers_scheduler, StableDiffusionPipeline as StableDiffusionText2ImagePipeline, | |
| StableDiffusionImg2ImgPipeline as StableDiffusionImage2ImagePipeline | |
| ) | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from packaging import version | |
| from torchao.quantization import ( | |
| Int4WeightOnlyConfig, | |
| Int8WeightOnlyConfig, | |
| Int8DynamicActivationInt8WeightConfig, | |
| Float8WeightOnlyConfig, | |
| Float8DynamicActivationFloat8WeightConfig, | |
| GemliteUIntXWeightOnlyConfig, | |
| ) | |
| from torchao.dtypes import Int4CPULayout | |
| from llmcompressor import oneshot | |
| from llmcompressor.modifiers.awq import AWQModifier | |
| logger = logging.getLogger(__name__) | |
| torch_dtype_auto = torch.float32 | |
| def _sanitize_model_name_for_yaml(model_name): | |
| name = model_name.split('/')[-1] if '/' in model_name else model_name | |
| sanitized = re.sub(r'[^a-zA-Z0-9\-_\.]', '-', name) | |
| return sanitized if sanitized else "model" | |
| ARCHITECTURE_MAP = {"Llama": (LlamaConfig, LlamaForCausalLM), "Mistral": (MistralConfig, MistralForCausalLM), "Gemma": (GemmaConfig, GemmaForCausalLM), "GPT2": (GPT2Config, GPT2LMHeadModel), "Phi": (PhiConfig, PhiForCausalLM), "Qwen2": (Qwen2Config, Qwen2ForCausalLM)} | |
| SCRATCH_TOKENIZER_MAP = {"Llama": "meta-llama/Llama-2-7b-hf", "Mistral": "mistralai/Mistral-7B-v0.1", "Gemma": "google/gemma-2b", "GPT2": "gpt2", "Phi": "microsoft/phi-2", "Qwen2": "Qwen/Qwen2-0.5B"} | |
| TRAINING_MODES = [ | |
| "Causal Language Modeling (SFT/LoRA)", | |
| "DPO (Direct Preference Optimization)", | |
| "Question Answering (Text)", | |
| "Token Classification (NER)", | |
| "Sequence Classification (Text)", | |
| "Text-to-Image Generation", | |
| "Image Classification (Vision)", | |
| "Audio Classification (Speech)", | |
| "ASR (Speech-to-Text)", | |
| "Text2Text Generation" | |
| ] | |
| TASK_TO_PIPELINE_MAP = { | |
| "Causal Language Modeling (SFT/LoRA)": "text-generation", | |
| "DPO (Direct Preference Optimization)": "text-generation", | |
| "Question Answering (Text)": "question-answering", | |
| "Token Classification (NER)": "token-classification", | |
| "Sequence Classification (Text)": "text-classification", | |
| "Image Classification (Vision)": "image-classification", | |
| "Audio Classification (Speech)": "audio-classification", | |
| "ASR (Speech-to-Text)": "automatic-speech-recognition", | |
| "Text2Text Generation": "text2text-generation", | |
| "Text-to-Image Generation": "text-to-image", | |
| } | |
| MODEL_CARD_TEMPLATE = """--- | |
| language: es | |
| license: apache-2.0 | |
| tags: | |
| - autotrain-advanced | |
| - fine-tuned | |
| - {base_model_name} | |
| widget: | |
| - text: "Hola, ¿cómo estás?" | |
| --- | |
| # {repo_id} | |
| Este modelo es una versión afinada de [{base_model}](https://huggingface.co/{base_model}) entrenado con la herramienta [AutoTrain-Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced). | |
| ## Detalles del Entrenamiento | |
| - **Modo de Entrenamiento:** {training_mode} | |
| - **Modelo Base:** `{base_model}` | |
| - **Datasets:** `{datasets}` | |
| - **Entrenado en:** {date} | |
| ### Hiperparámetros de Entrenamiento | |
| ```json | |
| {hyperparameters}``` | |
| ### Frameworks Utilizados | |
| - Transformers | |
| - PEFT | |
| - Accelerate | |
| - TRL | |
| - Gradio | |
| """ | |
| DATASET_CARD_TEMPLATE = """--- | |
| license: mit | |
| --- | |
| # {repo_id} | |
| Este dataset fue creado utilizando la herramienta [AutoTrain-Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced). | |
| ## Detalles del Dataset | |
| - **Tipo de Creación:** {creation_type} | |
| - **Modelo de Generación (si aplica):** `{generation_model}` | |
| - **Fecha de Creación:** {date} | |
| """ | |
| MAP_QUANT_TYPE_TO_NAME = { | |
| "Int4WeightOnly": "int4wo", | |
| "GemliteUIntXWeightOnly": "intxwo-gemlite", | |
| "Int8WeightOnly": "int8wo", | |
| "Int8DynamicActivationInt8Weight": "int8da8w8", | |
| "Float8WeightOnly": "float8wo", | |
| "Float8DynamicActivationFloat8Weight": "float8da8w8", | |
| "autoquant": "autoquant", | |
| } | |
| MAP_QUANT_TYPE_TO_CONFIG = { | |
| "Int4WeightOnly": Int4WeightOnlyConfig, | |
| "GemliteUIntXWeightOnly": GemliteUIntXWeightOnlyConfig, | |
| "Int8WeightOnly": Int8WeightOnlyConfig, | |
| "Int8DynamicActivationInt8Weight": Int8DynamicActivationInt8WeightConfig, | |
| "Float8WeightOnly": Float8WeightOnlyConfig, | |
| "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig, | |
| } | |
| _tox_pipe_singleton = None | |
| class DebiasingSFTTrainer(SFTTrainer): | |
| def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.reweighting_terms = [term.strip().lower() for term in reweighting_terms] if reweighting_terms else [] | |
| self.reweighting_factor = reweighting_factor | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| loss, outputs = super().compute_loss(model, inputs, return_outputs=True) | |
| if self.reweighting_terms and self.reweighting_factor > 1.0: | |
| input_ids = inputs.get("input_ids") | |
| decoded_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) | |
| for text in decoded_texts: | |
| if any(term in text.lower() for term in self.reweighting_terms): | |
| loss *= self.reweighting_factor | |
| break | |
| return (loss, outputs) if return_outputs else loss | |
| class DeduplicatedIterableDataset(IterableDataset): | |
| def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128): | |
| super().__init__(ex_iterable=iter([])) | |
| self.dataset = dataset | |
| self.text_col = text_col | |
| self.method = method | |
| self.threshold = threshold | |
| self.num_perm = num_perm | |
| if hasattr(dataset, '_info'): | |
| self._info = dataset._info | |
| elif hasattr(dataset, 'info'): | |
| self._info = dataset.info | |
| def __iter__(self): | |
| if self.method == 'Exacta': | |
| return self._exact_iter() | |
| elif self.method == 'Semántica (MinHash)': | |
| return self._minhash_iter() | |
| else: | |
| return iter(self.dataset) | |
| def _exact_iter(self): | |
| seen_texts = set() | |
| for example in self.dataset: | |
| text = example.get(self.text_col, "") | |
| if text and isinstance(text, str): | |
| if text not in seen_texts: | |
| seen_texts.add(text) | |
| yield example | |
| else: | |
| yield example | |
| def _minhash_iter(self): | |
| lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm) | |
| for i, example in enumerate(self.dataset): | |
| text = example.get(self.text_col, "") | |
| if text and isinstance(text, str) and text.strip(): | |
| m = MinHash(num_perm=self.num_perm) | |
| for d in text.split(): | |
| m.update(d.encode('utf8')) | |
| if not lsh.query(m): | |
| lsh.insert(f"key_{i}", m) | |
| yield example | |
| else: | |
| yield example | |
| def hf_login(token): | |
| if not token: | |
| return "Por favor, introduce un token." | |
| try: | |
| login(token=token, add_to_git_credential=True) | |
| user = whoami() | |
| return f"✅ Conectado como: {user['name']}" | |
| except Exception as e: | |
| return f"❌ Error en la conexión: {e}" | |
| def _clean_text(example, text_col, **kwargs): | |
| text = example.get(text_col, "") | |
| if not isinstance(text, str): | |
| return example | |
| if kwargs.get('remove_html_tags'): | |
| text = BeautifulSoup(text, "html.parser").get_text() | |
| if kwargs.get('remove_urls_emails'): | |
| text = re.sub(r'http\S+|www\S+|httpsS+', '', text, flags=re.MULTILINE) | |
| if kwargs.get('normalize_whitespace'): | |
| text = ' '.join(text.split()) | |
| if kwargs.get('redact_pii'): | |
| text = re.sub(r'\S+@\S+', '<EMAIL>', text) | |
| text = re.sub(r'(\d{1,4}[-.\s]?){7,}|(\+\d{1,3}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}', '<PHONE>', text) | |
| text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '<IP_ADDRESS>', text) | |
| example[text_col] = text | |
| return example | |
| def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords): | |
| text = example.get(text_col, "") | |
| if not isinstance(text, str): return False | |
| text_len = len(text.split()) | |
| if not (min_len <= text_len <= max_len): return False | |
| words = text.split() | |
| if not words: return False | |
| word_counts = {} | |
| for word in words: word_counts[word] = word_counts.get(word, 0) + 1 | |
| if not word_counts or (max(word_counts.values()) / len(words)) > rep_threshold: return False | |
| lower_text = text.lower() | |
| return not any(keyword in lower_text for keyword in exclude_keywords) | |
| def _apply_coherence_filter(example, text_col, char_rep_threshold, ngram_rep_threshold, entropy_threshold): | |
| text = example.get(text_col, "") | |
| if not isinstance(text, str) or not text: | |
| return False | |
| char_repetition_ratio = 0 | |
| if len(text) > 0: | |
| for char in set(text): | |
| if char.isalnum() or char in '.,;:!?': | |
| char_count = text.count(char) | |
| char_ratio = char_count / len(text) | |
| char_repetition_ratio = max(char_repetition_ratio, char_ratio) | |
| if char_repetition_ratio > char_rep_threshold: | |
| return False | |
| text_lower = text.lower() | |
| repeated_chars = 0 | |
| ngram_counts = {} | |
| for n in [3, 4, 5]: | |
| if len(text_lower) >= n: | |
| for i in range(len(text_lower) - n + 1): | |
| ngram = text_lower[i:i+n] | |
| if ngram.isalpha(): | |
| ngram_counts[ngram] = ngram_counts.get(ngram, 0) + 1 | |
| if ngram_counts: | |
| highly_repeated_ngrams = {ng for ng, count in ngram_counts.items() if count > 3} | |
| if highly_repeated_ngrams: | |
| covered_positions = set() | |
| for i in range(len(text_lower)): | |
| for n in [3, 4, 5]: | |
| if i + n <= len(text_lower): | |
| ngram = text_lower[i:i+n] | |
| if ngram in highly_repeated_ngrams: | |
| for j in range(i, i+n): | |
| covered_positions.add(j) | |
| repetition_coverage = len(covered_positions) / len(text_lower) | |
| if repetition_coverage > ngram_rep_threshold: | |
| return False | |
| if len(text) > 10: | |
| char_freq = {} | |
| for char in text: | |
| char_freq[char] = char_freq.get(char, 0) + 1 | |
| entropy = 0 | |
| for count in char_freq.values(): | |
| p = count / len(text) | |
| if p > 0: | |
| entropy -= p * math.log2(p) | |
| max_entropy = math.log2(len(char_freq)) if len(char_freq) > 0 else 1 | |
| normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0 | |
| if normalized_entropy < entropy_threshold: | |
| return False | |
| if len(text) > 0: | |
| alnum_count = sum(1 for c in text if c.isalnum() or c.isspace()) | |
| alnum_ratio = alnum_count / len(text) | |
| if alnum_ratio < 0.7: | |
| return False | |
| scripts = { | |
| 'greek': sum(1 for c in text if '\u0370' <= c <= '\u03FF'), | |
| 'cyrillic': sum(1 for c in text if '\u0400' <= c <= '\u04FF'), | |
| 'arabic': sum(1 for c in text if '\u0600' <= c <= '\u06FF'), | |
| 'chinese': sum(1 for c in text if '\u4E00' <= c <= '\u9FFF'), | |
| } | |
| non_latin_chars = sum(scripts.values()) | |
| latin_chars = sum(1 for c in text if c.isalpha() and not any(scripts.values())) | |
| if non_latin_chars > 2 and latin_chars > 10: | |
| return False | |
| return True | |
| def _get_filter_functions(**kwargs): | |
| filters = [] | |
| if kwargs.get('enable_quality_filter'): | |
| exclude_list = [k.strip().lower() for k in (kwargs.get('exclude_keywords_input', '') + ',' + kwargs.get('bias_keywords_input', '')).split(",") if k.strip()] | |
| filters.append(lambda ex: _apply_quality_filters(ex, kwargs['text_col'], kwargs['min_len_input'], kwargs['max_len_input'], kwargs['rep_threshold_input'], exclude_list)) | |
| if kwargs.get('enable_language_filter'): | |
| allowed_langs = [lang.strip() for lang in kwargs.get('allowed_languages', 'en').split(',')] | |
| lang_threshold = kwargs.get('language_detection_threshold', 0.95) | |
| def lang_filter(ex): | |
| text = ex.get(kwargs['text_col'], "") | |
| if not text or not isinstance(text, str) or len(text.split()) < 5: return True | |
| try: | |
| detected = detect_langs(text) | |
| return any(lang.lang in allowed_langs and lang.prob > lang_threshold for lang in detected) | |
| except: | |
| return False | |
| filters.append(lang_filter) | |
| if kwargs.get('enable_toxicity_filter'): | |
| tox_threshold = kwargs.get('toxicity_threshold', 0.8) | |
| def tox_filter(ex): | |
| global _tox_pipe_singleton | |
| if _tox_pipe_singleton is None: | |
| logger.info("Initializing toxicity filter pipeline...") | |
| _tox_pipe_singleton = pipeline("text-classification", model="unitary/toxic-bert") | |
| text = ex.get(kwargs['text_col'], "") | |
| if not text or not isinstance(text, str): return True | |
| try: | |
| results = _tox_pipe_singleton(text[:512], truncation=True) | |
| return not (results[0]['label'] == 'toxic' and results[0]['score'] > tox_threshold) | |
| except Exception: | |
| return True | |
| filters.append(tox_filter) | |
| if kwargs.get('enable_coherence_filter'): | |
| char_rep_thresh = kwargs.get('coherence_char_repetition_threshold', 0.4) | |
| ngram_rep_thresh = kwargs.get('coherence_ngram_repetition_threshold', 0.3) | |
| entropy_thresh = kwargs.get('coherence_entropy_threshold', 0.5) | |
| filters.append(lambda ex: _apply_coherence_filter(ex, kwargs['text_col'], char_rep_thresh, ngram_rep_thresh, entropy_thresh)) | |
| if any([kwargs.get('enable_readability_filter'), kwargs.get('enable_stopword_filter'), kwargs.get('enable_uniqueness_filter')]): | |
| stop_words = set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should', 'could', 'can', 'may', 'might', 'must', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'what', 'which', 'who', 'when', 'where', 'why', 'how']) | |
| def stats_filter(ex): | |
| text = ex.get(kwargs['text_col'], "") | |
| if not isinstance(text, str) or not text: return True | |
| words = text.split() | |
| num_words = len(words) | |
| if num_words == 0: return True | |
| if kwargs.get('enable_readability_filter'): | |
| try: | |
| score = textstat.flesch_reading_ease(text) | |
| if not (kwargs['min_readability'] <= score <= kwargs['max_readability']): return False | |
| except: | |
| pass | |
| if kwargs.get('enable_stopword_filter'): | |
| stopword_count = sum(1 for word in words if word.lower() in stop_words) | |
| if num_words > 0 and (stopword_count / num_words) > kwargs['max_stopword_ratio']: return False | |
| if kwargs.get('enable_uniqueness_filter'): | |
| if (len(set(words)) / num_words) < kwargs['min_uniqueness_ratio']: return False | |
| return True | |
| filters.append(stats_filter) | |
| return filters | |
| def _load_hf_streaming(ids, split="train", probabilities=None): | |
| streams = [] | |
| valid_ids = [] | |
| for ident in ids: | |
| try: | |
| d = load_dataset(ident, streaming=True, trust_remote_code=True) | |
| split_found = False | |
| if isinstance(d, dict): | |
| for s_name, ds in d.items(): | |
| if s_name.lower() == split or (split == "train" and "train" in s_name.lower()): | |
| streams.append(ds) | |
| split_found = True | |
| break | |
| else: | |
| streams.append(d) | |
| split_found = True | |
| if split_found: | |
| valid_ids.append(ident) | |
| else: | |
| logger.warning(f"Split '{split}' not found in dataset {ident}. Excluding from this source.") | |
| except Exception as e: | |
| logger.error(f"Error loading dataset {ident} split {split}: {e}. Excluding from this source.") | |
| if not streams: | |
| return None | |
| if probabilities and len(probabilities) != len(streams): | |
| logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.") | |
| probabilities = None | |
| return interleave_datasets(streams, probabilities=probabilities) | |
| def _load_uploaded_stream(files): | |
| all_rows = [] | |
| for f in files or []: | |
| content = f.read().decode("utf-8", errors="ignore") | |
| name = f.name.lower() | |
| if name.endswith(".csv"): | |
| import csv | |
| all_rows.extend(list(csv.DictReader(io.StringIO(content)))) | |
| elif name.endswith(".jsonl"): | |
| all_rows.extend([json.loads(line) for line in io.StringIO(content) if line.strip()]) | |
| elif name.endswith(".json"): | |
| data = json.loads(content) | |
| all_rows.extend(data if isinstance(data, list) else [data]) | |
| elif name.endswith(".txt"): | |
| all_rows.extend([{"text": line} for line in io.StringIO(content) if line.strip()]) | |
| if not all_rows: | |
| return None | |
| val_size = max(1, int(len(all_rows) * 0.01)) | |
| random.shuffle(all_rows) | |
| return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []} | |
| def _guess_columns(sample): | |
| text_col, image_col, audio_col, label_col = "text", "image", "audio", "label" | |
| if not isinstance(sample, dict): | |
| return text_col, image_col, audio_col, label_col | |
| keys = {k.lower(): k for k in sample.keys()} | |
| if "text" in keys: text_col = keys["text"] | |
| elif "content" in keys: text_col = keys["content"] | |
| elif "prompt" in keys: text_col = keys["prompt"] | |
| if "image" in keys: image_col = keys["image"] | |
| elif "img" in keys: image_col = keys["img"] | |
| if "audio" in keys: audio_col = keys["audio"] | |
| elif "speech" in keys: audio_col = keys["speech"] | |
| if "label" in keys: label_col = keys["label"] | |
| elif "labels" in keys: label_col = keys["labels"] | |
| return text_col, image_col, audio_col, label_col | |
| def _apply_cda(dataset, text_col, cda_config_str): | |
| try: | |
| swap_groups = json.loads(cda_config_str) | |
| except (json.JSONDecodeError, ValueError) as e: | |
| logger.error(f"Configuración de CDA inválida: {e}.") | |
| return dataset | |
| def cda_generator(): | |
| for example in dataset: | |
| original_text = example.get(text_col, "") | |
| if not isinstance(original_text, str): | |
| yield example | |
| continue | |
| yield example | |
| generated_texts = {original_text} | |
| current_texts = {original_text} | |
| for group in swap_groups: | |
| next_texts = set() | |
| for text in current_texts: | |
| for word_to_replace in group: | |
| if word_to_replace in text: | |
| for replacement_word in group: | |
| if word_to_replace != replacement_word: | |
| new_text = text.replace(word_to_replace, replacement_word) | |
| if new_text not in generated_texts: | |
| new_example = example.copy() | |
| new_example[text_col] = new_text | |
| yield new_example | |
| generated_texts.add(new_text) | |
| next_texts.add(new_text) | |
| current_texts.update(next_texts) | |
| return IterableDataset.from_generator(cda_generator) | |
| def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id): | |
| if not ratio or ratio <= 0: | |
| return dataset | |
| logger.info(f"Aplicando retrotraducción al {ratio*100}% del dataset.") | |
| try: | |
| pipe_to = pipeline("translation", model=model_id, ) | |
| pipe_from = pipeline("translation", model=reverse_model_id, ) | |
| except Exception as e: | |
| logger.error(f"No se pudieron cargar los modelos de traducción: {e}") | |
| return dataset | |
| def bt_generator(): | |
| for example in dataset: | |
| yield example | |
| if random.random() < ratio: | |
| original_text = example.get(text_col, "") | |
| if isinstance(original_text, str) and original_text: | |
| try: | |
| translated = pipe_to(original_text, max_length=512)[0]['translation_text'] | |
| back_translated = pipe_from(translated, max_length=512)[0]['translation_text'] | |
| if back_translated: | |
| new_example = example.copy() | |
| new_example[text_col] = back_translated | |
| yield new_example | |
| except Exception as e: | |
| logger.warning(f"Error en retrotraducción: {e}") | |
| return IterableDataset.from_generator(bt_generator) | |
| def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template): | |
| if not num_samples or num_samples <= 0: | |
| return None | |
| logger.info(f"Iniciando generación de {num_samples} muestras sintéticas con el modelo {model_id}.") | |
| try: | |
| generator = pipeline("text-generation", model=model_id, ) | |
| except Exception as e: | |
| logger.error(f"No se pudo cargar el modelo generador sintético: {e}") | |
| return None | |
| seed_examples = list(islice(original_dataset, 200)) | |
| if not seed_examples: | |
| logger.warning("Dataset original vacío, no se pueden generar datos sintéticos.") | |
| return None | |
| def synthetic_generator(): | |
| for i in range(num_samples): | |
| seed_example = random.choice(seed_examples) | |
| seed_text = seed_example.get(text_col, "") | |
| prompt = Template(prompt_template).render(example_text=seed_text) | |
| try: | |
| generated_output = generator(prompt, max_new_tokens=256, num_return_sequences=1, do_sample=True, temperature=0.9, top_p=0.95) | |
| cleaned_text = generated_output[0]['generated_text'][len(prompt):].strip() | |
| if "new example:" in cleaned_text.lower(): | |
| cleaned_text = re.split("new example:", cleaned_text, flags=re.IGNORECASE)[-1].strip() | |
| if cleaned_text: | |
| new_example = seed_example.copy() | |
| new_example[text_col] = cleaned_text | |
| yield new_example | |
| except Exception as e: | |
| logger.warning(f"Error generando una muestra sintética: {e}") | |
| continue | |
| return IterableDataset.from_generator(synthetic_generator) | |
| def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation): | |
| safe_steps = int(steps_per_epoch_estimate or 10000) | |
| safe_batch_size = int(batch_size or 1) | |
| safe_grad_accum = int(gradient_accumulation or 8) | |
| safe_block_size = int(block_size or 1024) | |
| size = safe_steps * safe_batch_size * safe_grad_accum | |
| if size <= 1: | |
| size = 10000 | |
| log_size = math.log2(max(1000, size)) | |
| vocab_size = min(65536, 32000 + int(log_size * 2000)) | |
| preliminary_hidden_size = max(512, min(4096, 512 + int(log_size * 100))) | |
| heads = max(8, min(32, preliminary_hidden_size // 64)) | |
| if heads == 0: heads = 8 | |
| hidden_size = (preliminary_hidden_size // heads) * heads | |
| layers = max(8, min(32, 8 + int(log_size * 1.5))) | |
| kv_heads = heads if is_gpt2_like else (max(1, heads // 4)) | |
| return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads | |
| def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn): | |
| if eval_ds_id: | |
| yield update_logs_fn(f"Cargando dataset de evaluación: {eval_ds_id}", "Evaluación") | |
| return _load_hf_streaming([eval_ds_id], split="train") | |
| if uploaded_val_data: | |
| yield update_logs_fn("Usando split de validación de archivos subidos.", "Evaluación") | |
| return HFDataset.from_list(uploaded_val_data) | |
| if train_ds_id: | |
| yield update_logs_fn("Intentando cargar split 'validation' o 'test' del dataset de entrenamiento.", "Evaluación") | |
| try: | |
| for split_name in ["validation", "test"]: | |
| eval_ds = _load_hf_streaming([train_ds_id], split=split_name) | |
| if eval_ds: | |
| yield update_logs_fn(f"Split '{split_name}' encontrado y cargado.", "Evaluación") | |
| return eval_ds | |
| except Exception as e: | |
| yield update_logs_fn(f"Error cargando split de evaluación: {e}. Omitiendo.", "Evaluación") | |
| return None | |
| yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación") | |
| return None | |
| def _create_training_args(output_dir, repo_id, **kwargs): | |
| neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0)) | |
| optim_args_dict = {} | |
| if kwargs.get('optim_args'): | |
| try: | |
| optim_args_dict = ast.literal_eval(f"dict({kwargs['optim_args']})") | |
| except Exception as e: | |
| logger.warning(f"No se pudieron parsear los argumentos del optimizador: {e}.") | |
| args_dict = { | |
| "output_dir": os.path.join(output_dir, "results"), | |
| "per_device_train_batch_size": int(kwargs.get('batch_size', 1)), | |
| "gradient_accumulation_steps": int(kwargs.get('gradient_accumulation', 8)), | |
| "optim": kwargs.get('optimizer', 'adamw_torch'), | |
| "optim_args": optim_args_dict, | |
| "save_strategy": "steps", | |
| "logging_steps": int(kwargs.get('logging_steps', 10)), | |
| "save_steps": int(kwargs.get('save_steps', 50)), | |
| "eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None, | |
| "learning_rate": float(kwargs.get('learning_rate', 2e-5)), | |
| "max_grad_norm": float(kwargs.get('max_grad_norm', 1.0)), | |
| "warmup_ratio": float(kwargs.get('warmup_ratio', 0.03)), | |
| "lr_scheduler_type": kwargs.get('scheduler', 'cosine'), | |
| "weight_decay": float(kwargs.get('weight_decay', 0.01)), | |
| "load_best_model_at_end": kwargs.get('run_evaluation', False), | |
| "save_total_limit": int(kwargs.get('save_total_limit', 1)), | |
| "push_to_hub": True, | |
| "hub_model_id": repo_id, | |
| "hub_strategy": kwargs.get('hub_strategy', 'every_save'), | |
| "dataloader_num_workers": 2, | |
| "report_to": "wandb" if kwargs.get('wandb_api_key_input') else "none", | |
| "remove_unused_columns": False, | |
| "group_by_length": kwargs.get('group_by_length', False), | |
| "metric_for_best_model": kwargs.get('metric_for_best_model', 'loss') if kwargs.get('run_evaluation') else None, | |
| "greater_is_better": kwargs.get('greater_is_better', False), | |
| "neftune_noise_alpha": neftune_alpha if neftune_alpha > 0 else None, | |
| "adam_beta1": float(kwargs.get('adam_beta1', 0.9)), | |
| "adam_beta2": float(kwargs.get('adam_beta2', 0.999)), | |
| "adam_epsilon": float(kwargs.get('adam_epsilon', 1e-8)), | |
| } | |
| if kwargs.get('early_stopping_patience', 0) > 0 and kwargs.get('run_evaluation', False): | |
| args_dict['early_stopping_patience'] = int(kwargs['early_stopping_patience']) | |
| args_dict['load_best_model_at_end'] = True | |
| max_steps_val = int(kwargs.get('max_steps', -1)) | |
| if max_steps_val > 0: | |
| args_dict["max_steps"] = max_steps_val | |
| else: | |
| raise ValueError("Para datasets en streaming se requiere un valor positivo para 'Máximos Pasos de Entrenamiento'.") | |
| return TrainingArguments(**args_dict) | |
| def _generic_model_loader(model_name_or_path, model_class, **kwargs): | |
| config_kwargs = {"trust_remote_code": True} | |
| if kwargs.get('label2id'): | |
| config_kwargs.update({"label2id": kwargs['label2id'], "id2label": kwargs['id2label']}) | |
| config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs) | |
| if kwargs.get('attention_dropout', 0) > 0: config.attention_dropout = kwargs['attention_dropout'] | |
| if kwargs.get('hidden_dropout', 0) > 0: config.hidden_dropout = kwargs['hidden_dropout'] | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "config": config, | |
| "torch_dtype": torch.float32, | |
| } | |
| if kwargs.get('num_labels'): | |
| model_kwargs.update({"num_labels": kwargs['num_labels'], "ignore_mismatched_sizes": True}) | |
| model = model_class.from_pretrained(model_name_or_path, **model_kwargs) | |
| return model | |
| def _find_all_linear_names(model): | |
| cls = torch.nn.Linear | |
| lora_module_names = set() | |
| for name, module in model.named_modules(): | |
| if isinstance(module, cls): | |
| names = name.split('.') | |
| lora_module_names.add(names[-1]) | |
| if 'lm_head' in lora_module_names: | |
| lora_module_names.remove('lm_head') | |
| common_targets = {'q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'} | |
| return list(lora_module_names.intersection(common_targets)) or list(lora_module_names) | |
| def _sft_formatting_func(example, text_col, tokenizer, **kwargs): | |
| if kwargs.get('sft_format_style') == "Conversacional": | |
| conv_col = "" | |
| for key in ["messages", "conversations", "turns"]: | |
| if key in example: conv_col = key; break | |
| if not conv_col: return "" | |
| conversation = example[conv_col] | |
| if isinstance(conversation, str): | |
| try: conversation = ast.literal_eval(conversation) | |
| except: return "" | |
| return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) | |
| if kwargs.get('sft_format_style') == "Razonamiento/Herramientas": | |
| messages = [] | |
| prompt = example.get(kwargs.get('prompt_col_input', 'prompt'), "") | |
| if prompt: messages.append({"role": "user", "content": prompt}) | |
| response_parts = [] | |
| if kwargs.get('enable_cot_input') and example.get(kwargs.get('reasoning_col_input', 'reasoning')): | |
| response_parts.append(f"<thinking>{example[kwargs.get('reasoning_col_input', 'reasoning')]}</thinking>") | |
| if kwargs.get('enable_tool_use_input') and example.get(kwargs.get('tool_use_col_input', 'tools')): | |
| response_parts.append(f"<tool_code>{example.get(kwargs.get('tool_use_col_input', 'tools'))}</tool_code>") | |
| if example.get(kwargs.get('response_col_input', 'response')): | |
| response_parts.append(example.get(kwargs.get('response_col_input', 'response'))) | |
| if response_parts: | |
| messages.append({"role": "assistant", "content": "\n".join(response_parts)}) | |
| if messages: | |
| try: | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) | |
| except Exception as e: | |
| logger.error(f"Error aplicando la plantilla de chat: {e}.") | |
| return "\n".join([m['content'] for m in messages]) | |
| return "" | |
| return example.get(text_col, "") | |
| def _dpo_formatting_func(example, **kwargs): | |
| return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")} | |
| def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col): | |
| model.eval() | |
| encodings = tokenizer("\n\n".join(ex[text_col] for ex in islice(eval_dataset, 1000)), return_tensors="pt") | |
| max_length = model.config.max_position_embeddings | |
| stride = 512 | |
| seq_len = encodings.input_ids.size(1) | |
| nlls = [] | |
| prev_end_loc = 0 | |
| with torch.no_grad(): | |
| for begin_loc in range(0, seq_len, stride): | |
| end_loc = min(begin_loc + max_length, seq_len) | |
| trg_len = end_loc - prev_end_loc | |
| input_ids = encodings.input_ids[:, begin_loc:end_loc] | |
| target_ids = input_ids.clone() | |
| target_ids[:, :-trg_len] = -100 | |
| outputs = model(input_ids, labels=target_ids) | |
| neg_log_likelihood = outputs.loss | |
| nlls.append(neg_log_likelihood) | |
| prev_end_loc = end_loc | |
| if end_loc == seq_len: | |
| break | |
| ppl = torch.exp(torch.stack(nlls).mean()) | |
| return ppl.item() | |
| def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type): | |
| adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()] | |
| if not adapter_ids: | |
| yield "No se proporcionaron IDs de adaptadores válidos. Omitiendo la fusión múltiple." | |
| return base_model_id | |
| try: | |
| weights = [float(w.strip()) for w in weights_str.split(',')] | |
| except: | |
| weights = [1.0] * len(adapter_ids) | |
| if len(weights) != len(adapter_ids): | |
| weights = [1.0] * len(adapter_ids) | |
| yield "Pesos de adaptadores inválidos, usando 1.0 para todos." | |
| yield f"Cargando modelo base {base_model_id} para fusión múltiple..." | |
| model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float32, trust_remote_code=True, ) | |
| for i, adapter_id in enumerate(adapter_ids): | |
| yield f"Cargando adaptador {i+1}: {adapter_id}" | |
| model.load_adapter(adapter_id, adapter_name=f"adapter_{i}") | |
| adapter_names = [f"adapter_{i}" for i in range(len(adapter_ids))] | |
| yield f"Combinando adaptadores: {adapter_names} con pesos: {weights} y tipo: {combination_type}" | |
| model.add_weighted_adapter(adapters=adapter_names, weights=weights, adapter_name="combined", combination_type=combination_type) | |
| model.set_adapter("combined") | |
| yield "Fusionando combinación de adaptadores en el modelo base..." | |
| merged_model = model.merge_and_unload() | |
| temp_dir = tempfile.mkdtemp() | |
| yield f"Guardando modelo fusionado en {temp_dir}" | |
| merged_model.save_pretrained(temp_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
| tokenizer.save_pretrained(temp_dir) | |
| yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}." | |
| return temp_dir | |
| def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando") | |
| trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False) | |
| final_metrics = {} | |
| if kwargs.get('run_evaluation'): | |
| eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log] | |
| if eval_logs: | |
| final_metrics = eval_logs[-1] | |
| final_metrics = {k.replace('eval_', ''): v for k, v in final_metrics.items()} | |
| yield update_logs_fn("Entrenamiento finalizado.", "Guardando") | |
| output_dir = trainer.args.output_dir | |
| trainer.save_model(output_dir) | |
| if tokenizer: | |
| tokenizer.save_pretrained(output_dir) | |
| with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f: | |
| f.write(model_card_content) | |
| yield update_logs_fn("Subiendo al Hub...", "Subiendo") | |
| upload_folder(folder_path=output_dir, repo_id=repo_id, commit_message="Fin de entrenamiento") | |
| return output_dir, final_metrics | |
| def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| is_dpo = kwargs.get('training_mode') == "DPO (Direct Preference Optimization)" | |
| text_col = kwargs.get('text_col') | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name_input') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, use_fast=False) | |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token | |
| if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja'] | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForCausalLM, **kwargs) | |
| peft_config = None | |
| if kwargs.get('peft'): | |
| target_modules = kwargs.get('target_modules').split(",") if not kwargs.get('auto_find_target_modules') else _find_all_linear_names(model) | |
| yield update_logs_fn(f"Módulos LoRA detectados/especificados: {target_modules}", "Configuración") | |
| peft_config = LoraConfig( | |
| r=int(kwargs.get('lora_r')), lora_alpha=int(kwargs.get('lora_alpha')), lora_dropout=float(kwargs.get('lora_dropout')), | |
| target_modules=target_modules, bias="none", task_type="CAUSAL_LM", use_dora=kwargs.get('use_dora', False), | |
| use_rslora=kwargs.get('use_rslora', False), init_lora_weights=kwargs.get('init_lora_weights', 'gaussian'), | |
| modules_to_save=kwargs.get('modules_to_save').split(',') if kwargs.get('modules_to_save') else None | |
| ) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer) | |
| trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config} | |
| if is_dpo: | |
| trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2}) | |
| if train_dataset: | |
| train_dataset = train_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs)) | |
| if eval_dataset: | |
| eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs)) | |
| trainer_kwargs.update({"train_dataset": train_dataset, "eval_dataset": eval_dataset}) | |
| else: | |
| sft_kwargs = kwargs.copy() | |
| if 'text_col' in sft_kwargs: | |
| del sft_kwargs['text_col'] | |
| trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(ex, text_col=text_col, tokenizer=tokenizer, **sft_kwargs)}) | |
| if kwargs.get('enable_loss_reweighting'): | |
| trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': float(kwargs.get('reweighting_factor', 2.0))}) | |
| trainer = TrainerClass(**trainer_kwargs) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}") | |
| def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| labels = [s.strip() for s in kwargs['classification_labels'].split(',')] | |
| label2id = {l: i for i, l in enumerate(labels)} | |
| id2label = {i: l for i, l in enumerate(labels)} | |
| tokenizer_id = kwargs.get('tokenizer_name_input') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs) | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| def preprocess(examples): | |
| return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512) | |
| train_dataset = train_dataset.map(preprocess, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True) | |
| metric = hf_evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| predictions = np.argmax(logits, axis=-1) | |
| return metric.compute(predictions=predictions, references=labels) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, compute_metrics=compute_metrics, | |
| tokenizer=tokenizer, data_collator=DataCollatorWithPadding(tokenizer=tokenizer) | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}") | |
| def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| labels = [s.strip() for s in kwargs['classification_labels'].split(',')] | |
| label2id = {l: i for i, l in enumerate(labels)} | |
| id2label = {i: l for i, l in enumerate(labels)} | |
| tokenizer_id = kwargs.get('tokenizer_name_input') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, add_prefix_space=True) | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForTokenClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs) | |
| def tokenize_and_align_labels(examples): | |
| tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) | |
| labels = [] | |
| for i, label in enumerate(examples["ner_tags"]): | |
| word_ids = tokenized_inputs.word_ids(batch_index=i) | |
| previous_word_idx = None | |
| label_ids = [] | |
| for word_idx in word_ids: | |
| if word_idx is None or word_idx == previous_word_idx: | |
| label_ids.append(-100) | |
| else: | |
| label_ids.append(label[word_idx]) | |
| previous_word_idx = word_idx | |
| labels.append(label_ids) | |
| tokenized_inputs["labels"] = labels | |
| return tokenized_inputs | |
| train_dataset = train_dataset.map(tokenize_and_align_labels, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| if eval_dataset: eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True) | |
| metric = hf_evaluate.load("seqeval") | |
| def compute_metrics(p): | |
| predictions, labels = p | |
| predictions = np.argmax(predictions, axis=2) | |
| true_predictions = [[id2label[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)] | |
| true_labels = [[id2label[l] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)] | |
| results = metric.compute(predictions=true_predictions, references=true_labels) | |
| return {"precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"]} | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, | |
| data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer), | |
| compute_metrics=compute_metrics | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}") | |
| def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name_input') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForQuestionAnswering, **kwargs) | |
| max_length = 384 | |
| doc_stride = 128 | |
| def prepare_train_features(examples): | |
| tokenized_examples = tokenizer( | |
| examples["question"], | |
| examples["context"], | |
| truncation="only_second", | |
| max_length=max_length, | |
| stride=doc_stride, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding="max_length", | |
| ) | |
| sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") | |
| offset_mapping = tokenized_examples.pop("offset_mapping") | |
| tokenized_examples["start_positions"] = [] | |
| tokenized_examples["end_positions"] = [] | |
| for i, offsets in enumerate(offset_mapping): | |
| input_ids = tokenized_examples["input_ids"][i] | |
| cls_index = input_ids.index(tokenizer.cls_token_id) | |
| sequence_ids = tokenized_examples.sequence_ids(i) | |
| sample_index = sample_mapping[i] | |
| answers = examples["answers"][sample_index] | |
| if len(answers["answer_start"]) == 0: | |
| tokenized_examples["start_positions"].append(cls_index) | |
| tokenized_examples["end_positions"].append(cls_index) | |
| else: | |
| start_char = answers["answer_start"][0] | |
| end_char = start_char + len(answers["text"][0]) | |
| token_start_index = 0 | |
| while sequence_ids[token_start_index] != 1: | |
| token_start_index += 1 | |
| token_end_index = len(input_ids) - 1 | |
| while sequence_ids[token_end_index] != 1: | |
| token_end_index -= 1 | |
| if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): | |
| tokenized_examples["start_positions"].append(cls_index) | |
| tokenized_examples["end_positions"].append(cls_index) | |
| else: | |
| while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: | |
| token_start_index += 1 | |
| tokenized_examples["start_positions"].append(token_start_index - 1) | |
| while offsets[token_end_index][1] >= end_char: | |
| token_end_index -= 1 | |
| tokenized_examples["end_positions"].append(token_end_index + 1) | |
| return tokenized_examples | |
| train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=next(iter(train_dataset)).keys()) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| eval_dataset_raw = None | |
| for update in eval_dataset_raw_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset_raw = update | |
| if eval_dataset_raw: | |
| eval_dataset = eval_dataset_raw.map(prepare_train_features, batched=True, remove_columns=next(iter(eval_dataset_raw)).keys()) | |
| training_args = _create_training_args(output_dir, repo_id, **kwargs) | |
| data_collator = DefaultDataCollator() | |
| trainer = Trainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}") | |
| def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| tokenizer_id = kwargs.get('tokenizer_name_input') or model_name | |
| yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) | |
| yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración") | |
| model = _generic_model_loader(model_name, AutoModelForSeq2SeqLM, **kwargs) | |
| def preprocess_function(examples): | |
| inputs = [ex[kwargs['text_col']] for ex in examples["translation"]] | |
| targets = [ex[kwargs['label_col']] for ex in examples["translation"]] | |
| model_inputs = tokenizer(inputs, max_length=128, truncation=True) | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer(targets, max_length=128, truncation=True) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| train_dataset = train_dataset.map(preprocess_function, batched=True) | |
| eval_dataset = None | |
| if kwargs.get('run_evaluation'): | |
| eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn) | |
| for update in eval_dataset_gen: | |
| if isinstance(update, dict): | |
| yield update | |
| else: | |
| eval_dataset = update | |
| if eval_dataset: eval_dataset = eval_dataset.map(preprocess_function, batched=True) | |
| metric = hf_evaluate.load("sacrebleu") | |
| def compute_metrics(eval_preds): | |
| preds, labels = eval_preds | |
| if isinstance(preds, tuple): preds = preds[0] | |
| decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
| labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
| decoded_preds = [pred.strip() for pred in decoded_preds] | |
| decoded_labels = [[label.strip()] for label in decoded_labels] | |
| result = metric.compute(predictions=decoded_preds, references=decoded_labels) | |
| return {"bleu": result["score"]} | |
| training_args_dict = _create_training_args(output_dir, repo_id, **kwargs).to_dict() | |
| training_args_dict["predict_with_generate"] = True | |
| training_args = Seq2SeqTrainingArguments(**training_args_dict) | |
| trainer = Seq2SeqTrainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, tokenizer=tokenizer, | |
| data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model), | |
| compute_metrics=compute_metrics | |
| ) | |
| final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs) | |
| return final_model_path, final_metrics | |
| except Exception as e: | |
| raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}") | |
| def train_text_to_image(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs): | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| yield update_logs(f"Iniciando entrenamiento Text-to-Image con modelo base '{model_name}'...", "Configuración") | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| yield update_logs("Cargando componentes del modelo de difusión...", "Configuración") | |
| tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder") | |
| vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae") | |
| unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet") | |
| noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler") | |
| yield update_logs("Componentes del modelo cargados exitosamente.", "Configuración") | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| unet.train() | |
| learning_rate = float(kwargs.get('learning_rate', 1e-5)) | |
| optimizer = torch.optim.AdamW( | |
| unet.parameters(), | |
| lr=learning_rate, | |
| betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))), | |
| weight_decay=float(kwargs.get('weight_decay', 0.01)), | |
| eps=float(kwargs.get('adam_epsilon', 1e-8)) | |
| ) | |
| yield update_logs("Optimizador configurado.", "Configuración") | |
| text_col = kwargs.get('text_col', 'text') | |
| image_col = kwargs.get('image_col', 'image') | |
| image_transforms = transforms.Compose([ | |
| transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(512), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| def preprocess_train(examples): | |
| images = [image.convert("RGB") for image in examples[image_col]] | |
| examples["pixel_values"] = [image_transforms(image) for image in images] | |
| examples["input_ids"] = tokenizer( | |
| examples[text_col], | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids | |
| return examples | |
| yield update_logs("Preprocesando dataset...", "Datos") | |
| train_dataset = train_dataset.map(preprocess_train, batched=True, remove_columns=[image_col]) | |
| batch_size = int(kwargs.get('batch_size', 1)) | |
| gradient_accumulation_steps = int(kwargs.get('gradient_accumulation', 4)) | |
| max_steps = int(kwargs.get('max_steps', 1000)) | |
| num_epochs = int(kwargs.get('num_epochs', 1)) | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=2 | |
| ) | |
| from diffusers.optimization import get_scheduler as get_diffusers_lr_scheduler | |
| lr_scheduler = get_diffusers_lr_scheduler( | |
| kwargs.get('scheduler', 'cosine'), | |
| optimizer=optimizer, | |
| num_warmup_steps=int(max_steps * float(kwargs.get('warmup_ratio', 0.03))), | |
| num_training_steps=max_steps | |
| ) | |
| yield update_logs(f"Iniciando entrenamiento: {max_steps} pasos, batch_size={batch_size}", "Entrenando") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| unet = unet.to(device) | |
| vae = vae.to(device) | |
| text_encoder = text_encoder.to(device) | |
| global_step = 0 | |
| progress_bar = tqdm(range(max_steps), desc="Entrenando") | |
| for epoch in range(num_epochs): | |
| for step, batch in enumerate(train_dataloader): | |
| if global_step >= max_steps: | |
| break | |
| pixel_values = torch.stack(batch["pixel_values"]).to(device) | |
| with torch.no_grad(): | |
| latents = vae.encode(pixel_values).latent_dist.sample() | |
| latents = latents * vae.config.scaling_factor | |
| noise = torch.randn_like(latents) | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long() | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| input_ids = batch["input_ids"].to(device) | |
| with torch.no_grad(): | |
| encoder_hidden_states = text_encoder(input_ids)[0] | |
| noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| loss = F.mse_loss(noise_pred, noise, reduction="mean") | |
| loss = loss / gradient_accumulation_steps | |
| loss.backward() | |
| if (step + 1) % gradient_accumulation_steps == 0: | |
| torch.nn.utils.clip_grad_norm_(unet.parameters(), float(kwargs.get('max_grad_norm', 1.0))) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| global_step += 1 | |
| progress_bar.update(1) | |
| if global_step % int(kwargs.get('logging_steps', 10)) == 0: | |
| yield update_logs(f"Paso {global_step}/{max_steps} - Loss: {loss.item():.4f}", "Entrenando") | |
| if global_step % int(kwargs.get('save_steps', 500)) == 0: | |
| yield update_logs(f"Guardando checkpoint en paso {global_step}...", "Guardando") | |
| checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}") | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| unet.save_pretrained(os.path.join(checkpoint_dir, "unet")) | |
| if kwargs.get('hub_strategy') == 'every_save': | |
| try: | |
| upload_folder( | |
| folder_path=checkpoint_dir, | |
| repo_id=repo_id, | |
| commit_message=f"Checkpoint paso {global_step}" | |
| ) | |
| except Exception as e: | |
| yield update_logs(f"Advertencia: No se pudo subir checkpoint: {e}", "Guardando") | |
| if global_step >= max_steps: | |
| break | |
| if global_step >= max_steps: | |
| break | |
| progress_bar.close() | |
| yield update_logs("Entrenamiento completado. Guardando modelo final...", "Guardando") | |
| final_output_dir = os.path.join(output_dir, "final_model") | |
| os.makedirs(final_output_dir, exist_ok=True) | |
| pipeline = StableDiffusionText2ImagePipeline( | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet, | |
| tokenizer=tokenizer, | |
| scheduler=noise_scheduler, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| requires_safety_checker=False | |
| ) | |
| pipeline.save_pretrained(final_output_dir) | |
| with open(os.path.join(final_output_dir, "README.md"), "w", encoding="utf-8") as f: | |
| f.write(model_card_content) | |
| yield update_logs("Modelo guardado. Subiendo al Hub...", "Subiendo") | |
| upload_folder( | |
| folder_path=final_output_dir, | |
| repo_id=repo_id, | |
| commit_message="Entrenamiento Text-to-Image completado" | |
| ) | |
| yield update_logs(f"✅ Modelo subido exitosamente a {repo_id}", "Completado") | |
| final_metrics = { | |
| "final_loss": loss.item(), | |
| "total_steps": global_step, | |
| "epochs_completed": epoch + 1 | |
| } | |
| del unet, vae, text_encoder, pipeline | |
| gc.collect() | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| return final_output_dir, final_metrics | |
| except Exception as e: | |
| yield update_logs(f"❌ Error en entrenamiento Text-to-Image: {str(e)}", "Error") | |
| raise Exception(f"Error en Text-to-Image: {e}\n{traceback.format_exc()}") | |
| def _get_data_processing_pipeline(**kwargs): | |
| hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()] | |
| if not hf_ids and not kwargs.get('uploads'): | |
| raise ValueError("No se proporcionaron datasets.") | |
| dataset_weights_str = kwargs.get('dataset_weights', '') | |
| probabilities = None | |
| if dataset_weights_str: | |
| try: | |
| probabilities = [float(w.strip()) for w in dataset_weights_str.split(',')] | |
| except ValueError: | |
| probabilities = None | |
| train_dataset, uploaded_val_data = None, None | |
| if kwargs.get('uploads'): | |
| uploaded_data_map = _load_uploaded_stream(kwargs.get('uploads')) | |
| if uploaded_data_map and uploaded_data_map["train"]: | |
| train_dataset = IterableDataset.from_generator(lambda: iter(uploaded_data_map["train"])) | |
| uploaded_val_data = uploaded_data_map["validation"] | |
| if hf_ids: | |
| hf_train_dataset = _load_hf_streaming(hf_ids, split="train", probabilities=probabilities if not train_dataset else None) | |
| if hf_train_dataset: | |
| if train_dataset is None: | |
| train_dataset = hf_train_dataset | |
| else: | |
| all_streams = [train_dataset, hf_train_dataset] | |
| all_probs = [0.5, 0.5] | |
| train_dataset = interleave_datasets(all_streams, probabilities=all_probs) | |
| if train_dataset is None: | |
| raise ValueError("No se pudieron cargar datos de entrenamiento válidos.") | |
| try: | |
| first_example = next(iter(train_dataset)) | |
| except StopIteration: | |
| raise ValueError("El dataset de entrenamiento está vacío después del procesamiento.") | |
| text_col, image_col, audio_col, label_col = _guess_columns(first_example) | |
| kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data}) | |
| is_text_task = kwargs['training_mode'] not in ["Image Classification (Vision)", "Audio Classification (Speech)"] | |
| if is_text_task: | |
| if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]): | |
| clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']} | |
| train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs)) | |
| filters = _get_filter_functions(**kwargs) | |
| if filters: | |
| for f in filters: | |
| train_dataset = train_dataset.filter(f) | |
| if kwargs.get('enable_back_translation'): | |
| train_dataset = _apply_back_translation(train_dataset, text_col, kwargs['bt_augmentation_ratio'], kwargs['bt_model_id'], kwargs['bt_reverse_model_id']) | |
| if kwargs.get('enable_synthetic_data'): | |
| synthetic_ds = _generate_synthetic_data(train_dataset, text_col, kwargs['synthetic_model_id'], int(kwargs['num_synthetic_samples']), kwargs['synthetic_prompt_template']) | |
| if synthetic_ds: | |
| train_dataset = interleave_datasets([train_dataset, synthetic_ds]) | |
| if kwargs.get('enable_cda') and kwargs.get('cda_json_config'): | |
| train_dataset = _apply_cda(train_dataset, text_col, kwargs['cda_json_config']) | |
| dedup_method = kwargs.get('deduplication_method') | |
| if dedup_method != 'Ninguna': | |
| train_dataset = DeduplicatedIterableDataset( | |
| dataset=train_dataset, | |
| text_col=text_col, | |
| method=dedup_method, | |
| threshold=kwargs.get('minhash_threshold', 0.85), | |
| num_perm=int(kwargs.get('minhash_num_perm', 128)) | |
| ) | |
| return train_dataset, kwargs | |
| def _train_and_upload(progress=gr.Progress(), **kwargs): | |
| logs, repo_link, final_model_path, final_metrics = "", "", None, {} | |
| progress(0, desc="Iniciando...") | |
| yield ( | |
| "Iniciando...", | |
| "Inicio", | |
| "", | |
| gr.update(value=None), | |
| gr.update(value="Entrenando...", interactive=False), | |
| gr.update(visible=True) | |
| ) | |
| def update_logs(new_msg, phase_msg): | |
| nonlocal logs, repo_link, final_metrics | |
| logs += f"[{phase_msg}] {new_msg}\n" | |
| progress(0, desc=f"[{phase_msg}] {new_msg}") | |
| return ( | |
| logs, | |
| phase_msg, | |
| repo_link, | |
| gr.update(value=final_metrics if final_metrics else None) | |
| ) | |
| try: | |
| yield update_logs("Verificando autenticación...", "Inicio") + (gr.update(), gr.update()) | |
| user = whoami() | |
| username = user.get("name") | |
| if not username: | |
| raise ValueError("No se pudo obtener el nombre de usuario de Hugging Face. Por favor, verifica tu token.") | |
| model_name = kwargs.get('model_base_input', '').strip() | |
| if kwargs.get('enable_multi_adapter_merge'): | |
| temp_model_path = model_name | |
| lora_merge_generator = _merge_multiple_loras(model_name, kwargs['multi_adapter_model_ids'], kwargs['multi_adapter_weights'], kwargs['multi_adapter_combination_type']) | |
| try: | |
| while True: | |
| status = next(lora_merge_generator) | |
| yield update_logs(status, "Fusión Múltiple") + (gr.update(), gr.update()) | |
| except StopIteration as e: | |
| temp_model_path = e.value | |
| model_name = temp_model_path | |
| repo_name_input = kwargs.get('repo_name_input', '').strip() | |
| if repo_name_input: | |
| repo_base = re.sub(r'[^a-zA-Z0-9_.-]+', '-', repo_name_input) | |
| repo_base = re.sub(r'^[.-]+|[.-]+$', '', repo_base) | |
| else: | |
| model_name_base = model_name.split('/')[-1] if model_name else "finetuned-model" | |
| sanitized_model_name_base = re.sub(r'[^a-zA-Z0-9_.-]+', '-', model_name_base) | |
| sanitized_model_name_base = re.sub(r'^[.-]+|[.-]+$', '', sanitized_model_name_base) | |
| repo_base = f"{sanitized_model_name_base}-{uuid.uuid4().hex[:6]}" | |
| if not repo_base: | |
| repo_base = f"autotrain-model-{uuid.uuid4().hex[:8]}" | |
| max_repo_base_len = 96 - (len(username) + 1) | |
| repo_base = repo_base[:max_repo_base_len] | |
| repo_id = f"{username}/{repo_base}" | |
| yield update_logs(f"Creando o verificando repositorio: '{repo_id}'", "Inicio") + (gr.update(), gr.update()) | |
| create_repo(repo_id, exist_ok=True, private=kwargs.get('private_repo', False)) | |
| repo_link = f"https://huggingface.co/{repo_id}" | |
| yield update_logs("Repositorio listo.", "Inicio") + (gr.update(), gr.update()) | |
| base_model_id_for_training = model_name | |
| if kwargs.get('train_from_scratch'): | |
| yield update_logs("Preparando entrenamiento desde cero...", "Modelo Cero") + (gr.update(), gr.update()) | |
| architecture = kwargs.get('scratch_architecture') | |
| if not architecture or architecture not in ARCHITECTURE_MAP: | |
| raise ValueError(f"Arquitectura '{architecture}' no es válida o no está soportada para entrenamiento desde cero. Opciones válidas: {list(ARCHITECTURE_MAP.keys())}") | |
| config_class, model_class = ARCHITECTURE_MAP[architecture] | |
| if kwargs.get('auto_config_scratch'): | |
| vocab_size, hidden_size, intermediate_size, layers, heads, block_size_val, tie_word_embeddings, kv_heads = _calculate_auto_config(kwargs.get('block_size'), architecture == "GPT2", kwargs.get('steps_per_epoch_estimate'), kwargs.get('batch_size'), kwargs.get('gradient_accumulation')) | |
| config = config_class(vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=layers, num_attention_heads=heads, num_key_value_heads=kv_heads, max_position_embeddings=block_size_val, tie_word_embeddings=tie_word_embeddings) | |
| model = model_class(config) | |
| elif kwargs.get('manual_config_scratch'): | |
| vocab_size = int(kwargs.get('scratch_vocab_size', 32000)) | |
| hidden_size = int(kwargs.get('scratch_hidden_size', 1024)) | |
| intermediate_size = int(kwargs.get('scratch_intermediate_size', 2048)) | |
| layers = int(kwargs.get('scratch_layers', 8)) | |
| heads = int(kwargs.get('scratch_heads', 8)) | |
| kv_heads = int(kwargs.get('scratch_kv_heads', 8)) | |
| block_size_val = int(kwargs.get('scratch_block_size', 1024)) | |
| tie_word_embeddings = kwargs.get('scratch_tie_word_embeddings', False) | |
| config = config_class(vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=layers, num_attention_heads=heads, num_key_value_heads=kv_heads, max_position_embeddings=block_size_val, tie_word_embeddings=tie_word_embeddings) | |
| model = model_class(config) | |
| else: | |
| raise ValueError("Debe seleccionar auto-configuración o configuración manual para entrenar desde cero.") | |
| temp_model_dir = tempfile.mkdtemp() | |
| model.save_pretrained(temp_model_dir) | |
| tokenizer_id = kwargs.get('tokenizer_name_input') or SCRATCH_TOKENIZER_MAP.get(architecture, "gpt2") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
| tokenizer.save_pretrained(temp_model_dir) | |
| yield update_logs(f"Tokenizer base '{tokenizer_id}' guardado para el modelo desde cero.", "Modelo Cero") + (gr.update(), gr.update()) | |
| except Exception as e: | |
| raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}") | |
| base_model_id_for_training = temp_model_dir | |
| kwargs["peft"] = False | |
| kwargs['tokenizer_name'] = temp_model_dir | |
| yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update()) | |
| yield update_logs("Procesando y cargando datasets...", "Datos") + (gr.update(), gr.update()) | |
| train_dataset, kwargs = _get_data_processing_pipeline(**kwargs) | |
| yield update_logs(f"Columnas detectadas (texto: {kwargs['text_col']}, imagen: {kwargs['image_col']})", "Datos") + (gr.update(), gr.update()) | |
| if kwargs.get('wandb_api_key_input'): | |
| os.environ["WANDB_API_KEY"] = kwargs['wandb_api_key_input'] | |
| os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}" | |
| os.environ["WANDB_LOG_MODEL"] = "checkpoint" | |
| model_card_content = MODEL_CARD_TEMPLATE.format( | |
| repo_id=repo_id, base_model=model_name, base_model_name=_sanitize_model_name_for_yaml(model_name), | |
| training_mode=kwargs.get('training_mode'), | |
| datasets=', '.join([x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]) or "Archivos locales", | |
| hyperparameters=json.dumps({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool)) and 'token' not in k and 'key' not in k and v is not None}, indent=2), | |
| date=datetime.now().strftime("%Y-%m-%d") | |
| ) | |
| training_mode = kwargs.get('training_mode') | |
| training_function_map = { | |
| "Causal Language Modeling (SFT/LoRA)": train_sft_dpo, | |
| "DPO (Direct Preference Optimization)": train_sft_dpo, | |
| "Question Answering (Text)": train_question_answering, | |
| "Sequence Classification (Text)": train_sequence_classification, | |
| "Token Classification (NER)": train_token_classification, | |
| "Text2Text Generation": train_seq2seq, | |
| "Text-to-Image Generation": train_text_to_image, | |
| } | |
| train_func = training_function_map.get(training_mode) | |
| if train_func: | |
| train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs) | |
| while True: | |
| try: | |
| update = next(train_generator) | |
| if isinstance(update, tuple) and len(update) == 4: | |
| yield update + (gr.update(), gr.update()) | |
| else: | |
| pass | |
| except StopIteration as e: | |
| final_model_path, final_metrics = e.value | |
| break | |
| else: | |
| raise ValueError(f"El modo de entrenamiento '{training_mode}' no está implementado.") | |
| if kwargs.get('run_perplexity_evaluation') and final_model_path and training_mode in ["Causal Language Modeling (SFT/LoRA)", "DPO (Direct Preference Optimization)"]: | |
| yield update_logs("Iniciando evaluación de perplejidad...", "Evaluación Final") + (gr.update(), gr.update()) | |
| model = AutoModelForCausalLM.from_pretrained(final_model_path, torch_dtype=torch.float32, ) | |
| tokenizer = AutoTokenizer.from_pretrained(final_model_path) | |
| eval_dataset_perp = None | |
| eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p)) | |
| for update in eval_gen: | |
| if isinstance(update, dict): | |
| yield update + (gr.update(), gr.update()) | |
| else: | |
| eval_dataset_perp = update | |
| if eval_dataset_perp: | |
| ppl = _evaluate_perplexity(model, tokenizer, eval_dataset_perp, kwargs['text_col']) | |
| final_metrics['perplexity'] = ppl | |
| yield update_logs(f"Evaluación de Perplejidad completada. Perplejidad: {ppl:.4f}", "Evaluación Final") + (gr.update(), gr.update()) | |
| final_logs, final_phase, final_repo_link, _ = update_logs(f"✅ Entrenamiento y subida completados: {repo_link}", "Listo") | |
| yield ( | |
| final_logs, | |
| final_phase, | |
| f"### ✅ [Modelo Finalizado: Visita el Repositorio en el Hub]({final_repo_link})", | |
| gr.update(value=final_metrics), | |
| gr.update(value="Iniciar Entrenamiento", interactive=True), | |
| gr.update(visible=False) | |
| ) | |
| except Exception as e: | |
| err_msg = f"❌ Error fatal: {type(e).__name__}: {e}\n{traceback.format_exc()}" | |
| error_logs, error_phase, _, _ = update_logs(err_msg, "Error") | |
| yield ( | |
| error_logs, | |
| error_phase, | |
| "", | |
| gr.update(value=None), | |
| gr.update(value="Iniciar Entrenamiento", interactive=True), | |
| gr.update(visible=False) | |
| ) | |
| def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens): | |
| if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| task_name = TASK_TO_PIPELINE_MAP.get(task_mode) | |
| if not task_name: return f"La inferencia para el modo '{task_mode}' no está soportada.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| try: | |
| pipe = pipeline(task_name, model=model_id, torch_dtype=torch.float32, trust_remote_code=True, ) | |
| result = None | |
| if task_name == "text-generation": | |
| if not text_in: return "Por favor, introduce un prompt de texto.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(text_in, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=temperature, top_p=top_p) | |
| elif task_name == "question-answering": | |
| if not text_in or not context_in: return "Por favor, introduce una pregunta y un contexto.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(question=text_in, context=context_in) | |
| elif task_name in ["token-classification", "text2text-generation", "text-classification"]: | |
| if not text_in: return f"Por favor, introduce texto para {task_name}.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(text_in) | |
| elif task_name in ["image-classification", "audio-classification", "automatic-speech-recognition"]: | |
| input_data = image_in if "image" in task_name else audio_in | |
| if input_data is None: return f"Por favor, proporciona una entrada de { 'imagen' if 'image' in task_name else 'audio' }.", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| result = pipe(input_data) | |
| return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update() | |
| def update_inference_ui(task_mode): | |
| task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "") | |
| is_text_gen = task_name == "text-generation" | |
| show_text = task_name in ["text-generation", "text2text-generation", "token-classification", "question-answering", "text-classification", "text-to-image"] | |
| show_context = task_name == "question-answering" | |
| show_image = task_name in ["image-classification"] | |
| show_audio = task_name in ["audio-classification", "automatic-speech-recognition"] | |
| text_label = "Pregunta" if task_name == "question-answering" else "Entrada de Texto / Prompt" | |
| return ( | |
| gr.update(visible=show_text, label=text_label), | |
| gr.update(visible=show_context), | |
| gr.update(visible=show_image), | |
| gr.update(visible=show_audio), | |
| gr.update(visible=is_text_gen) | |
| ) | |
| def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()): | |
| if not hf_token: | |
| return "Error: Se requiere un token de Hugging Face.", "" | |
| if not repo_name: | |
| return "Error: Se requiere un nombre de repositorio para el dataset.", "" | |
| try: | |
| login(token=hf_token) | |
| user = whoami() | |
| username = user.get("name") | |
| repo_base = f"{username}-{uuid.uuid4().hex[:6]}" if not repo_name else re.sub(r'[^a-zA-Z0-9_.-]+', '-', repo_name)[:90] | |
| repo_id = f"{username}/{repo_base}" | |
| create_repo(repo_id, repo_type="dataset", exist_ok=True) | |
| all_data = [] | |
| if creation_type == "Sintético": | |
| if not synth_model or not synth_prompt or not synth_num_samples: | |
| return "Error: Para la generación sintética se requiere un modelo, un prompt y un número de muestras.", "" | |
| progress(0, desc="Cargando modelo generador...") | |
| generator = pipeline("text-generation", model=synth_model, ) | |
| for i in progress.tqdm(range(int(synth_num_samples)), desc="Generando muestras"): | |
| try: | |
| generated_output = generator(synth_prompt, max_new_tokens=256, num_return_sequences=1, do_sample=True, temperature=0.9, top_p=0.95) | |
| cleaned_text = generated_output[0]['generated_text'][len(synth_prompt):].strip() | |
| if cleaned_text: | |
| all_data.append({"text": cleaned_text}) | |
| except Exception as e: | |
| logger.warning(f"Error al generar muestra {i}: {e}") | |
| elif creation_type == "Basado en Archivo": | |
| if not file_uploads: | |
| return "Error: Por favor, sube al menos un archivo.", "" | |
| progress(0.5, desc="Procesando archivos subidos...") | |
| file_data = _load_uploaded_stream(file_uploads) | |
| all_data = file_data.get("train", []) + file_data.get("validation", []) | |
| if not all_data: | |
| return "Error: No se generaron o procesaron datos.", "" | |
| progress(0.8, desc="Guardando y subiendo al Hub...") | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| data_file = os.path.join(temp_dir, "data.jsonl") | |
| with open(data_file, "w", encoding="utf-8") as f: | |
| for item in all_data: | |
| f.write(json.dumps(item, ensure_ascii=False) + "\n") | |
| readme_content = DATASET_CARD_TEMPLATE.format( | |
| repo_id=repo_id, | |
| creation_type=creation_type, | |
| generation_model=synth_model if creation_type == "Sintético" else "N/A", | |
| date=datetime.now().strftime("%Y-%m-%d") | |
| ) | |
| readme_file = os.path.join(temp_dir, "README.md") | |
| with open(readme_file, "w", encoding="utf-8") as f: | |
| f.write(readme_content) | |
| api = HfApi() | |
| api.upload_folder( | |
| folder_path=temp_dir, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| commit_message="Creación de dataset con AutoTrain-Advanced" | |
| ) | |
| dataset_link = f"https://huggingface.co/datasets/{repo_id}" | |
| return f"✅ Dataset creado y subido exitosamente a {repo_id}", f"### ✅ [Dataset Disponible: Visita el Repositorio]({dataset_link})" | |
| except Exception as e: | |
| return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", "" | |
| def gradio_train_wrapper(*args): | |
| kwargs = dict(zip(all_input_components_dict.keys(), args)) | |
| yield from _train_and_upload(**kwargs) | |
| def gradio_preview_data_wrapper(*args): | |
| kwargs = dict(zip(all_input_components_dict.keys(), args)) | |
| try: | |
| preview_text = "Procesando vista previa...\n" | |
| yield preview_text | |
| model_id_for_tokenizer = kwargs.get('model_base_input') | |
| if not model_id_for_tokenizer and not kwargs.get('train_from_scratch'): | |
| raise ValueError("Se necesita un ID de modelo base para cargar el tokenizer para la vista previa.") | |
| dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs) | |
| text_col = processed_kwargs.get('text_col') | |
| if kwargs.get('train_from_scratch'): | |
| tokenizer_id = SCRATCH_TOKENIZER_MAP.get(kwargs.get('scratch_architecture'), 'gpt2') | |
| else: | |
| tokenizer_id = kwargs.get('tokenizer_name_input') or model_id_for_tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_id, trust_remote_code=True, use_fast=False | |
| ) | |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token | |
| if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja'] | |
| preview_samples = [] | |
| for i, example in enumerate(islice(dataset, 5)): | |
| formatted_text = "" | |
| if kwargs['training_mode'] == "DPO (Direct Preference Optimization)": | |
| formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False) | |
| elif kwargs['training_mode'] == "Causal Language Modeling (SFT/LoRA)": | |
| formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs) | |
| else: | |
| formatted_text = str(example) | |
| preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n") | |
| preview_text = "\n".join(preview_samples) | |
| if not preview_samples: | |
| preview_text = "No se pudieron generar muestras. Revisa la configuración del dataset, los filtros y el formato." | |
| yield preview_text | |
| except Exception as e: | |
| yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}" | |
| def toggle_training_mode_ui(is_scratch): | |
| return ( | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=not is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| gr.update(visible=is_scratch), | |
| ) | |
| def toggle_task_specific_ui(training_mode): | |
| is_classification = "Classification" in training_mode | |
| is_dpo = "DPO" in training_mode | |
| is_sft = "Causal" in training_mode | |
| is_ner = "Token Classification" in training_mode | |
| is_diffusion = "Image Generation" in training_mode | |
| return ( | |
| gr.update(visible=is_classification or is_ner), | |
| gr.update(visible=is_dpo), | |
| gr.update(visible=is_sft), | |
| gr.update(visible=is_diffusion), | |
| gr.update(visible=not is_diffusion) | |
| ) | |
| def toggle_sft_format_ui(format_style): | |
| is_tool = format_style == "Razonamiento/Herramientas" | |
| return gr.update(visible=is_tool) | |
| def toggle_auto_modules_ui(is_auto): | |
| return gr.update(visible=not is_auto) | |
| def toggle_dataset_creator_ui(choice): | |
| is_synth = choice == "Sintético" | |
| return gr.update(visible=is_synth), gr.update(visible=not is_synth) | |
| def get_ao_username(token): | |
| try: | |
| api = HfApi(token=token) | |
| info = api.whoami() | |
| return info["name"] | |
| except Exception: | |
| return "anonymous" | |
| def check_ao_model_exists(username, quantization_type, group_size, model_name, quantized_model_name, token): | |
| try: | |
| models = list_models(author=username, token=token) | |
| model_names = [model.id for model in models] | |
| if quantized_model_name: | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else: | |
| if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and group_size is not None: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}" | |
| else: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}" | |
| if repo_name in model_names: | |
| return f"Model '{repo_name}' already exists in your repository." | |
| else: | |
| return None | |
| except Exception as e: | |
| return f"Error checking model existence: {str(e)}" | |
| def create_ao_model_card(model_name, quantization_type, group_size, token): | |
| try: | |
| model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=token) | |
| readme_path = os.path.join(model_path, "README.md") | |
| original_readme = "" | |
| if os.path.exists(readme_path): | |
| with open(readme_path, "r", encoding="utf-8") as f: | |
| original_readme = f.read() | |
| except Exception: | |
| original_readme = "" | |
| yaml_header = f"""--- | |
| base_model: | |
| - {model_name} | |
| tags: | |
| - torchao-my-repo | |
| --- | |
| # {model_name} (Quantized) | |
| ## Quantization Details | |
| - **Quantization Type**: {quantization_type} | |
| - **Group Size**: {group_size} | |
| """ | |
| if original_readme: | |
| yaml_header += "\n\n# 📄 Original Model Info\n\n" + original_readme | |
| return yaml_header | |
| def quantize_ao_model(model_name, quantization_type, group_size=128, token=None, progress=gr.Progress()): | |
| print(f"Quantizing model: {quantization_type}") | |
| progress(0, desc="Preparing Quantization") | |
| if quantization_type == "GemliteUIntXWeightOnly": | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size) | |
| elif quantization_type == "Int4WeightOnly": | |
| from torchao.dtypes import Int4CPULayout | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size, layout=Int4CPULayout()) | |
| elif quantization_type == "autoquant": | |
| quant_config = "autoquant" | |
| else: | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]() | |
| quantization_config = TorchAoConfig(quant_config) | |
| progress(0.10, desc="Quantizing model") | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| quantization_config=quantization_config, | |
| device_map="cpu", | |
| token=token, | |
| ) | |
| progress(0.45, desc="Quantization completed") | |
| return model | |
| def save_ao_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, token=None, progress=gr.Progress()): | |
| username = get_ao_username(token) | |
| progress(0.50, desc="Preparing to push") | |
| print("Saving quantized model") | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) | |
| tokenizer.save_pretrained(tmpdirname) | |
| model.save_pretrained(tmpdirname, safe_serialization=False) | |
| if quantized_model_name: | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else: | |
| if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and (group_size is not None): | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}" | |
| else: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}" | |
| progress(0.70, desc="Creating model card") | |
| model_card = create_ao_model_card(model_name, quantization_type, group_size, token) | |
| with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
| f.write(model_card) | |
| api = HfApi(token=token) | |
| api.create_repo(repo_name, exist_ok=True, private=not public) | |
| progress(0.80, desc="Pushing to Hub") | |
| api.upload_folder(folder_path=tmpdirname, repo_id=repo_name, repo_type="model") | |
| progress(1.00, desc="Done") | |
| repo_link = f""" | |
| <div class="repo-link"> | |
| <h3>🔗 Repository Link</h3> | |
| <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank">{repo_name}</a></p> | |
| </div> | |
| """ | |
| return f"<h1>🎉 Quantization Completed</h1><br/>{repo_link}" | |
| def quantize_and_save_ao(model_name, quantization_type, group_size, quantized_model_name, public, hf_token): | |
| username = get_ao_username(hf_token) | |
| if not username or username == "anonymous": | |
| return "<div class='error-box'><h3>❌ Authentication Error</h3><p>Invalid or missing HF_TOKEN.</p></div>" | |
| if group_size and str(group_size).strip(): | |
| try: | |
| group_size = int(group_size) | |
| except ValueError: | |
| group_size = None | |
| else: | |
| group_size = None | |
| exists_message = check_ao_model_exists(username, quantization_type, group_size, model_name, quantized_model_name, hf_token) | |
| if exists_message: | |
| return f"<div class='warning-box'><h3>⚠️ Model Already Exists</h3><p>{exists_message}</p></div>" | |
| try: | |
| quantized_model = quantize_ao_model(model_name, quantization_type, group_size, token=hf_token) | |
| return save_ao_model(quantized_model, model_name, quantization_type, group_size, quantized_model_name, public, token=hf_token) | |
| except Exception as e: | |
| return f"<div class='error-box'><h3>❌ Error</h3><p>{str(e)}</p></div>" | |
| def get_awq_default_repo_name(model_id: str, scheme: str) -> str: | |
| if not model_id or not scheme: | |
| return "" | |
| model_base_name = Path(model_id).name | |
| suggested_name = f"{model_base_name}-AWQ-{scheme}" | |
| return f"<your-username>/{suggested_name}" | |
| def run_awq_compression( | |
| hf_token: str, | |
| model_id: str, | |
| scheme: str, | |
| ignore_lm_head: bool, | |
| num_calib_samples: float, | |
| max_seq_len: float, | |
| pipeline_mode: str, | |
| upload_repo: str, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| logs = [] | |
| def log(msg: str) -> str: | |
| logs.append(msg) | |
| return "\n".join(logs) | |
| if not model_id: | |
| yield log("Error: Please provide a source model id (e.g. meta-llama/Llama-3.3-70B-Instruct).") | |
| return | |
| try: | |
| num_calib_samples_int = int(num_calib_samples) | |
| max_seq_len_int = int(max_seq_len) | |
| except ValueError as e: | |
| yield log(f"Error: Invalid number format for calibration settings. {e}") | |
| return | |
| temp_dir = tempfile.mkdtemp() | |
| local_output_dir = Path(temp_dir) / f"{Path(model_id).name}-AWQ-{scheme}" | |
| yield log(f"ℹ️ Quantized model will be saved temporarily to: {local_output_dir.name}") | |
| if hf_token: | |
| try: | |
| login(token=hf_token) | |
| yield log("✅ Logged in to Hugging Face Hub.") | |
| except Exception as e: | |
| yield log(f"⚠️ Hugging Face login failed: {e}") | |
| else: | |
| yield log("ℹ️ No HF token provided. You can still quantize public models and save locally.") | |
| try: | |
| progress(0.1, desc="Building AWQ recipe...") | |
| yield log("🔧 Building AWQ recipe...") | |
| ignore_patterns = ["lm_head"] if ignore_lm_head else None | |
| recipe = AWQModifier( | |
| targets="Linear", | |
| scheme=scheme, | |
| ignore=ignore_patterns, | |
| ) | |
| yield log(f"Recipe:\n scheme = {scheme}\n ignore = {ignore_patterns or '[]'}") | |
| except Exception as e: | |
| yield log(f"❌ Failed to build AWQ recipe: {e}") | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| return | |
| try: | |
| progress(0.25, desc="Running AWQ quantization...") | |
| yield log("🚀 Starting LLM Compressor `oneshot` run (no calibration dataset)...") | |
| yield log(f" • model = {model_id}") | |
| yield log(f" • num_calibration_samples = {num_calib_samples_int}") | |
| yield log(f" • max_seq_length = {max_seq_len_int}") | |
| yield log(f" • pipeline = {pipeline_mode}") | |
| oneshot( | |
| model=model_id, | |
| dataset=None, | |
| recipe=recipe, | |
| output_dir=str(local_output_dir), | |
| max_seq_length=max_seq_len_int, | |
| num_calibration_samples=num_calib_samples_int, | |
| pipeline=pipeline_mode, | |
| trust_remote_code_model=True, | |
| device="cpu", | |
| ) | |
| progress(0.8, desc="Quantization complete. Preparing upload...") | |
| yield log("✅ AWQ quantization finished.") | |
| except Exception as e: | |
| progress(1.0, desc="Error") | |
| yield log(f"❌ CRITICAL ERROR during oneshot:\n{traceback.format_exc()}") | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| return | |
| if upload_repo and hf_token: | |
| try: | |
| progress(0.9, desc="Uploading compressed model to Hugging Face Hub...") | |
| yield log(f"☁️ Uploading folder `{local_output_dir.name}` to repo `{upload_repo}`...") | |
| api = HfApi(token=hf_token) | |
| api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True) | |
| api.upload_folder( | |
| folder_path=str(local_output_dir), | |
| repo_id=upload_repo, | |
| repo_type="model", | |
| ) | |
| hub_url = f"https://huggingface.co/{upload_repo}" | |
| yield log(f"✅ Upload complete. Model available at:\n{hub_url}") | |
| except Exception as e: | |
| yield log(f"⚠️ Upload failed: {e}") | |
| else: | |
| yield log("ℹ️ No upload repo configured. Local files saved to temporary location.") | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| progress(1.0, desc="Done!") | |
| yield log("🎉 Done! AWQ compression finished successfully. Local temporary files cleaned up.") | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown("# 🚀 AutoTrain-Advanced & Quantization Hub") | |
| gr.Markdown("### Una plataforma unificada para Fine-Tuning, PEFT, TorchAO y AWQ Quantization.") | |
| with gr.Tab("1. Autenticación"): | |
| gr.Markdown("#### Conecta tu cuenta de Hugging Face para guardar y cargar modelos.") | |
| with gr.Row(): | |
| hf_token_input = gr.Textbox(label="Token de Hugging Face (con permisos de escritura)", type="password", placeholder="hf_...", scale=3) | |
| login_button = gr.Button("Conectar", variant="primary", scale=1) | |
| login_status = gr.Textbox(label="Estado de Conexión", interactive=False) | |
| login_button.click(hf_login, inputs=[hf_token_input], outputs=[login_status]) | |
| with gr.Tab("2. Creación de Dataset"): | |
| gr.Markdown("## 🧩 Genera o Procesa Datasets y Súbelos al Hub") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dset_repo_name = gr.Textbox(label="Nombre del Repositorio del Dataset", placeholder="mi-nuevo-dataset") | |
| dset_creation_type = gr.Radio(["Sintético", "Basado en Archivo"], label="Tipo de Creación", value="Sintético") | |
| with gr.Group(visible=True) as dset_synth_group: | |
| dset_synth_model = gr.Textbox(label="Modelo Generador", placeholder="p.ej. 'mistralai/Mistral-7B-Instruct-v0.2'") | |
| dset_synth_prompt = gr.Textbox(label="Prompt de Generación", lines=5, placeholder="Escribe una reseña de producto de 5 estrellas para...") | |
| dset_synth_num_samples = gr.Number(label="Número de Muestras", value=100) | |
| with gr.Group(visible=False) as dset_file_group: | |
| dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple") | |
| dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary") | |
| with gr.Column(scale=2): | |
| dset_status_output = gr.Textbox(label="Estado", lines=10, interactive=False) | |
| dset_link_output = gr.Markdown() | |
| dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group]) | |
| dset_create_button.click( | |
| create_and_upload_dataset, | |
| inputs=[hf_token_input, dset_repo_name, dset_creation_type, dset_synth_model, dset_synth_prompt, dset_synth_num_samples, dset_file_uploads], | |
| outputs=[dset_status_output, dset_link_output] | |
| ) | |
| with gr.Tab("3. Entrenamiento"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## ⚙️ Configuración del Entrenamiento") | |
| training_mode = gr.Dropdown(TRAINING_MODES, label="Modo de Entrenamiento", value=TRAINING_MODES[0]) | |
| with gr.Accordion("📦 Modelo y Repositorio", open=True): | |
| model_base_input = gr.Textbox(label="ID del Modelo Base", placeholder="p.ej. 'mistralai/Mistral-7B-v0.1'") | |
| tokenizer_name_input = gr.Textbox(label="ID del Tokenizer (opcional)", placeholder="p.ej. si el modelo no tiene tokenizer") | |
| repo_name_input = gr.Textbox(label="Nombre del Repositorio de Destino", placeholder="p.ej. 'mi-modelo-afinado'") | |
| private_repo = gr.Checkbox(label="Repositorio Privado", value=False) | |
| train_from_scratch = gr.Checkbox(label="Entrenar desde Cero", value=False) | |
| auto_config_scratch = gr.Checkbox(label="Auto-Configuración", value=True, visible=False) | |
| manual_config_scratch = gr.Checkbox(label="Configuración Manual", value=False, visible=False) | |
| scratch_architecture = gr.Textbox(label="Arquitectura (p.ej. Llama, Mistral)", value="Llama", visible=False) | |
| scratch_vocab_size = gr.Number(label="Tamaño de Vocabulario", value=32000, visible=False) | |
| scratch_hidden_size = gr.Number(label="Tamaño Oculto", value=1024, visible=False) | |
| scratch_intermediate_size = gr.Number(label="Tamaño Intermedio", value=2048, visible=False) | |
| scratch_layers = gr.Number(label="Número de Capas", value=8, visible=False) | |
| scratch_heads = gr.Number(label="Cabezas de Atención", value=8, visible=False) | |
| scratch_kv_heads = gr.Number(label="Cabezas KV", value=8, visible=False) | |
| scratch_block_size = gr.Number(label="Tamaño de Bloque", value=1024, visible=False) | |
| scratch_tie_word_embeddings = gr.Checkbox(label="Enlazar Embeddings de Palabras", value=False, visible=False) | |
| steps_per_epoch_estimate = gr.Number(label="Estimación de Pasos por Época (para auto-config)", value=1000, visible=False) | |
| attention_dropout = gr.Slider(0.0, 0.5, 0.0, label="Dropout de Atención", visible=False) | |
| hidden_dropout = gr.Slider(0.0, 0.5, 0.0, label="Dropout Oculto", visible=False) | |
| with gr.Accordion("🔄 Fusión de Múltiples Adaptadores (Avanzado)", open=False) as multi_adapter_accordion: | |
| enable_multi_adapter_merge = gr.Checkbox(label="Habilitar Fusión Múltiple", value=False) | |
| multi_adapter_model_ids = gr.Textbox(label="IDs de Adaptadores (csv)", placeholder="org/adapter1,org/adapter2") | |
| multi_adapter_weights = gr.Textbox(label="Pesos (csv)", placeholder="0.5,0.5") | |
| multi_adapter_combination_type = gr.Dropdown(["slerp", "linear", "cat", "svd", "dare_linear", "dare_ties", "ties"], label="Tipo de Combinación", value="slerp") | |
| with gr.Accordion("📚 Dataset", open=True): | |
| datasets_hf_text = gr.Textbox(label="Datasets de Hugging Face (csv)", placeholder="p.ej. 'databricks/dolly-15k'") | |
| uploads = gr.File(label="Subir Archivos Locales (.jsonl, .csv, .txt)", file_count="multiple") | |
| dataset_weights = gr.Textbox(label="Pesos de los Datasets (csv)", placeholder="p.ej. 0.7, 0.3") | |
| eval_dataset_hf = gr.Textbox(label="Dataset de Evaluación (opcional)", placeholder="p.ej. 'nombre/dataset_eval'") | |
| preview_data_button = gr.Button("Previsualizar Datos Procesados") | |
| data_preview_output = gr.Textbox(label="Vista Previa de Datos", lines=8, interactive=False) | |
| with gr.Accordion("🎓 Hiperparámetros", open=False): | |
| with gr.Row(): | |
| learning_rate = gr.Textbox(label="Tasa de Aprendizaje", value="2e-5") | |
| batch_size = gr.Textbox(label="Tamaño de Lote", value="1") | |
| gradient_accumulation = gr.Textbox(label="Acumulación de Gradiente", value="8") | |
| with gr.Row(): | |
| block_size = gr.Textbox(label="Longitud de Secuencia", value="1024") | |
| max_steps = gr.Textbox(label="Máximos Pasos de Entrenamiento", value="100") | |
| with gr.Row(): | |
| optimizer = gr.Dropdown(["adamw_torch", "adafactor", "sgd", "adagrad"], label="Optimizador", value="adamw_torch") | |
| scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine") | |
| with gr.Accordion("Avanzados", open=False): | |
| warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento") | |
| weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01") | |
| max_grad_norm = gr.Textbox(label="Norma Máxima de Gradiente", value="1.0") | |
| logging_steps = gr.Textbox(label="Pasos de Registro", value="10") | |
| save_steps = gr.Textbox(label="Pasos de Guardado", value="50") | |
| save_total_limit = gr.Textbox(label="Límite Total de Guardado", value="1") | |
| early_stopping_patience = gr.Number(label="Paciencia para Early Stopping (0 para desactivar)", value=0) | |
| resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False) | |
| with gr.Row(): | |
| adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9") | |
| adam_beta2 = gr.Textbox(label="Adam Beta2", value="0.999") | |
| adam_epsilon = gr.Textbox(label="Adam Epsilon", value="1e-8") | |
| group_by_length = gr.Checkbox(label="Agrupar por Longitud", value=False) | |
| neftune_noise_alpha = gr.Textbox(label="NEFTune Ruido Alfa (0 para desactivar)", value="0") | |
| optim_args = gr.Textbox(label="Argumentos del Optimizador (formato dict)", placeholder="ej: betas=(0.9,0.995)") | |
| with gr.Accordion("🦋 PEFT (LoRA)", open=True) as peft_accordion: | |
| peft = gr.Checkbox(label="Habilitar PEFT/LoRA", value=True) | |
| with gr.Row(): | |
| lora_r = gr.Textbox(label="LoRA r", value="16") | |
| lora_alpha = gr.Textbox(label="LoRA alpha", value="32") | |
| lora_dropout = gr.Textbox(label="LoRA dropout", value="0.05") | |
| auto_find_target_modules = gr.Checkbox(label="Auto-encontrar Módulos de Destino", value=True) | |
| target_modules = gr.Textbox(label="Módulos de Destino (csv)", placeholder="q_proj,v_proj", visible=False) | |
| modules_to_save = gr.Textbox(label="Módulos a Guardar (csv)", placeholder="embed_tokens,lm_head") | |
| with gr.Row(): | |
| use_dora = gr.Checkbox(label="Usar DoRA", value=False) | |
| use_rslora = gr.Checkbox(label="Usar RSLora", value=False) | |
| init_lora_weights = gr.Dropdown(["gaussian", "loftq", "pissa"], label="Inicialización de Pesos LoRA", value="gaussian") | |
| with gr.Accordion("🧹 Procesamiento y Aumentación de Datos", open=False): | |
| with gr.Tab("Limpieza y Normalización"): | |
| remove_html_tags = gr.Checkbox(label="Eliminar Etiquetas HTML", value=True) | |
| normalize_whitespace = gr.Checkbox(label="Normalizar Espacios en Blanco", value=True) | |
| remove_urls_emails = gr.Checkbox(label="Eliminar URLs/Emails", value=True) | |
| redact_pii = gr.Checkbox(label="Redactar PII", value=True) | |
| with gr.Tab("Filtrado"): | |
| enable_quality_filter = gr.Checkbox(label="Habilitar Filtros de Calidad", value=True) | |
| min_len_input = gr.Slider(1, 100, 10, label="Longitud Mínima (palabras)") | |
| max_len_input = gr.Slider(100, 5000, 2000, label="Longitud Máxima (palabras)") | |
| rep_threshold_input = gr.Slider(0, 1, 0.2, label="Umbral de Repetición") | |
| exclude_keywords_input = gr.Textbox(label="Palabras Clave a Excluir (csv)") | |
| bias_keywords_input = gr.Textbox(label="Palabras Clave de Sesgo (csv)", placeholder="p.ej. discriminación,prejuicio") | |
| enable_language_filter = gr.Checkbox(label="Habilitar Filtro de Idioma", value=False) | |
| allowed_languages = gr.Textbox(label="Idiomas Permitidos (csv)", value="es,en", placeholder="es,en") | |
| language_detection_threshold = gr.Slider(0.5, 1.0, 0.95, label="Umbral de Detección de Idioma") | |
| enable_toxicity_filter = gr.Checkbox(label="Habilitar Filtro de Toxicidad", value=False) | |
| toxicity_threshold = gr.Slider(0.5, 1.0, 0.8, label="Umbral de Toxicidad") | |
| enable_coherence_filter = gr.Checkbox(label="Habilitar Filtro de Coherencia (Anti-Gibberish)", value=True) | |
| coherence_char_repetition_threshold = gr.Slider(0.1, 0.8, 0.4, label="Umbral de Repetición de Caracteres", info="Máximo ratio de caracteres repetidos permitido") | |
| coherence_ngram_repetition_threshold = gr.Slider(0.1, 0.8, 0.3, label="Umbral de Repetición de N-gramas", info="Máximo ratio de patrones repetidos permitido") | |
| coherence_entropy_threshold = gr.Slider(0.1, 0.9, 0.5, label="Umbral de Entropía", info="Mínima entropía normalizada requerida") | |
| enable_readability_filter = gr.Checkbox(label="Habilitar Filtro de Legibilidad", value=False) | |
| min_readability = gr.Slider(0, 100, 30, label="Legibilidad Mínima (Flesch)") | |
| max_readability = gr.Slider(0, 100, 100, label="Legibilidad Máxima (Flesch)") | |
| enable_stopword_filter = gr.Checkbox(label="Habilitar Filtro de Palabras Vacías", value=False) | |
| max_stopword_ratio = gr.Slider(0.0, 1.0, 0.5, label="Ratio Máxima de Palabras Vacías") | |
| enable_uniqueness_filter = gr.Checkbox(label="Habilitar Filtro de Unicidad", value=False) | |
| min_uniqueness_ratio = gr.Slider(0.0, 1.0, 0.3, label="Ratio Mínima de Unicidad") | |
| with gr.Tab("Deduplicación"): | |
| deduplication_method = gr.Radio(["Ninguna", "Exacta", "Semántica (MinHash)"], label="Método de Deduplicación", value="Ninguna") | |
| minhash_threshold = gr.Slider(0.7, 0.99, 0.85, label="Umbral MinHash") | |
| minhash_num_perm = gr.Slider(64, 256, 128, step=16, label="Permutaciones MinHash") | |
| with gr.Tab("Aumentación"): | |
| enable_back_translation = gr.Checkbox(label="Habilitar Retrotraducción", value=False) | |
| bt_model_id = gr.Textbox(label="Modelo de Traducción", value="Helsinki-NLP/opus-mt-en-de") | |
| bt_reverse_model_id = gr.Textbox(label="Modelo Inverso", value="Helsinki-NLP/opus-mt-de-en") | |
| bt_augmentation_ratio = gr.Slider(0.0, 1.0, 0.1, label="Ratio de Aumentación BT") | |
| with gr.Tab("Generación Sintética"): | |
| enable_synthetic_data = gr.Checkbox(label="Habilitar Datos Sintéticos", value=False) | |
| synthetic_model_id = gr.Textbox(label="ID del Modelo Generador", placeholder="p.ej. 'mistralai/Mistral-7B-Instruct-v0.2'") | |
| num_synthetic_samples = gr.Number(label="Número de Muestras", value=1000) | |
| synthetic_prompt_template = gr.Textbox(label="Plantilla de Prompt", value="Genera un nuevo ejemplo basado en: {{example_text}}\n\nNuevo ejemplo:", lines=3) | |
| with gr.Accordion("📝 Configuración de Formato y Tarea", open=False): | |
| with gr.Group(visible=True) as sft_ui: | |
| sft_format_style = gr.Radio(["Columna de Texto", "Conversacional", "Razonamiento/Herramientas"], label="Formato de Datos SFT", value="Columna de Texto") | |
| chat_template_jinja = gr.Textbox(label="Plantilla de Chat Jinja2 (opcional)", lines=5) | |
| with gr.Group(visible=False) as sft_tool_ui: | |
| enable_cot_input = gr.Checkbox(label="Habilitar Razonamiento (CoT)", value=True) | |
| enable_tool_use_input = gr.Checkbox(label="Habilitar Uso de Herramientas", value=True) | |
| prompt_col_input = gr.Textbox(label="Columna de Prompt/Usuario", value="prompt") | |
| response_col_input = gr.Textbox(label="Columna de Respuesta Final", value="response") | |
| reasoning_col_input = gr.Textbox(label="Columna de Razonamiento", value="reasoning") | |
| tool_use_col_input = gr.Textbox(label="Columna de Uso de Herramientas", value="tools") | |
| with gr.Group(visible=False) as dpo_ui: | |
| dpo_prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt") | |
| dpo_chosen_col_input = gr.Textbox(label="Columna Elegida", value="chosen") | |
| dpo_rejected_col_input = gr.Textbox(label="Columna Rechazada", value="rejected") | |
| with gr.Group(visible=False) as classification_labels_ui: | |
| classification_labels = gr.Textbox(label="Etiquetas de Clasificación (csv)", placeholder="p.ej. positivo,negativo") | |
| with gr.Group(visible=False) as diffusion_ui: | |
| gr.Markdown("Opciones para Text-to-Image aparecerán aquí.") | |
| with gr.Accordion("📊 Evaluación y Mitigación de Sesgos", open=False): | |
| run_evaluation = gr.Checkbox(label="Ejecutar Evaluación", value=False) | |
| metric_for_best_model = gr.Textbox(label="Métrica para Mejor Modelo", value="loss", placeholder="loss, accuracy, f1") | |
| greater_is_better = gr.Checkbox(label="Mayor es Mejor", value=False) | |
| run_perplexity_evaluation = gr.Checkbox(label="Calcular Perplejidad", value=True) | |
| enable_loss_reweighting = gr.Checkbox(label="Habilitar Re-ponderación de Pérdida", value=False) | |
| reweighting_terms = gr.Textbox(label="Términos para Re-ponderar (csv)", placeholder="sesgo,injusto") | |
| reweighting_factor = gr.Slider(1.1, 10.0, 2.0, label="Factor de Re-ponderación") | |
| enable_cda = gr.Checkbox(label="Habilitar Aumentación Contrafactual (CDA)", value=False) | |
| cda_json_config = gr.Textbox(label="Configuración CDA (JSON)", placeholder='[["ella", "él"], ["mujer", "hombre"]]') | |
| with gr.Accordion("🔌 Integraciones", open=False): | |
| hub_strategy = gr.Dropdown(["every_save", "end", "checkpoint", "all_checkpoints"], label="Estrategia de Subida al Hub", value="every_save") | |
| wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password") | |
| wandb_project_input = gr.Textbox(label="Proyecto W&B") | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 📈 Progreso y Resultados") | |
| with gr.Row(): | |
| start_training_button = gr.Button("Iniciar Entrenamiento", variant="primary", scale=3) | |
| stop_training_button = gr.Button("Detener", variant="stop", visible=False, scale=1) | |
| training_phase = gr.Label(label="Fase Actual", value="En espera") | |
| repo_link_output = gr.Markdown(label="Enlace al Repositorio del Modelo") | |
| final_eval_results = gr.JSON(label="Resultados de Evaluación Final") | |
| training_logs = gr.Textbox(label="Registros de Entrenamiento", lines=35, interactive=False) | |
| all_input_components_dict = { | |
| "training_mode": training_mode, "model_base_input": model_base_input, "tokenizer_name_input": tokenizer_name_input, | |
| "repo_name_input": repo_name_input, "private_repo": private_repo, "train_from_scratch": train_from_scratch, | |
| "auto_config_scratch": auto_config_scratch, "manual_config_scratch": manual_config_scratch, | |
| "scratch_architecture": scratch_architecture, "scratch_vocab_size": scratch_vocab_size, | |
| "scratch_hidden_size": scratch_hidden_size, "scratch_intermediate_size": scratch_intermediate_size, | |
| "scratch_layers": scratch_layers, "scratch_heads": scratch_heads, "scratch_kv_heads": scratch_kv_heads, | |
| "scratch_block_size": scratch_block_size, "scratch_tie_word_embeddings": scratch_tie_word_embeddings, | |
| "steps_per_epoch_estimate": steps_per_epoch_estimate, "attention_dropout": attention_dropout, | |
| "hidden_dropout": hidden_dropout, "enable_multi_adapter_merge": enable_multi_adapter_merge, | |
| "multi_adapter_model_ids": multi_adapter_model_ids, "multi_adapter_weights": multi_adapter_weights, | |
| "multi_adapter_combination_type": multi_adapter_combination_type, "datasets_hf_text": datasets_hf_text, | |
| "uploads": uploads, "dataset_weights": dataset_weights, "eval_dataset_hf": eval_dataset_hf, | |
| "learning_rate": learning_rate, "max_steps": max_steps, "batch_size": batch_size, "gradient_accumulation": gradient_accumulation, | |
| "block_size": block_size, "optimizer": optimizer, "scheduler": scheduler, | |
| "warmup_ratio": warmup_ratio, "weight_decay": weight_decay, "max_grad_norm": max_grad_norm, | |
| "logging_steps": logging_steps, "save_steps": save_steps, "save_total_limit": save_total_limit, "resume_from_checkpoint": resume_from_checkpoint, | |
| "adam_beta1": adam_beta1, "adam_beta2": adam_beta2, "adam_epsilon": adam_epsilon, | |
| "group_by_length": group_by_length, | |
| "neftune_noise_alpha": neftune_noise_alpha, "optim_args": optim_args, | |
| "early_stopping_patience": early_stopping_patience, | |
| "peft": peft, "lora_r": lora_r, "lora_alpha": lora_alpha, | |
| "lora_dropout": lora_dropout, "auto_find_target_modules": auto_find_target_modules, "target_modules": target_modules, | |
| "modules_to_save": modules_to_save, "use_dora": use_dora, "use_rslora": use_rslora, "init_lora_weights": init_lora_weights, | |
| "remove_html_tags": remove_html_tags, "normalize_whitespace": normalize_whitespace, "remove_urls_emails": remove_urls_emails, | |
| "redact_pii": redact_pii, "enable_quality_filter": enable_quality_filter, "min_len_input": min_len_input, | |
| "max_len_input": max_len_input, "rep_threshold_input": rep_threshold_input, "exclude_keywords_input": exclude_keywords_input, | |
| "bias_keywords_input": bias_keywords_input, "enable_language_filter": enable_language_filter, | |
| "allowed_languages": allowed_languages, "language_detection_threshold": language_detection_threshold, | |
| "enable_toxicity_filter": enable_toxicity_filter, "toxicity_threshold": toxicity_threshold, | |
| "enable_coherence_filter": enable_coherence_filter, "coherence_char_repetition_threshold": coherence_char_repetition_threshold, | |
| "coherence_ngram_repetition_threshold": coherence_ngram_repetition_threshold, "coherence_entropy_threshold": coherence_entropy_threshold, | |
| "enable_readability_filter": enable_readability_filter, "min_readability": min_readability, "max_readability": max_readability, | |
| "enable_stopword_filter": enable_stopword_filter, "max_stopword_ratio": max_stopword_ratio, | |
| "enable_uniqueness_filter": enable_uniqueness_filter, "min_uniqueness_ratio": min_uniqueness_ratio, | |
| "deduplication_method": deduplication_method, "minhash_threshold": minhash_threshold, "minhash_num_perm": minhash_num_perm, | |
| "enable_cda": enable_cda, "cda_json_config": cda_json_config, | |
| "enable_back_translation": enable_back_translation, "bt_model_id": bt_model_id, | |
| "bt_reverse_model_id": bt_reverse_model_id, "bt_augmentation_ratio": bt_augmentation_ratio, | |
| "enable_synthetic_data": enable_synthetic_data, | |
| "synthetic_model_id": synthetic_model_id, "num_synthetic_samples": num_synthetic_samples, | |
| "synthetic_prompt_template": synthetic_prompt_template, | |
| "sft_format_style": sft_format_style, "chat_template_jinja": chat_template_jinja, | |
| "enable_cot_input": enable_cot_input, "enable_tool_use_input": enable_tool_use_input, | |
| "prompt_col_input": prompt_col_input, "response_col_input": response_col_input, | |
| "reasoning_col_input": reasoning_col_input, "tool_use_col_input": tool_use_col_input, | |
| "dpo_prompt_col_input": dpo_prompt_col_input, "dpo_chosen_col_input": dpo_chosen_col_input, | |
| "dpo_rejected_col_input": dpo_rejected_col_input, "classification_labels": classification_labels, | |
| "run_evaluation": run_evaluation, "metric_for_best_model": metric_for_best_model, | |
| "greater_is_better": greater_is_better, "run_perplexity_evaluation": run_perplexity_evaluation, | |
| "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms, "reweighting_factor": reweighting_factor, | |
| "hub_strategy": hub_strategy, "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input, | |
| } | |
| all_input_components_list = list(all_input_components_dict.values()) | |
| all_output_components = [training_logs, training_phase, repo_link_output, final_eval_results, start_training_button, stop_training_button] | |
| preview_data_button.click( | |
| gradio_preview_data_wrapper, | |
| inputs=all_input_components_list, | |
| outputs=[data_preview_output] | |
| ) | |
| train_from_scratch.change( | |
| toggle_training_mode_ui, | |
| inputs=[train_from_scratch], | |
| outputs=[model_base_input, tokenizer_name_input, multi_adapter_accordion, peft_accordion, | |
| auto_config_scratch, scratch_architecture, manual_config_scratch, scratch_vocab_size, | |
| scratch_hidden_size, scratch_intermediate_size, scratch_layers, scratch_heads, | |
| scratch_kv_heads, scratch_block_size, scratch_tie_word_embeddings, | |
| steps_per_epoch_estimate, attention_dropout, hidden_dropout] | |
| ) | |
| training_mode.change( | |
| toggle_task_specific_ui, | |
| inputs=[training_mode], | |
| outputs=[classification_labels_ui, dpo_ui, sft_ui, diffusion_ui, peft_accordion] | |
| ) | |
| sft_format_style.change( | |
| toggle_sft_format_ui, | |
| inputs=[sft_format_style], | |
| outputs=[sft_tool_ui] | |
| ) | |
| auto_find_target_modules.change( | |
| toggle_auto_modules_ui, | |
| inputs=[auto_find_target_modules], | |
| outputs=[target_modules] | |
| ) | |
| train_event = start_training_button.click( | |
| gradio_train_wrapper, | |
| inputs=all_input_components_list, | |
| outputs=all_output_components | |
| ) | |
| stop_training_button.click(fn=None, inputs=None, outputs=None, cancels=[train_event]) | |
| with gr.Tab("4. Inferencia"): | |
| gr.Markdown("## 🧪 Probar un Modelo del Hub") | |
| with gr.Row(): | |
| inf_task_mode = gr.Dropdown(TRAINING_MODES, label="Tipo de Tarea", value=TRAINING_MODES[0]) | |
| inf_model_id = gr.Textbox(label="ID del Modelo en el Hub", placeholder="TuUsuario/TuModeloEntrenado") | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| inf_text_in = gr.Textbox(label="Entrada de Texto / Prompt", lines=5) | |
| inf_context_in = gr.Textbox(label="Contexto (para QA)", lines=3, visible=False) | |
| inf_image_in = gr.Image(label="Entrada de Imagen", type="pil", visible=False) | |
| inf_audio_in = gr.Audio(label="Entrada de Audio", type="filepath", visible=False) | |
| with gr.Accordion("Opciones Avanzadas de Generación", open=False, visible=True) as inf_advanced_options: | |
| inf_temperature = gr.Slider(0.1, 2.0, 0.7, label="Temperatura") | |
| inf_top_p = gr.Slider(0.1, 1.0, 0.95, label="Top-p") | |
| inf_max_new_tokens = gr.Slider(10, 1024, 100, step=1, label="Máximos Tokens Nuevos") | |
| with gr.Row(): | |
| run_inference_btn = gr.Button("Ejecutar Inferencia", variant="primary") | |
| with gr.Column(scale=3): | |
| inf_text_out = gr.Textbox(label="Salida de Texto", lines=15, interactive=False) | |
| inf_task_mode.change( | |
| update_inference_ui, | |
| inputs=[inf_task_mode], | |
| outputs=[inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_advanced_options] | |
| ) | |
| run_inference_btn.click( | |
| run_inference, | |
| inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens], | |
| outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in] | |
| ) | |
| with gr.Tab("5. TorchAO Quantization"): | |
| gr.Markdown("## 🔥 TorchAO Quantizer") | |
| gr.Markdown("Cuantización eficiente usando `torchao`.") | |
| with gr.Row(): | |
| ao_token = gr.Textbox(label="HF Token (si es diferente al principal)", type="password", placeholder="Opcional") | |
| ao_model_name = HuggingfaceHubSearch(label="🔍 Hub Model ID", placeholder="Search a model", search_type="model") | |
| ao_quant_type = gr.Dropdown(choices=list(MAP_QUANT_TYPE_TO_NAME.keys()), value="Int8WeightOnly", label="Tipo de Cuantización") | |
| ao_group_size = gr.Textbox(label="Group Size (opcional)", value="128") | |
| ao_custom_name = gr.Textbox(label="Nombre Personalizado (opcional)", value="") | |
| ao_public = gr.Checkbox(label="Hacer Público", value=True) | |
| ao_output = gr.Markdown() | |
| ao_btn = gr.Button("🚀 Cuantizar y Subir", variant="primary") | |
| ao_btn.click( | |
| quantize_and_save_ao, | |
| inputs=[ao_model_name, ao_quant_type, ao_group_size, ao_custom_name, ao_public, hf_token_input], | |
| outputs=ao_output | |
| ) | |
| with gr.Tab("6. AWQ Quantization"): | |
| gr.Markdown("## 🧱 LLM Compressor – AWQ Quantizer") | |
| gr.Markdown("Cuantización AWQ usando `llmcompressor` (oneshot).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| awq_token = gr.Textbox(label="HF Token (si es diferente al principal)", type="password", placeholder="Opcional") | |
| awq_model_id = gr.Textbox(label="Source Model ID", value="meta-llama/Llama-3.3-70B-Instruct") | |
| awq_scheme = gr.Dropdown(label="AWQ Scheme", choices=["W4A16", "W4A16_ASYM"], value="W4A16_ASYM") | |
| awq_ignore_head = gr.Checkbox(label="Ignore lm_head", value=True) | |
| awq_calib = gr.Number(label="Calibration Samples", value=128, precision=0) | |
| awq_seq_len = gr.Number(label="Max Sequence Length", value=2048, precision=0) | |
| awq_pipeline = gr.Dropdown(label="Pipeline Mode", choices=["sequential", "default"], value="sequential") | |
| awq_repo = gr.Textbox(label="Target HF Repo", placeholder="username/model-awq") | |
| awq_btn = gr.Button("Iniciar Compresión AWQ", variant="primary") | |
| with gr.Column(): | |
| awq_logs = gr.Textbox(label="Logs del Proceso", lines=30, interactive=False) | |
| awq_btn.click( | |
| run_awq_compression, | |
| inputs=[hf_token_input, awq_model_id, awq_scheme, awq_ignore_head, awq_calib, awq_seq_len, awq_pipeline, awq_repo], | |
| outputs=[awq_logs] | |
| ) | |
| with gr.Tab("7. Explicación del Código"): | |
| gr.Markdown(""" | |
| ### 🧠 Explicación del Código y Mecanismos Avanzados | |
| """) | |
| gr.Markdown("#### 1. CORE MECHANISMS") | |
| gr.Markdown(""" | |
| * PEFT/LoRA: Parameter-Efficient Fine-Tuning. Only low-rank matrices ($A$ and $B$) are trained for low-rank updates ($W' = W + B A$). This drastically reduces trainable parameters. | |
| * Accelerator: Manages device placement and gradient accumulation for stable large-batch training simulation. | |
| * Early Stopping: Halts training if validation loss doesn't improve over a set number of steps (`early_stopping_patience`). | |
| * Gradient Accumulation: Simulates larger batch sizes by accumulating gradients over several forward/backward passes before an optimization step. | |
| * Gradient Clipping: Limits the maximum norm of the gradients (`max_grad_norm`) to prevent exploding gradients during training. | |
| """) | |
| gr.Markdown("#### 2. DATA PROCESSING & AUGMENTATION") | |
| gr.Markdown(""" | |
| * Streaming Datasets: Uses `datasets` streaming mode to handle very large datasets without loading all into RAM. | |
| * Data Cleaning: Removes HTML tags, normalizes whitespace, redacts PII, and removes URLs/emails. | |
| * Advanced Filtering: Includes optional filters for text length, word repetition, language detection, and basic toxicity detection (via `unitary/toxic-bert`). | |
| * Data Augmentation: Supports **Back-Translation (BT)** for introducing paraphrasing variations and **Counterfactual Data Augmentation (CDA)** for controlled bias testing (e.g., swapping gendered pronouns). | |
| * Synthetic Data Generation: Uses a specified LLM to generate new training examples based on an initial prompt template. | |
| * Deduplication: Implements both **Exact** and **Semantic (MinHash LSH)** deduplication to prevent data contamination during iterative fine-tuning. | |
| """) | |
| gr.Markdown("#### 3. TRAINING MODES") | |
| gr.Markdown(""" | |
| * SFT (Supervised Fine-Tuning): Standard fine-tuning, supports **Conversation** and **Reasoning/Tool Use (CoT)** formatting styles. | |
| * DPO (Direct Preference Optimization): Trains directly on preference pairs (chosen vs. rejected), using the `trl` library. | |
| * Task-Specific Heads: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`). | |
| * Seq2Seq: For translation/summarization tasks, using `Seq2SeqTrainer`. | |
| """) | |
| gr.Markdown("#### 4. QUANTIZATION (TorchAO & AWQ)") | |
| gr.Markdown(""" | |
| * **TorchAO**: PyTorch Native Quantization. Supports Int4, Int8, and Float8 quantization techniques directly integrated with the model loading process. | |
| * **AWQ (Activation-aware Weight Quantization)**: Uses `llmcompressor` in oneshot mode to protect salient weights based on activation magnitude, preserving performance at 4-bit. | |
| """) | |
| gr.Markdown("#### 5. OUTPUT & DEPLOYMENT") | |
| gr.Markdown(""" | |
| * Hugging Face Hub Integration: All trained artifacts (full model/LoRA adapter) are automatically pushed to a specified repository on the HF Hub using the provided token. | |
| * Model Card Generation: Automatically generates a `README.md` detailing training parameters and model provenance. | |
| * Inference Tab: A separate UI for easily testing the trained model with various inputs and generation parameters. | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True, share=True) |