Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,6 +16,9 @@ import re
|
|
| 16 |
import ast
|
| 17 |
from itertools import islice
|
| 18 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
| 19 |
import torch
|
| 20 |
import torch.nn.functional as F
|
| 21 |
from torch.utils.data import DataLoader
|
|
@@ -29,15 +32,14 @@ from langdetect import detect_langs
|
|
| 29 |
import textstat
|
| 30 |
from datasketch import MinHash, MinHashLSH
|
| 31 |
import gradio as gr
|
| 32 |
-
import
|
| 33 |
-
from datasets import load_dataset, IterableDataset, Dataset, DatasetDict
|
| 34 |
from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
|
| 35 |
from transformers import (
|
| 36 |
AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
|
| 37 |
AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
|
| 38 |
-
AutoModelForImageClassification,
|
| 39 |
AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
|
| 40 |
-
DataCollatorForTokenClassification, AutoModelForQuestionAnswering,
|
| 41 |
AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
|
| 42 |
DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
|
| 43 |
LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
|
|
@@ -54,7 +56,6 @@ from diffusers import (
|
|
| 54 |
)
|
| 55 |
import evaluate as hf_evaluate
|
| 56 |
from jinja2 import Template
|
| 57 |
-
from collections import defaultdict
|
| 58 |
|
| 59 |
logger = logging.getLogger(__name__)
|
| 60 |
|
|
@@ -94,7 +95,8 @@ TASK_TO_PIPELINE_MAP = {
|
|
| 94 |
"DreamBooth LoRA (Text-to-Image)": "text-to-image",
|
| 95 |
}
|
| 96 |
|
| 97 |
-
MODEL_CARD_TEMPLATE = """
|
|
|
|
| 98 |
language: es
|
| 99 |
license: apache-2.0
|
| 100 |
tags:
|
|
@@ -132,7 +134,8 @@ Este modelo es una versión afinada de [{base_model}](https://huggingface.co/{ba
|
|
| 132 |
- Gradio
|
| 133 |
"""
|
| 134 |
|
| 135 |
-
DATASET_CARD_TEMPLATE = """
|
|
|
|
| 136 |
license: mit
|
| 137 |
---
|
| 138 |
|
|
@@ -147,6 +150,7 @@ Este dataset fue creado utilizando la herramienta [AutoTrain-Advanced](https://h
|
|
| 147 |
- **Fecha de Creación:** {date}
|
| 148 |
"""
|
| 149 |
|
|
|
|
| 150 |
class DebiasingSFTTrainer(SFTTrainer):
|
| 151 |
def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
|
| 152 |
super().__init__(*args, **kwargs)
|
|
@@ -164,6 +168,54 @@ class DebiasingSFTTrainer(SFTTrainer):
|
|
| 164 |
break
|
| 165 |
return (loss, outputs) if return_outputs else loss
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
def hf_login(token):
|
| 168 |
if not token:
|
| 169 |
return "Por favor, introduce un token."
|
|
@@ -174,6 +226,7 @@ def hf_login(token):
|
|
| 174 |
except Exception as e:
|
| 175 |
return f"❌ Error en la conexión: {e}"
|
| 176 |
|
|
|
|
| 177 |
def _clean_text(example, text_col, **kwargs):
|
| 178 |
text = example.get(text_col, "")
|
| 179 |
if not isinstance(text, str):
|
|
@@ -191,6 +244,7 @@ def _clean_text(example, text_col, **kwargs):
|
|
| 191 |
example[text_col] = text
|
| 192 |
return example
|
| 193 |
|
|
|
|
| 194 |
def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords):
|
| 195 |
text = example.get(text_col, "")
|
| 196 |
if not isinstance(text, str): return False
|
|
@@ -204,6 +258,7 @@ def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, e
|
|
| 204 |
lower_text = text.lower()
|
| 205 |
return not any(keyword in lower_text for keyword in exclude_keywords)
|
| 206 |
|
|
|
|
| 207 |
def _get_filter_functions(**kwargs):
|
| 208 |
filters = []
|
| 209 |
if kwargs.get('enable_quality_filter'):
|
|
@@ -252,6 +307,7 @@ def _get_filter_functions(**kwargs):
|
|
| 252 |
filters.append(stats_filter)
|
| 253 |
return filters
|
| 254 |
|
|
|
|
| 255 |
def _load_hf_streaming(ids, split="train", probabilities=None):
|
| 256 |
streams = []
|
| 257 |
valid_ids = []
|
|
@@ -279,10 +335,9 @@ def _load_hf_streaming(ids, split="train", probabilities=None):
|
|
| 279 |
if probabilities and len(probabilities) != len(streams):
|
| 280 |
logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
|
| 281 |
probabilities = None
|
| 282 |
-
|
| 283 |
-
from datasets import interleave_datasets
|
| 284 |
return interleave_datasets(streams, probabilities=probabilities)
|
| 285 |
|
|
|
|
| 286 |
def _load_uploaded_stream(files):
|
| 287 |
all_rows = []
|
| 288 |
for f in files or []:
|
|
@@ -304,6 +359,7 @@ def _load_uploaded_stream(files):
|
|
| 304 |
random.shuffle(all_rows)
|
| 305 |
return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []}
|
| 306 |
|
|
|
|
| 307 |
def _guess_columns(sample):
|
| 308 |
text_col, image_col, audio_col, label_col = "text", "image", "audio", "label"
|
| 309 |
if not isinstance(sample, dict):
|
|
@@ -320,6 +376,7 @@ def _guess_columns(sample):
|
|
| 320 |
elif "labels" in keys: label_col = keys["labels"]
|
| 321 |
return text_col, image_col, audio_col, label_col
|
| 322 |
|
|
|
|
| 323 |
def _apply_cda(dataset, text_col, cda_config_str):
|
| 324 |
try:
|
| 325 |
swap_groups = json.loads(cda_config_str)
|
|
@@ -352,6 +409,7 @@ def _apply_cda(dataset, text_col, cda_config_str):
|
|
| 352 |
current_texts.update(next_texts)
|
| 353 |
return IterableDataset.from_generator(cda_generator)
|
| 354 |
|
|
|
|
| 355 |
def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
|
| 356 |
if not ratio or ratio <= 0:
|
| 357 |
return dataset
|
|
@@ -379,6 +437,7 @@ def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id
|
|
| 379 |
logger.warning(f"Error en retrotraducción: {e}")
|
| 380 |
return IterableDataset.from_generator(bt_generator)
|
| 381 |
|
|
|
|
| 382 |
def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
|
| 383 |
if not num_samples or num_samples <= 0:
|
| 384 |
return None
|
|
@@ -411,6 +470,7 @@ def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples,
|
|
| 411 |
continue
|
| 412 |
return IterableDataset.from_generator(synthetic_generator)
|
| 413 |
|
|
|
|
| 414 |
def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
|
| 415 |
safe_steps = int(steps_per_epoch_estimate or 10000)
|
| 416 |
safe_batch_size = int(batch_size or 1)
|
|
@@ -429,6 +489,7 @@ def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, b
|
|
| 429 |
kv_heads = heads if is_gpt2_like else (max(1, heads // 4))
|
| 430 |
return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads
|
| 431 |
|
|
|
|
| 432 |
def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn):
|
| 433 |
if eval_ds_id:
|
| 434 |
yield update_logs_fn(f"Cargando dataset de evaluación: {eval_ds_id}", "Evaluación")
|
|
@@ -450,6 +511,7 @@ def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn
|
|
| 450 |
yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
|
| 451 |
return None
|
| 452 |
|
|
|
|
| 453 |
def _create_training_args(output_dir, repo_id, **kwargs):
|
| 454 |
neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
|
| 455 |
optim_args_dict = {}
|
|
@@ -468,11 +530,12 @@ def _create_training_args(output_dir, repo_id, **kwargs):
|
|
| 468 |
"save_strategy": "steps",
|
| 469 |
"logging_steps": int(kwargs.get('logging_steps', 10)),
|
| 470 |
"save_steps": int(kwargs.get('save_steps', 50)),
|
|
|
|
| 471 |
"eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
|
| 472 |
"learning_rate": float(kwargs.get('learning_rate', 2e-5)),
|
| 473 |
"fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
|
| 474 |
"bf16": kwargs.get('mixed_precision') == 'bf16' and device == 'cuda',
|
| 475 |
-
"max_grad_norm": float(kwargs.get('max_grad_norm', 0
|
| 476 |
"warmup_ratio": float(kwargs.get('warmup_ratio', 0.03)),
|
| 477 |
"lr_scheduler_type": kwargs.get('scheduler', 'cosine'),
|
| 478 |
"weight_decay": float(kwargs.get('weight_decay', 0.01)),
|
|
@@ -507,15 +570,21 @@ def _create_training_args(output_dir, repo_id, **kwargs):
|
|
| 507 |
|
| 508 |
return TrainingArguments(**args_dict)
|
| 509 |
|
|
|
|
| 510 |
def _generic_model_loader(model_name_or_path, model_class, **kwargs):
|
| 511 |
quantization_type = kwargs.get('quantization', 'no')
|
| 512 |
bnb_config = None
|
| 513 |
|
| 514 |
if quantization_type != "no" and device == "cuda":
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
elif quantization_type != "no" and device == "cpu":
|
| 520 |
logger.warning("La cuantización solo es compatible con GPU CUDA. Se procederá sin cuantización.")
|
| 521 |
|
|
@@ -556,6 +625,7 @@ def _generic_model_loader(model_name_or_path, model_class, **kwargs):
|
|
| 556 |
|
| 557 |
return model
|
| 558 |
|
|
|
|
| 559 |
def _find_all_linear_names(model, quantization_type):
|
| 560 |
cls = torch.nn.Linear
|
| 561 |
if quantization_type != 'no' and device == "cuda":
|
|
@@ -581,6 +651,7 @@ def _find_all_linear_names(model, quantization_type):
|
|
| 581 |
|
| 582 |
return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
|
| 583 |
|
|
|
|
| 584 |
def _conversation_formatting_func(example, tokenizer, **kwargs):
|
| 585 |
conv_col = ""
|
| 586 |
for key in ["messages", "conversations", "turns"]:
|
|
@@ -592,6 +663,7 @@ def _conversation_formatting_func(example, tokenizer, **kwargs):
|
|
| 592 |
except: return ""
|
| 593 |
return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
|
| 594 |
|
|
|
|
| 595 |
def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
|
| 596 |
if kwargs.get('enable_cot_input') or kwargs.get('enable_tool_use_input'):
|
| 597 |
messages = []
|
|
@@ -610,9 +682,11 @@ def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
|
|
| 610 |
return "\n".join([m['content'] for m in messages])
|
| 611 |
return example.get(text_col, "")
|
| 612 |
|
|
|
|
| 613 |
def _dpo_formatting_func(example, **kwargs):
|
| 614 |
return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")}
|
| 615 |
|
|
|
|
| 616 |
def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
|
| 617 |
model.eval()
|
| 618 |
encodings = tokenizer("\n\n".join(ex[text_col] for ex in islice(eval_dataset, 1000)), return_tensors="pt").to(model.device)
|
|
@@ -637,6 +711,7 @@ def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
|
|
| 637 |
ppl = torch.exp(torch.stack(nlls).mean())
|
| 638 |
return ppl.item()
|
| 639 |
|
|
|
|
| 640 |
def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
|
| 641 |
adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
|
| 642 |
if not adapter_ids:
|
|
@@ -668,6 +743,7 @@ def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combinati
|
|
| 668 |
yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}."
|
| 669 |
return temp_dir
|
| 670 |
|
|
|
|
| 671 |
def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 672 |
yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
|
| 673 |
trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False)
|
|
@@ -677,6 +753,7 @@ def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_c
|
|
| 677 |
eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
|
| 678 |
if eval_logs:
|
| 679 |
final_metrics = eval_logs[-1]
|
|
|
|
| 680 |
|
| 681 |
yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
|
| 682 |
output_dir = trainer.args.output_dir
|
|
@@ -695,6 +772,7 @@ def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_c
|
|
| 695 |
torch.cuda.empty_cache()
|
| 696 |
return output_dir, final_metrics
|
| 697 |
|
|
|
|
| 698 |
def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 699 |
output_dir = tempfile.mkdtemp()
|
| 700 |
is_dpo = kwargs.get('training_mode') == "DPO (Direct Preference Optimization)"
|
|
@@ -726,13 +804,13 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 726 |
if kwargs.get('run_evaluation'):
|
| 727 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 728 |
for update in eval_dataset_gen:
|
| 729 |
-
if isinstance(update,
|
| 730 |
yield update
|
| 731 |
else:
|
| 732 |
eval_dataset = update
|
| 733 |
|
| 734 |
TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
|
| 735 |
-
trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config
|
| 736 |
|
| 737 |
if is_dpo:
|
| 738 |
trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
|
|
@@ -743,24 +821,15 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 743 |
trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
|
| 744 |
if kwargs.get('enable_loss_reweighting'):
|
| 745 |
trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
trainer = TrainerClass(**trainer_kwargs)
|
| 749 |
-
except TypeError as e:
|
| 750 |
-
if "unexpected keyword argument 'tokenizer'" in str(e):
|
| 751 |
-
logger.warning("Caught TypeError for tokenizer argument. Retrying without it for TRL compatibility.")
|
| 752 |
-
trainer_kwargs.pop("tokenizer", None)
|
| 753 |
-
trainer = TrainerClass(**trainer_kwargs)
|
| 754 |
-
trainer.tokenizer = tokenizer
|
| 755 |
-
else:
|
| 756 |
-
raise e
|
| 757 |
-
|
| 758 |
final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 759 |
return final_model_path, final_metrics
|
| 760 |
|
| 761 |
except Exception as e:
|
| 762 |
raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}")
|
| 763 |
|
|
|
|
| 764 |
def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 765 |
output_dir = tempfile.mkdtemp()
|
| 766 |
try:
|
|
@@ -771,23 +840,26 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
|
|
| 771 |
tokenizer_id = kwargs.get('tokenizer_name') or model_name
|
| 772 |
yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración")
|
| 773 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
|
|
|
|
|
|
|
| 774 |
|
| 775 |
yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
|
| 776 |
model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
|
|
|
|
| 777 |
|
| 778 |
def preprocess(examples):
|
| 779 |
return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
|
| 780 |
-
train_dataset = train_dataset.map(preprocess)
|
| 781 |
|
| 782 |
eval_dataset = None
|
| 783 |
if kwargs.get('run_evaluation'):
|
| 784 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 785 |
for update in eval_dataset_gen:
|
| 786 |
-
if isinstance(update,
|
| 787 |
yield update
|
| 788 |
else:
|
| 789 |
eval_dataset = update
|
| 790 |
-
if eval_dataset: eval_dataset = eval_dataset.map(preprocess)
|
| 791 |
|
| 792 |
metric = hf_evaluate.load("accuracy")
|
| 793 |
def compute_metrics(eval_pred):
|
|
@@ -807,6 +879,7 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
|
|
| 807 |
except Exception as e:
|
| 808 |
raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}")
|
| 809 |
|
|
|
|
| 810 |
def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 811 |
output_dir = tempfile.mkdtemp()
|
| 812 |
try:
|
|
@@ -843,7 +916,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
|
|
| 843 |
if kwargs.get('run_evaluation'):
|
| 844 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 845 |
for update in eval_dataset_gen:
|
| 846 |
-
if isinstance(update,
|
| 847 |
yield update
|
| 848 |
else:
|
| 849 |
eval_dataset = update
|
|
@@ -871,6 +944,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
|
|
| 871 |
except Exception as e:
|
| 872 |
raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}")
|
| 873 |
|
|
|
|
| 874 |
def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 875 |
output_dir = tempfile.mkdtemp()
|
| 876 |
try:
|
|
@@ -935,7 +1009,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
|
|
| 935 |
eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 936 |
eval_dataset_raw = None
|
| 937 |
for update in eval_dataset_raw_gen:
|
| 938 |
-
if isinstance(update,
|
| 939 |
yield update
|
| 940 |
else:
|
| 941 |
eval_dataset_raw = update
|
|
@@ -955,6 +1029,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
|
|
| 955 |
except Exception as e:
|
| 956 |
raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}")
|
| 957 |
|
|
|
|
| 958 |
def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 959 |
output_dir = tempfile.mkdtemp()
|
| 960 |
try:
|
|
@@ -979,7 +1054,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 979 |
if kwargs.get('run_evaluation'):
|
| 980 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 981 |
for update in eval_dataset_gen:
|
| 982 |
-
if isinstance(update,
|
| 983 |
yield update
|
| 984 |
else:
|
| 985 |
eval_dataset = update
|
|
@@ -1012,6 +1087,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 1012 |
except Exception as e:
|
| 1013 |
raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}")
|
| 1014 |
|
|
|
|
| 1015 |
def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 1016 |
if device == 'cpu':
|
| 1017 |
raise ValueError("El entrenamiento de Text-to-Image solo es compatible con GPU CUDA.")
|
|
@@ -1023,22 +1099,30 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1023 |
|
| 1024 |
yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
|
| 1025 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
| 1026 |
-
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
|
| 1027 |
-
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
|
| 1028 |
-
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")
|
| 1029 |
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
| 1030 |
|
| 1031 |
vae.requires_grad_(False)
|
| 1032 |
text_encoder.requires_grad_(False)
|
| 1033 |
unet.train()
|
| 1034 |
|
| 1035 |
-
yield update_logs_fn("Agregando adaptadores LoRA al
|
| 1036 |
unet_lora_config = LoraConfig(
|
| 1037 |
r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
|
| 1038 |
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
| 1039 |
)
|
| 1040 |
unet.add_adapter(unet_lora_config)
|
| 1041 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
yield update_logs_fn("Procesando dataset de imágenes...", "Text-to-Image (LoRA)")
|
| 1043 |
resolution = int(kwargs.get('diffusion_resolution', 512))
|
| 1044 |
|
|
@@ -1050,7 +1134,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1050 |
])
|
| 1051 |
|
| 1052 |
def preprocess_train(examples):
|
| 1053 |
-
images = [
|
| 1054 |
examples["pixel_values"] = [train_transforms(image) for image in images]
|
| 1055 |
examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
| 1056 |
return examples
|
|
@@ -1064,14 +1148,17 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1064 |
|
| 1065 |
def collate_fn(examples):
|
| 1066 |
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 1067 |
-
input_ids = torch.stack([
|
| 1068 |
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
| 1069 |
|
| 1070 |
train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1071 |
|
| 1072 |
-
yield update_logs_fn("Configurando optimizador y planificador...", "Text-to-Image (LoRA)")
|
| 1073 |
optimizer = torch.optim.AdamW(
|
| 1074 |
-
|
| 1075 |
betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
|
| 1076 |
weight_decay=float(kwargs.get('weight_decay', 0.01)),
|
| 1077 |
eps=float(kwargs.get('adam_epsilon', 1e-8)),
|
|
@@ -1087,36 +1174,34 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1087 |
num_training_steps=max_train_steps,
|
| 1088 |
)
|
| 1089 |
|
| 1090 |
-
unet, optimizer, train_dataloader, lr_scheduler
|
| 1091 |
-
unet, optimizer, train_dataloader, lr_scheduler
|
| 1092 |
)
|
| 1093 |
|
| 1094 |
-
text_encoder.to(accelerator.device, dtype=torch_dtype_auto)
|
| 1095 |
vae.to(accelerator.device, dtype=torch_dtype_auto)
|
| 1096 |
|
| 1097 |
-
yield update_logs_fn("Iniciando bucle de entrenamiento de difusión...", "Text-to-Image (LoRA)")
|
| 1098 |
global_step = 0
|
| 1099 |
final_loss = 0
|
| 1100 |
for epoch in range(num_epochs):
|
| 1101 |
for step, batch in enumerate(train_dataloader):
|
| 1102 |
with accelerator.accumulate(unet):
|
| 1103 |
-
latents = vae.encode(batch["pixel_values"].to(torch_dtype_auto)).latent_dist.sample()
|
| 1104 |
latents = latents * vae.config.scaling_factor
|
| 1105 |
noise = torch.randn_like(latents)
|
| 1106 |
bsz = latents.shape[0]
|
| 1107 |
-
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 1108 |
-
timesteps = timesteps.long()
|
| 1109 |
-
|
| 1110 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 1111 |
-
encoder_hidden_states = text_encoder(batch["input_ids"]
|
| 1112 |
-
|
| 1113 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 1114 |
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
| 1115 |
final_loss = loss.detach().item()
|
| 1116 |
|
| 1117 |
accelerator.backward(loss)
|
| 1118 |
if accelerator.sync_gradients:
|
| 1119 |
-
|
|
|
|
|
|
|
|
|
|
| 1120 |
|
| 1121 |
optimizer.step()
|
| 1122 |
lr_scheduler.step()
|
|
@@ -1124,16 +1209,21 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1124 |
|
| 1125 |
if accelerator.is_main_process:
|
| 1126 |
if global_step % int(kwargs.get('logging_steps', 10)) == 0:
|
| 1127 |
-
yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss}", "
|
| 1128 |
global_step += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1129 |
|
| 1130 |
-
yield update_logs_fn("Entrenamiento completado, guardando modelo...", "Text-to-Image (LoRA)")
|
| 1131 |
accelerator.wait_for_everyone()
|
| 1132 |
if accelerator.is_main_process:
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
|
|
|
|
|
|
| 1137 |
pipeline.save_pretrained(output_dir)
|
| 1138 |
|
| 1139 |
with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
|
|
@@ -1148,7 +1238,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
|
|
| 1148 |
torch.cuda.empty_cache()
|
| 1149 |
return output_dir, {"final_loss": final_loss}
|
| 1150 |
|
| 1151 |
-
|
| 1152 |
def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 1153 |
if device == 'cpu':
|
| 1154 |
raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
|
|
@@ -1163,11 +1253,12 @@ def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, mo
|
|
| 1163 |
|
| 1164 |
train_dataset = train_dataset.map(add_prompt)
|
| 1165 |
|
| 1166 |
-
yield update_logs_fn(f"Usando el prompt de instancia para todas las imágenes: '{dreambooth_prompt}'", "DreamBooth LoRA
|
| 1167 |
|
| 1168 |
final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 1169 |
return final_model_path, final_metrics
|
| 1170 |
|
|
|
|
| 1171 |
def _get_data_processing_pipeline(**kwargs):
|
| 1172 |
hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]
|
| 1173 |
if not hf_ids and not kwargs.get('uploads'):
|
|
@@ -1194,9 +1285,8 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1194 |
if train_dataset is None:
|
| 1195 |
train_dataset = hf_train_dataset
|
| 1196 |
else:
|
| 1197 |
-
from datasets import interleave_datasets
|
| 1198 |
all_streams = [train_dataset, hf_train_dataset]
|
| 1199 |
-
all_probs = [0.5, 0.5]
|
| 1200 |
train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
|
| 1201 |
|
| 1202 |
if train_dataset is None:
|
|
@@ -1206,7 +1296,8 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1206 |
text_col, image_col, audio_col, label_col = _guess_columns(first_example)
|
| 1207 |
kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
|
| 1208 |
|
| 1209 |
-
|
|
|
|
| 1210 |
if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
|
| 1211 |
clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
|
| 1212 |
train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
|
|
@@ -1229,38 +1320,17 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1229 |
|
| 1230 |
dedup_method = kwargs.get('deduplication_method')
|
| 1231 |
if dedup_method != 'Ninguna':
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
-
|
| 1236 |
-
|
| 1237 |
-
|
| 1238 |
-
|
| 1239 |
-
if isinstance(text, str) and text:
|
| 1240 |
-
seen_texts.add(text)
|
| 1241 |
-
yield example
|
| 1242 |
-
train_dataset = IterableDataset.from_generator(dedup_generator_exact)
|
| 1243 |
-
elif dedup_method == 'Semántica (MinHash)':
|
| 1244 |
-
threshold = kwargs.get('minhash_threshold', 0.85)
|
| 1245 |
-
num_perm = int(kwargs.get('minhash_num_perm', 128))
|
| 1246 |
-
def dedup_generator_minhash():
|
| 1247 |
-
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
| 1248 |
-
for i, example in enumerate(base_iterator):
|
| 1249 |
-
text = example.get(text_col, "")
|
| 1250 |
-
if text and isinstance(text, str) and text.strip():
|
| 1251 |
-
m = MinHash(num_perm=num_perm)
|
| 1252 |
-
for d in text.split():
|
| 1253 |
-
m.update(d.encode('utf8'))
|
| 1254 |
-
if not lsh.query(m):
|
| 1255 |
-
lsh.insert(f"key_{i}", m)
|
| 1256 |
-
yield example
|
| 1257 |
-
else:
|
| 1258 |
-
yield example
|
| 1259 |
-
train_dataset = IterableDataset.from_generator(dedup_generator_minhash)
|
| 1260 |
-
|
| 1261 |
|
| 1262 |
return train_dataset, kwargs
|
| 1263 |
|
|
|
|
| 1264 |
def _train_and_upload(**kwargs):
|
| 1265 |
logs, repo_link, final_model_path, final_metrics = "", "", None, {}
|
| 1266 |
|
|
@@ -1350,7 +1420,6 @@ def _train_and_upload(**kwargs):
|
|
| 1350 |
raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
|
| 1351 |
base_model_id_for_training = temp_model_dir
|
| 1352 |
kwargs["peft"] = False
|
| 1353 |
-
kwargs["merge_adapter"] = False
|
| 1354 |
kwargs['tokenizer_name'] = temp_model_dir
|
| 1355 |
yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
|
| 1356 |
|
|
@@ -1363,7 +1432,6 @@ def _train_and_upload(**kwargs):
|
|
| 1363 |
os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
|
| 1364 |
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
|
| 1365 |
|
| 1366 |
-
from datetime import datetime
|
| 1367 |
model_card_content = MODEL_CARD_TEMPLATE.format(
|
| 1368 |
repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
|
| 1369 |
training_mode=kwargs.get('training_mode'),
|
|
@@ -1390,8 +1458,11 @@ def _train_and_upload(**kwargs):
|
|
| 1390 |
train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
|
| 1391 |
while True:
|
| 1392 |
try:
|
| 1393 |
-
|
| 1394 |
-
|
|
|
|
|
|
|
|
|
|
| 1395 |
except StopIteration as e:
|
| 1396 |
final_model_path, final_metrics = e.value
|
| 1397 |
break
|
|
@@ -1405,7 +1476,7 @@ def _train_and_upload(**kwargs):
|
|
| 1405 |
eval_dataset_perp = None
|
| 1406 |
eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
|
| 1407 |
for update in eval_gen:
|
| 1408 |
-
if isinstance(update,
|
| 1409 |
yield update + (gr.update(), gr.update())
|
| 1410 |
else:
|
| 1411 |
eval_dataset_perp = update
|
|
@@ -1436,6 +1507,7 @@ def _train_and_upload(**kwargs):
|
|
| 1436 |
gr.update(visible=False)
|
| 1437 |
)
|
| 1438 |
|
|
|
|
| 1439 |
def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens):
|
| 1440 |
if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update()
|
| 1441 |
task_name = TASK_TO_PIPELINE_MAP.get(task_mode)
|
|
@@ -1460,6 +1532,7 @@ def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in,
|
|
| 1460 |
return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
|
| 1461 |
except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
|
| 1462 |
|
|
|
|
| 1463 |
def update_inference_ui(task_mode):
|
| 1464 |
task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
|
| 1465 |
is_text_gen = task_name == "text-generation"
|
|
@@ -1477,6 +1550,7 @@ def update_inference_ui(task_mode):
|
|
| 1477 |
gr.update(visible=is_text_gen)
|
| 1478 |
)
|
| 1479 |
|
|
|
|
| 1480 |
def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()):
|
| 1481 |
if not hf_token:
|
| 1482 |
return "Error: Se requiere un token de Hugging Face.", ""
|
|
@@ -1525,7 +1599,6 @@ def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, s
|
|
| 1525 |
for item in all_data:
|
| 1526 |
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 1527 |
|
| 1528 |
-
from datetime import datetime
|
| 1529 |
readme_content = DATASET_CARD_TEMPLATE.format(
|
| 1530 |
repo_id=repo_id,
|
| 1531 |
creation_type=creation_type,
|
|
@@ -1550,10 +1623,12 @@ def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, s
|
|
| 1550 |
except Exception as e:
|
| 1551 |
return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", ""
|
| 1552 |
|
|
|
|
| 1553 |
def gradio_train_wrapper(*args):
|
| 1554 |
kwargs = dict(zip(all_input_components_dict.keys(), args))
|
| 1555 |
yield from _train_and_upload(**kwargs)
|
| 1556 |
|
|
|
|
| 1557 |
def gradio_preview_data_wrapper(*args):
|
| 1558 |
kwargs = dict(zip(all_input_components_dict.keys(), args))
|
| 1559 |
try:
|
|
@@ -1563,9 +1638,13 @@ def gradio_preview_data_wrapper(*args):
|
|
| 1563 |
dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
|
| 1564 |
text_col = processed_kwargs.get('text_col')
|
| 1565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 1567 |
-
|
| 1568 |
-
trust_remote_code=True, use_fast=False
|
| 1569 |
)
|
| 1570 |
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
| 1571 |
if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
|
|
@@ -1574,20 +1653,21 @@ def gradio_preview_data_wrapper(*args):
|
|
| 1574 |
for i, example in enumerate(islice(dataset, 5)):
|
| 1575 |
formatted_text = ""
|
| 1576 |
if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
|
| 1577 |
-
formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2)
|
| 1578 |
else:
|
| 1579 |
formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
|
| 1580 |
|
| 1581 |
preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
|
| 1582 |
|
| 1583 |
preview_text = "\n".join(preview_samples)
|
| 1584 |
-
if not
|
| 1585 |
-
preview_text = "No se pudieron generar muestras. Revisa la configuración del dataset y el formato."
|
| 1586 |
yield preview_text
|
| 1587 |
|
| 1588 |
except Exception as e:
|
| 1589 |
yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
|
| 1590 |
|
|
|
|
| 1591 |
def toggle_training_mode_ui(is_scratch):
|
| 1592 |
return (
|
| 1593 |
gr.update(visible=not is_scratch),
|
|
@@ -1598,12 +1678,14 @@ def toggle_training_mode_ui(is_scratch):
|
|
| 1598 |
gr.update(visible=is_scratch)
|
| 1599 |
)
|
| 1600 |
|
|
|
|
| 1601 |
def toggle_task_specific_ui(training_mode):
|
| 1602 |
is_classification = "Classification" in training_mode
|
| 1603 |
is_dpo = "DPO" in training_mode
|
| 1604 |
is_sft = "Causal" in training_mode
|
| 1605 |
is_ner = "Token Classification" in training_mode
|
| 1606 |
is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
|
|
|
|
| 1607 |
|
| 1608 |
return (
|
| 1609 |
gr.update(visible=is_classification or is_ner),
|
|
@@ -1613,13 +1695,15 @@ def toggle_task_specific_ui(training_mode):
|
|
| 1613 |
gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
|
| 1614 |
gr.update(visible=not is_diffusion),
|
| 1615 |
gr.update(visible=is_diffusion),
|
| 1616 |
-
gr.update(visible=not
|
|
|
|
| 1617 |
)
|
| 1618 |
|
| 1619 |
-
|
| 1620 |
def toggle_auto_modules_ui(is_auto):
|
| 1621 |
return gr.update(visible=not is_auto)
|
| 1622 |
|
|
|
|
| 1623 |
def toggle_dataset_creator_ui(choice):
|
| 1624 |
is_synth = choice == "Sintético"
|
| 1625 |
return gr.update(visible=is_synth), gr.update(visible=not is_synth)
|
|
@@ -1650,7 +1734,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1650 |
dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
|
| 1651 |
dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
|
| 1652 |
with gr.Column(scale=2):
|
| 1653 |
-
dset_status_output = gr.Textbox(label="Estado", lines=10)
|
| 1654 |
dset_link_output = gr.Markdown()
|
| 1655 |
|
| 1656 |
dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
|
|
@@ -1706,7 +1790,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1706 |
with gr.Accordion("Avanzados", open=False):
|
| 1707 |
warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
|
| 1708 |
weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
|
| 1709 |
-
max_grad_norm = gr.Textbox(label="Norma Máxima de Gradiente", value="0
|
| 1710 |
logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
|
| 1711 |
save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
|
| 1712 |
save_total_limit = gr.Textbox(label="Límite Total de Guardado", value="1")
|
|
@@ -1766,9 +1850,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1766 |
diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resolución")
|
| 1767 |
with gr.Group(visible=False) as dreambooth_ui:
|
| 1768 |
dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
|
| 1769 |
-
dreambooth_class_prompt = gr.Textbox(label="Prompt de Clase (Opcional)", placeholder="p.ej. 'foto de perro'")
|
| 1770 |
-
dreambooth_num_class_images = gr.Slider(0, 1000, 100, step=10, label="Nº de Imágenes de Clase")
|
| 1771 |
-
dreambooth_prior_loss_weight = gr.Slider(0.0, 2.0, 1.0, label="Peso de Pérdida a Priori")
|
| 1772 |
dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
|
| 1773 |
with gr.Group(visible=False) as classification_labels_ui:
|
| 1774 |
classification_labels = gr.Textbox(label="Etiquetas de Clasificación (csv)", placeholder="p.ej. positivo,negativo")
|
|
@@ -1787,7 +1868,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1787 |
enable_cda = gr.Checkbox(label="Habilitar Aumentación Contrafactual (CDA)", value=False)
|
| 1788 |
cda_json_config = gr.Textbox(label="Configuración CDA (JSON)", placeholder='[["ella", "él"], ["mujer", "hombre"]]')
|
| 1789 |
|
| 1790 |
-
|
| 1791 |
with gr.Accordion("🔌 Integraciones", open=False):
|
| 1792 |
wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
|
| 1793 |
wandb_project_input = gr.Textbox(label="Proyecto W&B")
|
|
@@ -1832,8 +1912,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1832 |
"diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
|
| 1833 |
"enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
|
| 1834 |
"wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
|
| 1835 |
-
"dreambooth_instance_prompt": dreambooth_instance_prompt,
|
| 1836 |
-
"dreambooth_num_class_images": dreambooth_num_class_images, "dreambooth_prior_loss_weight": dreambooth_prior_loss_weight,
|
| 1837 |
"dreambooth_train_text_encoder": dreambooth_train_text_encoder
|
| 1838 |
}
|
| 1839 |
|
|
@@ -1905,4 +1984,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1905 |
outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
|
| 1906 |
)
|
| 1907 |
|
| 1908 |
-
demo.queue().launch(
|
|
|
|
| 16 |
import ast
|
| 17 |
from itertools import islice
|
| 18 |
from pathlib import Path
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
|
| 22 |
import torch
|
| 23 |
import torch.nn.functional as F
|
| 24 |
from torch.utils.data import DataLoader
|
|
|
|
| 32 |
import textstat
|
| 33 |
from datasketch import MinHash, MinHashLSH
|
| 34 |
import gradio as gr
|
| 35 |
+
from datasets import load_dataset, IterableDataset, Dataset, DatasetDict, interleave_datasets, Audio
|
|
|
|
| 36 |
from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
|
| 37 |
from transformers import (
|
| 38 |
AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
|
| 39 |
AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
|
| 40 |
+
SpeechT5ForTextToSpeech, SpeechT5Processor, SpeechT5HifiGan, AutoModelForImageClassification,
|
| 41 |
AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
|
| 42 |
+
DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq,
|
| 43 |
AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
|
| 44 |
DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
|
| 45 |
LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
|
|
|
|
| 56 |
)
|
| 57 |
import evaluate as hf_evaluate
|
| 58 |
from jinja2 import Template
|
|
|
|
| 59 |
|
| 60 |
logger = logging.getLogger(__name__)
|
| 61 |
|
|
|
|
| 95 |
"DreamBooth LoRA (Text-to-Image)": "text-to-image",
|
| 96 |
}
|
| 97 |
|
| 98 |
+
MODEL_CARD_TEMPLATE = """
|
| 99 |
+
---
|
| 100 |
language: es
|
| 101 |
license: apache-2.0
|
| 102 |
tags:
|
|
|
|
| 134 |
- Gradio
|
| 135 |
"""
|
| 136 |
|
| 137 |
+
DATASET_CARD_TEMPLATE = """
|
| 138 |
+
---
|
| 139 |
license: mit
|
| 140 |
---
|
| 141 |
|
|
|
|
| 150 |
- **Fecha de Creación:** {date}
|
| 151 |
"""
|
| 152 |
|
| 153 |
+
@spaces.GPU()
|
| 154 |
class DebiasingSFTTrainer(SFTTrainer):
|
| 155 |
def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
|
| 156 |
super().__init__(*args, **kwargs)
|
|
|
|
| 168 |
break
|
| 169 |
return (loss, outputs) if return_outputs else loss
|
| 170 |
|
| 171 |
+
@spaces.GPU()
|
| 172 |
+
class DeduplicatedIterableDataset(IterableDataset):
|
| 173 |
+
def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
|
| 174 |
+
super().__init__(ex_iterable=iter([]))
|
| 175 |
+
self.dataset = dataset
|
| 176 |
+
self.text_col = text_col
|
| 177 |
+
self.method = method
|
| 178 |
+
self.threshold = threshold
|
| 179 |
+
self.num_perm = num_perm
|
| 180 |
+
if hasattr(dataset, '_info'):
|
| 181 |
+
self._info = dataset._info
|
| 182 |
+
elif hasattr(dataset, 'info'):
|
| 183 |
+
self._info = dataset.info
|
| 184 |
+
|
| 185 |
+
def __iter__(self):
|
| 186 |
+
if self.method == 'Exacta':
|
| 187 |
+
return self._exact_iter()
|
| 188 |
+
elif self.method == 'Semántica (MinHash)':
|
| 189 |
+
return self._minhash_iter()
|
| 190 |
+
else:
|
| 191 |
+
return iter(self.dataset)
|
| 192 |
+
|
| 193 |
+
def _exact_iter(self):
|
| 194 |
+
seen_texts = set()
|
| 195 |
+
for example in self.dataset:
|
| 196 |
+
text = example.get(self.text_col, "")
|
| 197 |
+
if text and isinstance(text, str):
|
| 198 |
+
if text not in seen_texts:
|
| 199 |
+
seen_texts.add(text)
|
| 200 |
+
yield example
|
| 201 |
+
else:
|
| 202 |
+
yield example
|
| 203 |
+
|
| 204 |
+
def _minhash_iter(self):
|
| 205 |
+
lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm)
|
| 206 |
+
for i, example in enumerate(self.dataset):
|
| 207 |
+
text = example.get(self.text_col, "")
|
| 208 |
+
if text and isinstance(text, str) and text.strip():
|
| 209 |
+
m = MinHash(num_perm=self.num_perm)
|
| 210 |
+
for d in text.split():
|
| 211 |
+
m.update(d.encode('utf8'))
|
| 212 |
+
if not lsh.query(m):
|
| 213 |
+
lsh.insert(f"key_{i}", m)
|
| 214 |
+
yield example
|
| 215 |
+
else:
|
| 216 |
+
yield example
|
| 217 |
+
|
| 218 |
+
@spaces.GPU()
|
| 219 |
def hf_login(token):
|
| 220 |
if not token:
|
| 221 |
return "Por favor, introduce un token."
|
|
|
|
| 226 |
except Exception as e:
|
| 227 |
return f"❌ Error en la conexión: {e}"
|
| 228 |
|
| 229 |
+
@spaces.GPU()
|
| 230 |
def _clean_text(example, text_col, **kwargs):
|
| 231 |
text = example.get(text_col, "")
|
| 232 |
if not isinstance(text, str):
|
|
|
|
| 244 |
example[text_col] = text
|
| 245 |
return example
|
| 246 |
|
| 247 |
+
@spaces.GPU()
|
| 248 |
def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords):
|
| 249 |
text = example.get(text_col, "")
|
| 250 |
if not isinstance(text, str): return False
|
|
|
|
| 258 |
lower_text = text.lower()
|
| 259 |
return not any(keyword in lower_text for keyword in exclude_keywords)
|
| 260 |
|
| 261 |
+
@spaces.GPU()
|
| 262 |
def _get_filter_functions(**kwargs):
|
| 263 |
filters = []
|
| 264 |
if kwargs.get('enable_quality_filter'):
|
|
|
|
| 307 |
filters.append(stats_filter)
|
| 308 |
return filters
|
| 309 |
|
| 310 |
+
@spaces.GPU()
|
| 311 |
def _load_hf_streaming(ids, split="train", probabilities=None):
|
| 312 |
streams = []
|
| 313 |
valid_ids = []
|
|
|
|
| 335 |
if probabilities and len(probabilities) != len(streams):
|
| 336 |
logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
|
| 337 |
probabilities = None
|
|
|
|
|
|
|
| 338 |
return interleave_datasets(streams, probabilities=probabilities)
|
| 339 |
|
| 340 |
+
@spaces.GPU()
|
| 341 |
def _load_uploaded_stream(files):
|
| 342 |
all_rows = []
|
| 343 |
for f in files or []:
|
|
|
|
| 359 |
random.shuffle(all_rows)
|
| 360 |
return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []}
|
| 361 |
|
| 362 |
+
@spaces.GPU()
|
| 363 |
def _guess_columns(sample):
|
| 364 |
text_col, image_col, audio_col, label_col = "text", "image", "audio", "label"
|
| 365 |
if not isinstance(sample, dict):
|
|
|
|
| 376 |
elif "labels" in keys: label_col = keys["labels"]
|
| 377 |
return text_col, image_col, audio_col, label_col
|
| 378 |
|
| 379 |
+
@spaces.GPU()
|
| 380 |
def _apply_cda(dataset, text_col, cda_config_str):
|
| 381 |
try:
|
| 382 |
swap_groups = json.loads(cda_config_str)
|
|
|
|
| 409 |
current_texts.update(next_texts)
|
| 410 |
return IterableDataset.from_generator(cda_generator)
|
| 411 |
|
| 412 |
+
@spaces.GPU()
|
| 413 |
def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
|
| 414 |
if not ratio or ratio <= 0:
|
| 415 |
return dataset
|
|
|
|
| 437 |
logger.warning(f"Error en retrotraducción: {e}")
|
| 438 |
return IterableDataset.from_generator(bt_generator)
|
| 439 |
|
| 440 |
+
@spaces.GPU()
|
| 441 |
def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
|
| 442 |
if not num_samples or num_samples <= 0:
|
| 443 |
return None
|
|
|
|
| 470 |
continue
|
| 471 |
return IterableDataset.from_generator(synthetic_generator)
|
| 472 |
|
| 473 |
+
@spaces.GPU()
|
| 474 |
def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
|
| 475 |
safe_steps = int(steps_per_epoch_estimate or 10000)
|
| 476 |
safe_batch_size = int(batch_size or 1)
|
|
|
|
| 489 |
kv_heads = heads if is_gpt2_like else (max(1, heads // 4))
|
| 490 |
return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads
|
| 491 |
|
| 492 |
+
@spaces.GPU()
|
| 493 |
def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn):
|
| 494 |
if eval_ds_id:
|
| 495 |
yield update_logs_fn(f"Cargando dataset de evaluación: {eval_ds_id}", "Evaluación")
|
|
|
|
| 511 |
yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
|
| 512 |
return None
|
| 513 |
|
| 514 |
+
@spaces.GPU()
|
| 515 |
def _create_training_args(output_dir, repo_id, **kwargs):
|
| 516 |
neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
|
| 517 |
optim_args_dict = {}
|
|
|
|
| 530 |
"save_strategy": "steps",
|
| 531 |
"logging_steps": int(kwargs.get('logging_steps', 10)),
|
| 532 |
"save_steps": int(kwargs.get('save_steps', 50)),
|
| 533 |
+
"evaluation_strategy": "steps" if kwargs.get('run_evaluation', False) else "no",
|
| 534 |
"eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
|
| 535 |
"learning_rate": float(kwargs.get('learning_rate', 2e-5)),
|
| 536 |
"fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
|
| 537 |
"bf16": kwargs.get('mixed_precision') == 'bf16' and device == 'cuda',
|
| 538 |
+
"max_grad_norm": float(kwargs.get('max_grad_norm', 1.0)),
|
| 539 |
"warmup_ratio": float(kwargs.get('warmup_ratio', 0.03)),
|
| 540 |
"lr_scheduler_type": kwargs.get('scheduler', 'cosine'),
|
| 541 |
"weight_decay": float(kwargs.get('weight_decay', 0.01)),
|
|
|
|
| 570 |
|
| 571 |
return TrainingArguments(**args_dict)
|
| 572 |
|
| 573 |
+
@spaces.GPU()
|
| 574 |
def _generic_model_loader(model_name_or_path, model_class, **kwargs):
|
| 575 |
quantization_type = kwargs.get('quantization', 'no')
|
| 576 |
bnb_config = None
|
| 577 |
|
| 578 |
if quantization_type != "no" and device == "cuda":
|
| 579 |
+
try:
|
| 580 |
+
import bitsandbytes as bnb
|
| 581 |
+
if quantization_type == "4bit":
|
| 582 |
+
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype_auto, bnb_4bit_use_double_quant=True)
|
| 583 |
+
elif quantization_type == "8bit":
|
| 584 |
+
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 585 |
+
except ImportError:
|
| 586 |
+
logger.warning("bitsandbytes no está instalado. No se puede cargar en 4bit/8bit.")
|
| 587 |
+
|
| 588 |
elif quantization_type != "no" and device == "cpu":
|
| 589 |
logger.warning("La cuantización solo es compatible con GPU CUDA. Se procederá sin cuantización.")
|
| 590 |
|
|
|
|
| 625 |
|
| 626 |
return model
|
| 627 |
|
| 628 |
+
@spaces.GPU()
|
| 629 |
def _find_all_linear_names(model, quantization_type):
|
| 630 |
cls = torch.nn.Linear
|
| 631 |
if quantization_type != 'no' and device == "cuda":
|
|
|
|
| 651 |
|
| 652 |
return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
|
| 653 |
|
| 654 |
+
@spaces.GPU()
|
| 655 |
def _conversation_formatting_func(example, tokenizer, **kwargs):
|
| 656 |
conv_col = ""
|
| 657 |
for key in ["messages", "conversations", "turns"]:
|
|
|
|
| 663 |
except: return ""
|
| 664 |
return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
|
| 665 |
|
| 666 |
+
@spaces.GPU()
|
| 667 |
def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
|
| 668 |
if kwargs.get('enable_cot_input') or kwargs.get('enable_tool_use_input'):
|
| 669 |
messages = []
|
|
|
|
| 682 |
return "\n".join([m['content'] for m in messages])
|
| 683 |
return example.get(text_col, "")
|
| 684 |
|
| 685 |
+
@spaces.GPU()
|
| 686 |
def _dpo_formatting_func(example, **kwargs):
|
| 687 |
return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")}
|
| 688 |
|
| 689 |
+
@spaces.GPU()
|
| 690 |
def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
|
| 691 |
model.eval()
|
| 692 |
encodings = tokenizer("\n\n".join(ex[text_col] for ex in islice(eval_dataset, 1000)), return_tensors="pt").to(model.device)
|
|
|
|
| 711 |
ppl = torch.exp(torch.stack(nlls).mean())
|
| 712 |
return ppl.item()
|
| 713 |
|
| 714 |
+
@spaces.GPU()
|
| 715 |
def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
|
| 716 |
adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
|
| 717 |
if not adapter_ids:
|
|
|
|
| 743 |
yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}."
|
| 744 |
return temp_dir
|
| 745 |
|
| 746 |
+
@spaces.GPU()
|
| 747 |
def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 748 |
yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
|
| 749 |
trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False)
|
|
|
|
| 753 |
eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
|
| 754 |
if eval_logs:
|
| 755 |
final_metrics = eval_logs[-1]
|
| 756 |
+
final_metrics = {k.replace('eval_', ''): v for k, v in final_metrics.items()}
|
| 757 |
|
| 758 |
yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
|
| 759 |
output_dir = trainer.args.output_dir
|
|
|
|
| 772 |
torch.cuda.empty_cache()
|
| 773 |
return output_dir, final_metrics
|
| 774 |
|
| 775 |
+
@spaces.GPU()
|
| 776 |
def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 777 |
output_dir = tempfile.mkdtemp()
|
| 778 |
is_dpo = kwargs.get('training_mode') == "DPO (Direct Preference Optimization)"
|
|
|
|
| 804 |
if kwargs.get('run_evaluation'):
|
| 805 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 806 |
for update in eval_dataset_gen:
|
| 807 |
+
if isinstance(update, dict):
|
| 808 |
yield update
|
| 809 |
else:
|
| 810 |
eval_dataset = update
|
| 811 |
|
| 812 |
TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
|
| 813 |
+
trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config}
|
| 814 |
|
| 815 |
if is_dpo:
|
| 816 |
trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
|
|
|
|
| 821 |
trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
|
| 822 |
if kwargs.get('enable_loss_reweighting'):
|
| 823 |
trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
|
| 824 |
+
|
| 825 |
+
trainer = TrainerClass(**trainer_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 827 |
return final_model_path, final_metrics
|
| 828 |
|
| 829 |
except Exception as e:
|
| 830 |
raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}")
|
| 831 |
|
| 832 |
+
@spaces.GPU()
|
| 833 |
def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 834 |
output_dir = tempfile.mkdtemp()
|
| 835 |
try:
|
|
|
|
| 840 |
tokenizer_id = kwargs.get('tokenizer_name') or model_name
|
| 841 |
yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración")
|
| 842 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
|
| 843 |
+
if tokenizer.pad_token is None:
|
| 844 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 845 |
|
| 846 |
yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
|
| 847 |
model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
|
| 848 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 849 |
|
| 850 |
def preprocess(examples):
|
| 851 |
return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
|
| 852 |
+
train_dataset = train_dataset.map(preprocess, batched=True)
|
| 853 |
|
| 854 |
eval_dataset = None
|
| 855 |
if kwargs.get('run_evaluation'):
|
| 856 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 857 |
for update in eval_dataset_gen:
|
| 858 |
+
if isinstance(update, dict):
|
| 859 |
yield update
|
| 860 |
else:
|
| 861 |
eval_dataset = update
|
| 862 |
+
if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True)
|
| 863 |
|
| 864 |
metric = hf_evaluate.load("accuracy")
|
| 865 |
def compute_metrics(eval_pred):
|
|
|
|
| 879 |
except Exception as e:
|
| 880 |
raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}")
|
| 881 |
|
| 882 |
+
@spaces.GPU()
|
| 883 |
def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 884 |
output_dir = tempfile.mkdtemp()
|
| 885 |
try:
|
|
|
|
| 916 |
if kwargs.get('run_evaluation'):
|
| 917 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 918 |
for update in eval_dataset_gen:
|
| 919 |
+
if isinstance(update, dict):
|
| 920 |
yield update
|
| 921 |
else:
|
| 922 |
eval_dataset = update
|
|
|
|
| 944 |
except Exception as e:
|
| 945 |
raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}")
|
| 946 |
|
| 947 |
+
@spaces.GPU()
|
| 948 |
def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 949 |
output_dir = tempfile.mkdtemp()
|
| 950 |
try:
|
|
|
|
| 1009 |
eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 1010 |
eval_dataset_raw = None
|
| 1011 |
for update in eval_dataset_raw_gen:
|
| 1012 |
+
if isinstance(update, dict):
|
| 1013 |
yield update
|
| 1014 |
else:
|
| 1015 |
eval_dataset_raw = update
|
|
|
|
| 1029 |
except Exception as e:
|
| 1030 |
raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}")
|
| 1031 |
|
| 1032 |
+
@spaces.GPU()
|
| 1033 |
def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 1034 |
output_dir = tempfile.mkdtemp()
|
| 1035 |
try:
|
|
|
|
| 1054 |
if kwargs.get('run_evaluation'):
|
| 1055 |
eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
|
| 1056 |
for update in eval_dataset_gen:
|
| 1057 |
+
if isinstance(update, dict):
|
| 1058 |
yield update
|
| 1059 |
else:
|
| 1060 |
eval_dataset = update
|
|
|
|
| 1087 |
except Exception as e:
|
| 1088 |
raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}")
|
| 1089 |
|
| 1090 |
+
@spaces.GPU()
|
| 1091 |
def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 1092 |
if device == 'cpu':
|
| 1093 |
raise ValueError("El entrenamiento de Text-to-Image solo es compatible con GPU CUDA.")
|
|
|
|
| 1099 |
|
| 1100 |
yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
|
| 1101 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
| 1102 |
+
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=torch_dtype_auto)
|
| 1103 |
+
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=torch_dtype_auto)
|
| 1104 |
+
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=torch_dtype_auto)
|
| 1105 |
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
| 1106 |
|
| 1107 |
vae.requires_grad_(False)
|
| 1108 |
text_encoder.requires_grad_(False)
|
| 1109 |
unet.train()
|
| 1110 |
|
| 1111 |
+
yield update_logs_fn("Agregando adaptadores LoRA al UNet...", "Text-to-Image (LoRA)")
|
| 1112 |
unet_lora_config = LoraConfig(
|
| 1113 |
r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
|
| 1114 |
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
| 1115 |
)
|
| 1116 |
unet.add_adapter(unet_lora_config)
|
| 1117 |
|
| 1118 |
+
if kwargs.get('dreambooth_train_text_encoder', False):
|
| 1119 |
+
yield update_logs_fn("Agregando adaptadores LoRA al Text Encoder...", "DreamBooth LoRA")
|
| 1120 |
+
text_encoder_lora_config = LoraConfig(
|
| 1121 |
+
r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
|
| 1122 |
+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
| 1123 |
+
)
|
| 1124 |
+
text_encoder.add_adapter(text_encoder_lora_config)
|
| 1125 |
+
|
| 1126 |
yield update_logs_fn("Procesando dataset de imágenes...", "Text-to-Image (LoRA)")
|
| 1127 |
resolution = int(kwargs.get('diffusion_resolution', 512))
|
| 1128 |
|
|
|
|
| 1134 |
])
|
| 1135 |
|
| 1136 |
def preprocess_train(examples):
|
| 1137 |
+
images = [image.convert("RGB") for image in examples[kwargs.get('image_col', 'image')]]
|
| 1138 |
examples["pixel_values"] = [train_transforms(image) for image in images]
|
| 1139 |
examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
| 1140 |
return examples
|
|
|
|
| 1148 |
|
| 1149 |
def collate_fn(examples):
|
| 1150 |
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 1151 |
+
input_ids = torch.stack([e["input_ids"][0] for e in examples])
|
| 1152 |
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
| 1153 |
|
| 1154 |
train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
|
| 1155 |
+
|
| 1156 |
+
params_to_optimize = list(unet.parameters())
|
| 1157 |
+
if kwargs.get('dreambooth_train_text_encoder', False):
|
| 1158 |
+
params_to_optimize += list(text_encoder.parameters())
|
| 1159 |
|
|
|
|
| 1160 |
optimizer = torch.optim.AdamW(
|
| 1161 |
+
params_to_optimize, lr=float(kwargs.get('learning_rate', 2e-5)),
|
| 1162 |
betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
|
| 1163 |
weight_decay=float(kwargs.get('weight_decay', 0.01)),
|
| 1164 |
eps=float(kwargs.get('adam_epsilon', 1e-8)),
|
|
|
|
| 1174 |
num_training_steps=max_train_steps,
|
| 1175 |
)
|
| 1176 |
|
| 1177 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 1178 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 1179 |
)
|
| 1180 |
|
|
|
|
| 1181 |
vae.to(accelerator.device, dtype=torch_dtype_auto)
|
| 1182 |
|
|
|
|
| 1183 |
global_step = 0
|
| 1184 |
final_loss = 0
|
| 1185 |
for epoch in range(num_epochs):
|
| 1186 |
for step, batch in enumerate(train_dataloader):
|
| 1187 |
with accelerator.accumulate(unet):
|
| 1188 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=torch_dtype_auto)).latent_dist.sample()
|
| 1189 |
latents = latents * vae.config.scaling_factor
|
| 1190 |
noise = torch.randn_like(latents)
|
| 1191 |
bsz = latents.shape[0]
|
| 1192 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
|
|
|
|
|
|
| 1193 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 1194 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
|
|
|
| 1195 |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 1196 |
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
| 1197 |
final_loss = loss.detach().item()
|
| 1198 |
|
| 1199 |
accelerator.backward(loss)
|
| 1200 |
if accelerator.sync_gradients:
|
| 1201 |
+
params_to_clip = list(unet.parameters())
|
| 1202 |
+
if kwargs.get('dreambooth_train_text_encoder', False):
|
| 1203 |
+
params_to_clip += list(text_encoder.parameters())
|
| 1204 |
+
accelerator.clip_grad_norm_(params_to_clip, float(kwargs.get('max_grad_norm', 1.0)))
|
| 1205 |
|
| 1206 |
optimizer.step()
|
| 1207 |
lr_scheduler.step()
|
|
|
|
| 1209 |
|
| 1210 |
if accelerator.is_main_process:
|
| 1211 |
if global_step % int(kwargs.get('logging_steps', 10)) == 0:
|
| 1212 |
+
yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss:.4f}", "Entrenando Difusión")
|
| 1213 |
global_step += 1
|
| 1214 |
+
if global_step >= max_train_steps:
|
| 1215 |
+
break
|
| 1216 |
+
if global_step >= max_train_steps:
|
| 1217 |
+
break
|
| 1218 |
|
|
|
|
| 1219 |
accelerator.wait_for_everyone()
|
| 1220 |
if accelerator.is_main_process:
|
| 1221 |
+
pipeline = StableDiffusionText2ImagePipeline.from_pretrained(
|
| 1222 |
+
model_name,
|
| 1223 |
+
unet=accelerator.unwrap_model(unet),
|
| 1224 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 1225 |
+
torch_dtype=torch_dtype_auto,
|
| 1226 |
+
)
|
| 1227 |
pipeline.save_pretrained(output_dir)
|
| 1228 |
|
| 1229 |
with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
|
|
|
|
| 1238 |
torch.cuda.empty_cache()
|
| 1239 |
return output_dir, {"final_loss": final_loss}
|
| 1240 |
|
| 1241 |
+
@spaces.GPU()
|
| 1242 |
def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
|
| 1243 |
if device == 'cpu':
|
| 1244 |
raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
|
|
|
|
| 1253 |
|
| 1254 |
train_dataset = train_dataset.map(add_prompt)
|
| 1255 |
|
| 1256 |
+
yield update_logs_fn(f"Usando el prompt de instancia para todas las imágenes: '{dreambooth_prompt}'", "DreamBooth LoRA")
|
| 1257 |
|
| 1258 |
final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
|
| 1259 |
return final_model_path, final_metrics
|
| 1260 |
|
| 1261 |
+
@spaces.GPU()
|
| 1262 |
def _get_data_processing_pipeline(**kwargs):
|
| 1263 |
hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]
|
| 1264 |
if not hf_ids and not kwargs.get('uploads'):
|
|
|
|
| 1285 |
if train_dataset is None:
|
| 1286 |
train_dataset = hf_train_dataset
|
| 1287 |
else:
|
|
|
|
| 1288 |
all_streams = [train_dataset, hf_train_dataset]
|
| 1289 |
+
all_probs = [0.5, 0.5]
|
| 1290 |
train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
|
| 1291 |
|
| 1292 |
if train_dataset is None:
|
|
|
|
| 1296 |
text_col, image_col, audio_col, label_col = _guess_columns(first_example)
|
| 1297 |
kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
|
| 1298 |
|
| 1299 |
+
is_text_task = kwargs['training_mode'] not in ["DreamBooth LoRA (Text-to-Image)", "Text-to-Image (LoRA)", "Image Classification (Vision)", "Audio Classification (Speech)"]
|
| 1300 |
+
if is_text_task:
|
| 1301 |
if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
|
| 1302 |
clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
|
| 1303 |
train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
|
|
|
|
| 1320 |
|
| 1321 |
dedup_method = kwargs.get('deduplication_method')
|
| 1322 |
if dedup_method != 'Ninguna':
|
| 1323 |
+
train_dataset = DeduplicatedIterableDataset(
|
| 1324 |
+
dataset=train_dataset,
|
| 1325 |
+
text_col=text_col,
|
| 1326 |
+
method=dedup_method,
|
| 1327 |
+
threshold=kwargs.get('minhash_threshold', 0.85),
|
| 1328 |
+
num_perm=int(kwargs.get('minhash_num_perm', 128))
|
| 1329 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1330 |
|
| 1331 |
return train_dataset, kwargs
|
| 1332 |
|
| 1333 |
+
@spaces.GPU()
|
| 1334 |
def _train_and_upload(**kwargs):
|
| 1335 |
logs, repo_link, final_model_path, final_metrics = "", "", None, {}
|
| 1336 |
|
|
|
|
| 1420 |
raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
|
| 1421 |
base_model_id_for_training = temp_model_dir
|
| 1422 |
kwargs["peft"] = False
|
|
|
|
| 1423 |
kwargs['tokenizer_name'] = temp_model_dir
|
| 1424 |
yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
|
| 1425 |
|
|
|
|
| 1432 |
os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
|
| 1433 |
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
|
| 1434 |
|
|
|
|
| 1435 |
model_card_content = MODEL_CARD_TEMPLATE.format(
|
| 1436 |
repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
|
| 1437 |
training_mode=kwargs.get('training_mode'),
|
|
|
|
| 1458 |
train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
|
| 1459 |
while True:
|
| 1460 |
try:
|
| 1461 |
+
update = next(train_generator)
|
| 1462 |
+
if isinstance(update, tuple) and len(update) == 4:
|
| 1463 |
+
yield update + (gr.update(), gr.update())
|
| 1464 |
+
else:
|
| 1465 |
+
pass
|
| 1466 |
except StopIteration as e:
|
| 1467 |
final_model_path, final_metrics = e.value
|
| 1468 |
break
|
|
|
|
| 1476 |
eval_dataset_perp = None
|
| 1477 |
eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
|
| 1478 |
for update in eval_gen:
|
| 1479 |
+
if isinstance(update, dict):
|
| 1480 |
yield update + (gr.update(), gr.update())
|
| 1481 |
else:
|
| 1482 |
eval_dataset_perp = update
|
|
|
|
| 1507 |
gr.update(visible=False)
|
| 1508 |
)
|
| 1509 |
|
| 1510 |
+
@spaces.GPU()
|
| 1511 |
def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens):
|
| 1512 |
if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update()
|
| 1513 |
task_name = TASK_TO_PIPELINE_MAP.get(task_mode)
|
|
|
|
| 1532 |
return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
|
| 1533 |
except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
|
| 1534 |
|
| 1535 |
+
@spaces.GPU()
|
| 1536 |
def update_inference_ui(task_mode):
|
| 1537 |
task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
|
| 1538 |
is_text_gen = task_name == "text-generation"
|
|
|
|
| 1550 |
gr.update(visible=is_text_gen)
|
| 1551 |
)
|
| 1552 |
|
| 1553 |
+
@spaces.GPU()
|
| 1554 |
def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()):
|
| 1555 |
if not hf_token:
|
| 1556 |
return "Error: Se requiere un token de Hugging Face.", ""
|
|
|
|
| 1599 |
for item in all_data:
|
| 1600 |
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 1601 |
|
|
|
|
| 1602 |
readme_content = DATASET_CARD_TEMPLATE.format(
|
| 1603 |
repo_id=repo_id,
|
| 1604 |
creation_type=creation_type,
|
|
|
|
| 1623 |
except Exception as e:
|
| 1624 |
return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", ""
|
| 1625 |
|
| 1626 |
+
@spaces.GPU()
|
| 1627 |
def gradio_train_wrapper(*args):
|
| 1628 |
kwargs = dict(zip(all_input_components_dict.keys(), args))
|
| 1629 |
yield from _train_and_upload(**kwargs)
|
| 1630 |
|
| 1631 |
+
@spaces.GPU()
|
| 1632 |
def gradio_preview_data_wrapper(*args):
|
| 1633 |
kwargs = dict(zip(all_input_components_dict.keys(), args))
|
| 1634 |
try:
|
|
|
|
| 1638 |
dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
|
| 1639 |
text_col = processed_kwargs.get('text_col')
|
| 1640 |
|
| 1641 |
+
model_id_for_tokenizer = kwargs.get('model_base_input')
|
| 1642 |
+
if not model_id_for_tokenizer:
|
| 1643 |
+
raise ValueError("Se necesita un ID de modelo base para cargar el tokenizer para la vista previa.")
|
| 1644 |
+
|
| 1645 |
+
tokenizer_id = kwargs.get('tokenizer_name') or model_id_for_tokenizer
|
| 1646 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 1647 |
+
tokenizer_id, trust_remote_code=True, use_fast=False
|
|
|
|
| 1648 |
)
|
| 1649 |
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
| 1650 |
if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
|
|
|
|
| 1653 |
for i, example in enumerate(islice(dataset, 5)):
|
| 1654 |
formatted_text = ""
|
| 1655 |
if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
|
| 1656 |
+
formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False)
|
| 1657 |
else:
|
| 1658 |
formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
|
| 1659 |
|
| 1660 |
preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
|
| 1661 |
|
| 1662 |
preview_text = "\n".join(preview_samples)
|
| 1663 |
+
if not preview_samples:
|
| 1664 |
+
preview_text = "No se pudieron generar muestras. Revisa la configuración del dataset, los filtros y el formato."
|
| 1665 |
yield preview_text
|
| 1666 |
|
| 1667 |
except Exception as e:
|
| 1668 |
yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
|
| 1669 |
|
| 1670 |
+
@spaces.GPU()
|
| 1671 |
def toggle_training_mode_ui(is_scratch):
|
| 1672 |
return (
|
| 1673 |
gr.update(visible=not is_scratch),
|
|
|
|
| 1678 |
gr.update(visible=is_scratch)
|
| 1679 |
)
|
| 1680 |
|
| 1681 |
+
@spaces.GPU()
|
| 1682 |
def toggle_task_specific_ui(training_mode):
|
| 1683 |
is_classification = "Classification" in training_mode
|
| 1684 |
is_dpo = "DPO" in training_mode
|
| 1685 |
is_sft = "Causal" in training_mode
|
| 1686 |
is_ner = "Token Classification" in training_mode
|
| 1687 |
is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
|
| 1688 |
+
is_streaming = not is_diffusion
|
| 1689 |
|
| 1690 |
return (
|
| 1691 |
gr.update(visible=is_classification or is_ner),
|
|
|
|
| 1695 |
gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
|
| 1696 |
gr.update(visible=not is_diffusion),
|
| 1697 |
gr.update(visible=is_diffusion),
|
| 1698 |
+
gr.update(visible=not is_streaming),
|
| 1699 |
+
gr.update(visible=is_streaming)
|
| 1700 |
)
|
| 1701 |
|
| 1702 |
+
@spaces.GPU()
|
| 1703 |
def toggle_auto_modules_ui(is_auto):
|
| 1704 |
return gr.update(visible=not is_auto)
|
| 1705 |
|
| 1706 |
+
@spaces.GPU()
|
| 1707 |
def toggle_dataset_creator_ui(choice):
|
| 1708 |
is_synth = choice == "Sintético"
|
| 1709 |
return gr.update(visible=is_synth), gr.update(visible=not is_synth)
|
|
|
|
| 1734 |
dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
|
| 1735 |
dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
|
| 1736 |
with gr.Column(scale=2):
|
| 1737 |
+
dset_status_output = gr.Textbox(label="Estado", lines=10, interactive=False)
|
| 1738 |
dset_link_output = gr.Markdown()
|
| 1739 |
|
| 1740 |
dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
|
|
|
|
| 1790 |
with gr.Accordion("Avanzados", open=False):
|
| 1791 |
warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
|
| 1792 |
weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
|
| 1793 |
+
max_grad_norm = gr.Textbox(label="Norma Máxima de Gradiente", value="1.0")
|
| 1794 |
logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
|
| 1795 |
save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
|
| 1796 |
save_total_limit = gr.Textbox(label="Límite Total de Guardado", value="1")
|
|
|
|
| 1850 |
diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resolución")
|
| 1851 |
with gr.Group(visible=False) as dreambooth_ui:
|
| 1852 |
dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
|
|
|
|
|
|
|
|
|
|
| 1853 |
dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
|
| 1854 |
with gr.Group(visible=False) as classification_labels_ui:
|
| 1855 |
classification_labels = gr.Textbox(label="Etiquetas de Clasificación (csv)", placeholder="p.ej. positivo,negativo")
|
|
|
|
| 1868 |
enable_cda = gr.Checkbox(label="Habilitar Aumentación Contrafactual (CDA)", value=False)
|
| 1869 |
cda_json_config = gr.Textbox(label="Configuración CDA (JSON)", placeholder='[["ella", "él"], ["mujer", "hombre"]]')
|
| 1870 |
|
|
|
|
| 1871 |
with gr.Accordion("🔌 Integraciones", open=False):
|
| 1872 |
wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
|
| 1873 |
wandb_project_input = gr.Textbox(label="Proyecto W&B")
|
|
|
|
| 1912 |
"diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
|
| 1913 |
"enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
|
| 1914 |
"wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
|
| 1915 |
+
"dreambooth_instance_prompt": dreambooth_instance_prompt,
|
|
|
|
| 1916 |
"dreambooth_train_text_encoder": dreambooth_train_text_encoder
|
| 1917 |
}
|
| 1918 |
|
|
|
|
| 1984 |
outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
|
| 1985 |
)
|
| 1986 |
|
| 1987 |
+
demo.queue().launch(debug=True)
|