Ignaciohhhhggfgjfrffd commited on
Commit
231ae13
·
verified ·
1 Parent(s): c4e90bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -115
app.py CHANGED
@@ -16,6 +16,9 @@ import re
16
  import ast
17
  from itertools import islice
18
  from pathlib import Path
 
 
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch.utils.data import DataLoader
@@ -29,15 +32,14 @@ from langdetect import detect_langs
29
  import textstat
30
  from datasketch import MinHash, MinHashLSH
31
  import gradio as gr
32
- import spaces
33
- from datasets import load_dataset, IterableDataset, Dataset, DatasetDict
34
  from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
35
  from transformers import (
36
  AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
37
  AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
38
- AutoModelForImageClassification,
39
  AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
40
- DataCollatorForTokenClassification, AutoModelForQuestionAnswering,
41
  AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
42
  DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
43
  LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
@@ -54,7 +56,6 @@ from diffusers import (
54
  )
55
  import evaluate as hf_evaluate
56
  from jinja2 import Template
57
- from collections import defaultdict
58
 
59
  logger = logging.getLogger(__name__)
60
 
@@ -94,7 +95,8 @@ TASK_TO_PIPELINE_MAP = {
94
  "DreamBooth LoRA (Text-to-Image)": "text-to-image",
95
  }
96
 
97
- MODEL_CARD_TEMPLATE = """---
 
98
  language: es
99
  license: apache-2.0
100
  tags:
@@ -132,7 +134,8 @@ Este modelo es una versión afinada de [{base_model}](https://huggingface.co/{ba
132
  - Gradio
133
  """
134
 
135
- DATASET_CARD_TEMPLATE = """---
 
136
  license: mit
137
  ---
138
 
@@ -147,6 +150,7 @@ Este dataset fue creado utilizando la herramienta [AutoTrain-Advanced](https://h
147
  - **Fecha de Creación:** {date}
148
  """
149
 
 
150
  class DebiasingSFTTrainer(SFTTrainer):
151
  def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
152
  super().__init__(*args, **kwargs)
@@ -164,6 +168,54 @@ class DebiasingSFTTrainer(SFTTrainer):
164
  break
165
  return (loss, outputs) if return_outputs else loss
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def hf_login(token):
168
  if not token:
169
  return "Por favor, introduce un token."
@@ -174,6 +226,7 @@ def hf_login(token):
174
  except Exception as e:
175
  return f"❌ Error en la conexión: {e}"
176
 
 
177
  def _clean_text(example, text_col, **kwargs):
178
  text = example.get(text_col, "")
179
  if not isinstance(text, str):
@@ -191,6 +244,7 @@ def _clean_text(example, text_col, **kwargs):
191
  example[text_col] = text
192
  return example
193
 
 
194
  def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords):
195
  text = example.get(text_col, "")
196
  if not isinstance(text, str): return False
@@ -204,6 +258,7 @@ def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, e
204
  lower_text = text.lower()
205
  return not any(keyword in lower_text for keyword in exclude_keywords)
206
 
 
207
  def _get_filter_functions(**kwargs):
208
  filters = []
209
  if kwargs.get('enable_quality_filter'):
@@ -252,6 +307,7 @@ def _get_filter_functions(**kwargs):
252
  filters.append(stats_filter)
253
  return filters
254
 
 
255
  def _load_hf_streaming(ids, split="train", probabilities=None):
256
  streams = []
257
  valid_ids = []
@@ -279,10 +335,9 @@ def _load_hf_streaming(ids, split="train", probabilities=None):
279
  if probabilities and len(probabilities) != len(streams):
280
  logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
281
  probabilities = None
282
-
283
- from datasets import interleave_datasets
284
  return interleave_datasets(streams, probabilities=probabilities)
285
 
 
286
  def _load_uploaded_stream(files):
287
  all_rows = []
288
  for f in files or []:
@@ -304,6 +359,7 @@ def _load_uploaded_stream(files):
304
  random.shuffle(all_rows)
305
  return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []}
306
 
 
307
  def _guess_columns(sample):
308
  text_col, image_col, audio_col, label_col = "text", "image", "audio", "label"
309
  if not isinstance(sample, dict):
@@ -320,6 +376,7 @@ def _guess_columns(sample):
320
  elif "labels" in keys: label_col = keys["labels"]
321
  return text_col, image_col, audio_col, label_col
322
 
 
323
  def _apply_cda(dataset, text_col, cda_config_str):
324
  try:
325
  swap_groups = json.loads(cda_config_str)
@@ -352,6 +409,7 @@ def _apply_cda(dataset, text_col, cda_config_str):
352
  current_texts.update(next_texts)
353
  return IterableDataset.from_generator(cda_generator)
354
 
 
355
  def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
356
  if not ratio or ratio <= 0:
357
  return dataset
@@ -379,6 +437,7 @@ def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id
379
  logger.warning(f"Error en retrotraducción: {e}")
380
  return IterableDataset.from_generator(bt_generator)
381
 
 
382
  def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
383
  if not num_samples or num_samples <= 0:
384
  return None
@@ -411,6 +470,7 @@ def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples,
411
  continue
412
  return IterableDataset.from_generator(synthetic_generator)
413
 
 
414
  def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
415
  safe_steps = int(steps_per_epoch_estimate or 10000)
416
  safe_batch_size = int(batch_size or 1)
@@ -429,6 +489,7 @@ def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, b
429
  kv_heads = heads if is_gpt2_like else (max(1, heads // 4))
430
  return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads
431
 
 
432
  def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn):
433
  if eval_ds_id:
434
  yield update_logs_fn(f"Cargando dataset de evaluación: {eval_ds_id}", "Evaluación")
@@ -450,6 +511,7 @@ def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn
450
  yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
451
  return None
452
 
 
453
  def _create_training_args(output_dir, repo_id, **kwargs):
454
  neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
455
  optim_args_dict = {}
@@ -468,11 +530,12 @@ def _create_training_args(output_dir, repo_id, **kwargs):
468
  "save_strategy": "steps",
469
  "logging_steps": int(kwargs.get('logging_steps', 10)),
470
  "save_steps": int(kwargs.get('save_steps', 50)),
 
471
  "eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
472
  "learning_rate": float(kwargs.get('learning_rate', 2e-5)),
473
  "fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
474
  "bf16": kwargs.get('mixed_precision') == 'bf16' and device == 'cuda',
475
- "max_grad_norm": float(kwargs.get('max_grad_norm', 0.3)),
476
  "warmup_ratio": float(kwargs.get('warmup_ratio', 0.03)),
477
  "lr_scheduler_type": kwargs.get('scheduler', 'cosine'),
478
  "weight_decay": float(kwargs.get('weight_decay', 0.01)),
@@ -507,15 +570,21 @@ def _create_training_args(output_dir, repo_id, **kwargs):
507
 
508
  return TrainingArguments(**args_dict)
509
 
 
510
  def _generic_model_loader(model_name_or_path, model_class, **kwargs):
511
  quantization_type = kwargs.get('quantization', 'no')
512
  bnb_config = None
513
 
514
  if quantization_type != "no" and device == "cuda":
515
- if quantization_type == "4bit":
516
- bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype_auto, bnb_4bit_use_double_quant=True)
517
- elif quantization_type == "8bit":
518
- bnb_config = BitsAndBytesConfig(load_in_8bit=True)
 
 
 
 
 
519
  elif quantization_type != "no" and device == "cpu":
520
  logger.warning("La cuantización solo es compatible con GPU CUDA. Se procederá sin cuantización.")
521
 
@@ -556,6 +625,7 @@ def _generic_model_loader(model_name_or_path, model_class, **kwargs):
556
 
557
  return model
558
 
 
559
  def _find_all_linear_names(model, quantization_type):
560
  cls = torch.nn.Linear
561
  if quantization_type != 'no' and device == "cuda":
@@ -581,6 +651,7 @@ def _find_all_linear_names(model, quantization_type):
581
 
582
  return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
583
 
 
584
  def _conversation_formatting_func(example, tokenizer, **kwargs):
585
  conv_col = ""
586
  for key in ["messages", "conversations", "turns"]:
@@ -592,6 +663,7 @@ def _conversation_formatting_func(example, tokenizer, **kwargs):
592
  except: return ""
593
  return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
594
 
 
595
  def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
596
  if kwargs.get('enable_cot_input') or kwargs.get('enable_tool_use_input'):
597
  messages = []
@@ -610,9 +682,11 @@ def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
610
  return "\n".join([m['content'] for m in messages])
611
  return example.get(text_col, "")
612
 
 
613
  def _dpo_formatting_func(example, **kwargs):
614
  return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")}
615
 
 
616
  def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
617
  model.eval()
618
  encodings = tokenizer("\n\n".join(ex[text_col] for ex in islice(eval_dataset, 1000)), return_tensors="pt").to(model.device)
@@ -637,6 +711,7 @@ def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
637
  ppl = torch.exp(torch.stack(nlls).mean())
638
  return ppl.item()
639
 
 
640
  def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
641
  adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
642
  if not adapter_ids:
@@ -668,6 +743,7 @@ def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combinati
668
  yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}."
669
  return temp_dir
670
 
 
671
  def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
672
  yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
673
  trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False)
@@ -677,6 +753,7 @@ def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_c
677
  eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
678
  if eval_logs:
679
  final_metrics = eval_logs[-1]
 
680
 
681
  yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
682
  output_dir = trainer.args.output_dir
@@ -695,6 +772,7 @@ def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_c
695
  torch.cuda.empty_cache()
696
  return output_dir, final_metrics
697
 
 
698
  def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
699
  output_dir = tempfile.mkdtemp()
700
  is_dpo = kwargs.get('training_mode') == "DPO (Direct Preference Optimization)"
@@ -726,13 +804,13 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
726
  if kwargs.get('run_evaluation'):
727
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
728
  for update in eval_dataset_gen:
729
- if isinstance(update, tuple):
730
  yield update
731
  else:
732
  eval_dataset = update
733
 
734
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
735
- trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config, "tokenizer": tokenizer}
736
 
737
  if is_dpo:
738
  trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
@@ -743,24 +821,15 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
743
  trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
744
  if kwargs.get('enable_loss_reweighting'):
745
  trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
746
-
747
- try:
748
- trainer = TrainerClass(**trainer_kwargs)
749
- except TypeError as e:
750
- if "unexpected keyword argument 'tokenizer'" in str(e):
751
- logger.warning("Caught TypeError for tokenizer argument. Retrying without it for TRL compatibility.")
752
- trainer_kwargs.pop("tokenizer", None)
753
- trainer = TrainerClass(**trainer_kwargs)
754
- trainer.tokenizer = tokenizer
755
- else:
756
- raise e
757
-
758
  final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
759
  return final_model_path, final_metrics
760
 
761
  except Exception as e:
762
  raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}")
763
 
 
764
  def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
765
  output_dir = tempfile.mkdtemp()
766
  try:
@@ -771,23 +840,26 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
771
  tokenizer_id = kwargs.get('tokenizer_name') or model_name
772
  yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración")
773
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
 
 
774
 
775
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
776
  model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
 
777
 
778
  def preprocess(examples):
779
  return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
780
- train_dataset = train_dataset.map(preprocess)
781
 
782
  eval_dataset = None
783
  if kwargs.get('run_evaluation'):
784
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
785
  for update in eval_dataset_gen:
786
- if isinstance(update, tuple):
787
  yield update
788
  else:
789
  eval_dataset = update
790
- if eval_dataset: eval_dataset = eval_dataset.map(preprocess)
791
 
792
  metric = hf_evaluate.load("accuracy")
793
  def compute_metrics(eval_pred):
@@ -807,6 +879,7 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
807
  except Exception as e:
808
  raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}")
809
 
 
810
  def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
811
  output_dir = tempfile.mkdtemp()
812
  try:
@@ -843,7 +916,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
843
  if kwargs.get('run_evaluation'):
844
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
845
  for update in eval_dataset_gen:
846
- if isinstance(update, tuple):
847
  yield update
848
  else:
849
  eval_dataset = update
@@ -871,6 +944,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
871
  except Exception as e:
872
  raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}")
873
 
 
874
  def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
875
  output_dir = tempfile.mkdtemp()
876
  try:
@@ -935,7 +1009,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
935
  eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
936
  eval_dataset_raw = None
937
  for update in eval_dataset_raw_gen:
938
- if isinstance(update, tuple):
939
  yield update
940
  else:
941
  eval_dataset_raw = update
@@ -955,6 +1029,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
955
  except Exception as e:
956
  raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}")
957
 
 
958
  def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
959
  output_dir = tempfile.mkdtemp()
960
  try:
@@ -979,7 +1054,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
979
  if kwargs.get('run_evaluation'):
980
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
981
  for update in eval_dataset_gen:
982
- if isinstance(update, tuple):
983
  yield update
984
  else:
985
  eval_dataset = update
@@ -1012,6 +1087,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
1012
  except Exception as e:
1013
  raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}")
1014
 
 
1015
  def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1016
  if device == 'cpu':
1017
  raise ValueError("El entrenamiento de Text-to-Image solo es compatible con GPU CUDA.")
@@ -1023,22 +1099,30 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1023
 
1024
  yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
1025
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
1026
- text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
1027
- vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
1028
- unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")
1029
  noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
1030
 
1031
  vae.requires_grad_(False)
1032
  text_encoder.requires_grad_(False)
1033
  unet.train()
1034
 
1035
- yield update_logs_fn("Agregando adaptadores LoRA al modelo...", "Text-to-Image (LoRA)")
1036
  unet_lora_config = LoraConfig(
1037
  r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
1038
  target_modules=["to_q", "to_k", "to_v", "to_out.0"],
1039
  )
1040
  unet.add_adapter(unet_lora_config)
1041
 
 
 
 
 
 
 
 
 
1042
  yield update_logs_fn("Procesando dataset de imágenes...", "Text-to-Image (LoRA)")
1043
  resolution = int(kwargs.get('diffusion_resolution', 512))
1044
 
@@ -1050,7 +1134,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1050
  ])
1051
 
1052
  def preprocess_train(examples):
1053
- images = [Image.open(image).convert("RGB") for image in examples[kwargs.get('image_col', 'image')]]
1054
  examples["pixel_values"] = [train_transforms(image) for image in images]
1055
  examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
1056
  return examples
@@ -1064,14 +1148,17 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1064
 
1065
  def collate_fn(examples):
1066
  pixel_values = torch.stack([example["pixel_values"] for example in examples])
1067
- input_ids = torch.stack([example["input_ids"] for example in examples])
1068
  return {"pixel_values": pixel_values, "input_ids": input_ids}
1069
 
1070
  train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
 
 
 
 
1071
 
1072
- yield update_logs_fn("Configurando optimizador y planificador...", "Text-to-Image (LoRA)")
1073
  optimizer = torch.optim.AdamW(
1074
- unet.parameters(), lr=float(kwargs.get('learning_rate', 2e-5)),
1075
  betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
1076
  weight_decay=float(kwargs.get('weight_decay', 0.01)),
1077
  eps=float(kwargs.get('adam_epsilon', 1e-8)),
@@ -1087,36 +1174,34 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1087
  num_training_steps=max_train_steps,
1088
  )
1089
 
1090
- unet, optimizer, train_dataloader, lr_scheduler, text_encoder, vae = accelerator.prepare(
1091
- unet, optimizer, train_dataloader, lr_scheduler, text_encoder, vae
1092
  )
1093
 
1094
- text_encoder.to(accelerator.device, dtype=torch_dtype_auto)
1095
  vae.to(accelerator.device, dtype=torch_dtype_auto)
1096
 
1097
- yield update_logs_fn("Iniciando bucle de entrenamiento de difusión...", "Text-to-Image (LoRA)")
1098
  global_step = 0
1099
  final_loss = 0
1100
  for epoch in range(num_epochs):
1101
  for step, batch in enumerate(train_dataloader):
1102
  with accelerator.accumulate(unet):
1103
- latents = vae.encode(batch["pixel_values"].to(torch_dtype_auto)).latent_dist.sample()
1104
  latents = latents * vae.config.scaling_factor
1105
  noise = torch.randn_like(latents)
1106
  bsz = latents.shape[0]
1107
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1108
- timesteps = timesteps.long()
1109
-
1110
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1111
- encoder_hidden_states = text_encoder(batch["input_ids"].to(accelerator.device))[0]
1112
-
1113
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1114
  loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
1115
  final_loss = loss.detach().item()
1116
 
1117
  accelerator.backward(loss)
1118
  if accelerator.sync_gradients:
1119
- accelerator.clip_grad_norm_(unet.parameters(), float(kwargs.get('max_grad_norm', 1.0)))
 
 
 
1120
 
1121
  optimizer.step()
1122
  lr_scheduler.step()
@@ -1124,16 +1209,21 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1124
 
1125
  if accelerator.is_main_process:
1126
  if global_step % int(kwargs.get('logging_steps', 10)) == 0:
1127
- yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss}", "Text-to-Image (LoRA)")
1128
  global_step += 1
 
 
 
 
1129
 
1130
- yield update_logs_fn("Entrenamiento completado, guardando modelo...", "Text-to-Image (LoRA)")
1131
  accelerator.wait_for_everyone()
1132
  if accelerator.is_main_process:
1133
- unwrapped_unet = accelerator.unwrap_model(unet)
1134
-
1135
- pipeline = StableDiffusionText2ImagePipeline.from_pretrained(model_name, torch_dtype=torch_dtype_auto)
1136
- pipeline.unet.load_state_dict(unwrapped_unet.state_dict())
 
 
1137
  pipeline.save_pretrained(output_dir)
1138
 
1139
  with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
@@ -1148,7 +1238,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1148
  torch.cuda.empty_cache()
1149
  return output_dir, {"final_loss": final_loss}
1150
 
1151
-
1152
  def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1153
  if device == 'cpu':
1154
  raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
@@ -1163,11 +1253,12 @@ def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, mo
1163
 
1164
  train_dataset = train_dataset.map(add_prompt)
1165
 
1166
- yield update_logs_fn(f"Usando el prompt de instancia para todas las imágenes: '{dreambooth_prompt}'", "DreamBooth LoRA (Text-to-Image)")
1167
 
1168
  final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
1169
  return final_model_path, final_metrics
1170
 
 
1171
  def _get_data_processing_pipeline(**kwargs):
1172
  hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]
1173
  if not hf_ids and not kwargs.get('uploads'):
@@ -1194,9 +1285,8 @@ def _get_data_processing_pipeline(**kwargs):
1194
  if train_dataset is None:
1195
  train_dataset = hf_train_dataset
1196
  else:
1197
- from datasets import interleave_datasets
1198
  all_streams = [train_dataset, hf_train_dataset]
1199
- all_probs = [0.5, 0.5] if not probabilities else [probabilities] + probabilities[1:]
1200
  train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
1201
 
1202
  if train_dataset is None:
@@ -1206,7 +1296,8 @@ def _get_data_processing_pipeline(**kwargs):
1206
  text_col, image_col, audio_col, label_col = _guess_columns(first_example)
1207
  kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
1208
 
1209
- if kwargs['training_mode'] not in ["DreamBooth LoRA (Text-to-Image)", "Text-to-Image (LoRA)"]:
 
1210
  if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
1211
  clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
1212
  train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
@@ -1229,38 +1320,17 @@ def _get_data_processing_pipeline(**kwargs):
1229
 
1230
  dedup_method = kwargs.get('deduplication_method')
1231
  if dedup_method != 'Ninguna':
1232
- base_iterator = train_dataset
1233
- if dedup_method == 'Exacta':
1234
- def dedup_generator_exact():
1235
- seen_texts = set()
1236
- for example in base_iterator:
1237
- text = example.get(text_col, "")
1238
- if not isinstance(text, str) or text not in seen_texts:
1239
- if isinstance(text, str) and text:
1240
- seen_texts.add(text)
1241
- yield example
1242
- train_dataset = IterableDataset.from_generator(dedup_generator_exact)
1243
- elif dedup_method == 'Semántica (MinHash)':
1244
- threshold = kwargs.get('minhash_threshold', 0.85)
1245
- num_perm = int(kwargs.get('minhash_num_perm', 128))
1246
- def dedup_generator_minhash():
1247
- lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
1248
- for i, example in enumerate(base_iterator):
1249
- text = example.get(text_col, "")
1250
- if text and isinstance(text, str) and text.strip():
1251
- m = MinHash(num_perm=num_perm)
1252
- for d in text.split():
1253
- m.update(d.encode('utf8'))
1254
- if not lsh.query(m):
1255
- lsh.insert(f"key_{i}", m)
1256
- yield example
1257
- else:
1258
- yield example
1259
- train_dataset = IterableDataset.from_generator(dedup_generator_minhash)
1260
-
1261
 
1262
  return train_dataset, kwargs
1263
 
 
1264
  def _train_and_upload(**kwargs):
1265
  logs, repo_link, final_model_path, final_metrics = "", "", None, {}
1266
 
@@ -1350,7 +1420,6 @@ def _train_and_upload(**kwargs):
1350
  raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
1351
  base_model_id_for_training = temp_model_dir
1352
  kwargs["peft"] = False
1353
- kwargs["merge_adapter"] = False
1354
  kwargs['tokenizer_name'] = temp_model_dir
1355
  yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
1356
 
@@ -1363,7 +1432,6 @@ def _train_and_upload(**kwargs):
1363
  os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
1364
  os.environ["WANDB_LOG_MODEL"] = "checkpoint"
1365
 
1366
- from datetime import datetime
1367
  model_card_content = MODEL_CARD_TEMPLATE.format(
1368
  repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
1369
  training_mode=kwargs.get('training_mode'),
@@ -1390,8 +1458,11 @@ def _train_and_upload(**kwargs):
1390
  train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
1391
  while True:
1392
  try:
1393
- update_tuple = next(train_generator)
1394
- yield update_tuple + (gr.update(), gr.update())
 
 
 
1395
  except StopIteration as e:
1396
  final_model_path, final_metrics = e.value
1397
  break
@@ -1405,7 +1476,7 @@ def _train_and_upload(**kwargs):
1405
  eval_dataset_perp = None
1406
  eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
1407
  for update in eval_gen:
1408
- if isinstance(update, tuple):
1409
  yield update + (gr.update(), gr.update())
1410
  else:
1411
  eval_dataset_perp = update
@@ -1436,6 +1507,7 @@ def _train_and_upload(**kwargs):
1436
  gr.update(visible=False)
1437
  )
1438
 
 
1439
  def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens):
1440
  if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1441
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode)
@@ -1460,6 +1532,7 @@ def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in,
1460
  return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1461
  except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1462
 
 
1463
  def update_inference_ui(task_mode):
1464
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
1465
  is_text_gen = task_name == "text-generation"
@@ -1477,6 +1550,7 @@ def update_inference_ui(task_mode):
1477
  gr.update(visible=is_text_gen)
1478
  )
1479
 
 
1480
  def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()):
1481
  if not hf_token:
1482
  return "Error: Se requiere un token de Hugging Face.", ""
@@ -1525,7 +1599,6 @@ def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, s
1525
  for item in all_data:
1526
  f.write(json.dumps(item, ensure_ascii=False) + "\n")
1527
 
1528
- from datetime import datetime
1529
  readme_content = DATASET_CARD_TEMPLATE.format(
1530
  repo_id=repo_id,
1531
  creation_type=creation_type,
@@ -1550,10 +1623,12 @@ def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, s
1550
  except Exception as e:
1551
  return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", ""
1552
 
 
1553
  def gradio_train_wrapper(*args):
1554
  kwargs = dict(zip(all_input_components_dict.keys(), args))
1555
  yield from _train_and_upload(**kwargs)
1556
 
 
1557
  def gradio_preview_data_wrapper(*args):
1558
  kwargs = dict(zip(all_input_components_dict.keys(), args))
1559
  try:
@@ -1563,9 +1638,13 @@ def gradio_preview_data_wrapper(*args):
1563
  dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
1564
  text_col = processed_kwargs.get('text_col')
1565
 
 
 
 
 
 
1566
  tokenizer = AutoTokenizer.from_pretrained(
1567
- kwargs.get('tokenizer_name') or kwargs.get('model_base_input') or 'gpt2',
1568
- trust_remote_code=True, use_fast=False
1569
  )
1570
  if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
1571
  if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
@@ -1574,20 +1653,21 @@ def gradio_preview_data_wrapper(*args):
1574
  for i, example in enumerate(islice(dataset, 5)):
1575
  formatted_text = ""
1576
  if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
1577
- formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2)
1578
  else:
1579
  formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
1580
 
1581
  preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
1582
 
1583
  preview_text = "\n".join(preview_samples)
1584
- if not preview_text:
1585
- preview_text = "No se pudieron generar muestras. Revisa la configuración del dataset y el formato."
1586
  yield preview_text
1587
 
1588
  except Exception as e:
1589
  yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
1590
 
 
1591
  def toggle_training_mode_ui(is_scratch):
1592
  return (
1593
  gr.update(visible=not is_scratch),
@@ -1598,12 +1678,14 @@ def toggle_training_mode_ui(is_scratch):
1598
  gr.update(visible=is_scratch)
1599
  )
1600
 
 
1601
  def toggle_task_specific_ui(training_mode):
1602
  is_classification = "Classification" in training_mode
1603
  is_dpo = "DPO" in training_mode
1604
  is_sft = "Causal" in training_mode
1605
  is_ner = "Token Classification" in training_mode
1606
  is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
 
1607
 
1608
  return (
1609
  gr.update(visible=is_classification or is_ner),
@@ -1613,13 +1695,15 @@ def toggle_task_specific_ui(training_mode):
1613
  gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
1614
  gr.update(visible=not is_diffusion),
1615
  gr.update(visible=is_diffusion),
1616
- gr.update(visible=not is_diffusion),
 
1617
  )
1618
 
1619
-
1620
  def toggle_auto_modules_ui(is_auto):
1621
  return gr.update(visible=not is_auto)
1622
 
 
1623
  def toggle_dataset_creator_ui(choice):
1624
  is_synth = choice == "Sintético"
1625
  return gr.update(visible=is_synth), gr.update(visible=not is_synth)
@@ -1650,7 +1734,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1650
  dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
1651
  dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
1652
  with gr.Column(scale=2):
1653
- dset_status_output = gr.Textbox(label="Estado", lines=10)
1654
  dset_link_output = gr.Markdown()
1655
 
1656
  dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
@@ -1706,7 +1790,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1706
  with gr.Accordion("Avanzados", open=False):
1707
  warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
1708
  weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
1709
- max_grad_norm = gr.Textbox(label="Norma Máxima de Gradiente", value="0.3")
1710
  logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1711
  save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1712
  save_total_limit = gr.Textbox(label="Límite Total de Guardado", value="1")
@@ -1766,9 +1850,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1766
  diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resolución")
1767
  with gr.Group(visible=False) as dreambooth_ui:
1768
  dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
1769
- dreambooth_class_prompt = gr.Textbox(label="Prompt de Clase (Opcional)", placeholder="p.ej. 'foto de perro'")
1770
- dreambooth_num_class_images = gr.Slider(0, 1000, 100, step=10, label="Nº de Imágenes de Clase")
1771
- dreambooth_prior_loss_weight = gr.Slider(0.0, 2.0, 1.0, label="Peso de Pérdida a Priori")
1772
  dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
1773
  with gr.Group(visible=False) as classification_labels_ui:
1774
  classification_labels = gr.Textbox(label="Etiquetas de Clasificación (csv)", placeholder="p.ej. positivo,negativo")
@@ -1787,7 +1868,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1787
  enable_cda = gr.Checkbox(label="Habilitar Aumentación Contrafactual (CDA)", value=False)
1788
  cda_json_config = gr.Textbox(label="Configuración CDA (JSON)", placeholder='[["ella", "él"], ["mujer", "hombre"]]')
1789
 
1790
-
1791
  with gr.Accordion("🔌 Integraciones", open=False):
1792
  wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
1793
  wandb_project_input = gr.Textbox(label="Proyecto W&B")
@@ -1832,8 +1912,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1832
  "diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
1833
  "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
1834
  "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
1835
- "dreambooth_instance_prompt": dreambooth_instance_prompt, "dreambooth_class_prompt": dreambooth_class_prompt,
1836
- "dreambooth_num_class_images": dreambooth_num_class_images, "dreambooth_prior_loss_weight": dreambooth_prior_loss_weight,
1837
  "dreambooth_train_text_encoder": dreambooth_train_text_encoder
1838
  }
1839
 
@@ -1905,4 +1984,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1905
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1906
  )
1907
 
1908
- demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
16
  import ast
17
  from itertools import islice
18
  from pathlib import Path
19
+ from collections import defaultdict
20
+ from datetime import datetime
21
+
22
  import torch
23
  import torch.nn.functional as F
24
  from torch.utils.data import DataLoader
 
32
  import textstat
33
  from datasketch import MinHash, MinHashLSH
34
  import gradio as gr
35
+ from datasets import load_dataset, IterableDataset, Dataset, DatasetDict, interleave_datasets, Audio
 
36
  from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
37
  from transformers import (
38
  AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
39
  AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
40
+ SpeechT5ForTextToSpeech, SpeechT5Processor, SpeechT5HifiGan, AutoModelForImageClassification,
41
  AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
42
+ DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq,
43
  AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
44
  DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
45
  LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
 
56
  )
57
  import evaluate as hf_evaluate
58
  from jinja2 import Template
 
59
 
60
  logger = logging.getLogger(__name__)
61
 
 
95
  "DreamBooth LoRA (Text-to-Image)": "text-to-image",
96
  }
97
 
98
+ MODEL_CARD_TEMPLATE = """
99
+ ---
100
  language: es
101
  license: apache-2.0
102
  tags:
 
134
  - Gradio
135
  """
136
 
137
+ DATASET_CARD_TEMPLATE = """
138
+ ---
139
  license: mit
140
  ---
141
 
 
150
  - **Fecha de Creación:** {date}
151
  """
152
 
153
+ @spaces.GPU()
154
  class DebiasingSFTTrainer(SFTTrainer):
155
  def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
156
  super().__init__(*args, **kwargs)
 
168
  break
169
  return (loss, outputs) if return_outputs else loss
170
 
171
+ @spaces.GPU()
172
+ class DeduplicatedIterableDataset(IterableDataset):
173
+ def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
174
+ super().__init__(ex_iterable=iter([]))
175
+ self.dataset = dataset
176
+ self.text_col = text_col
177
+ self.method = method
178
+ self.threshold = threshold
179
+ self.num_perm = num_perm
180
+ if hasattr(dataset, '_info'):
181
+ self._info = dataset._info
182
+ elif hasattr(dataset, 'info'):
183
+ self._info = dataset.info
184
+
185
+ def __iter__(self):
186
+ if self.method == 'Exacta':
187
+ return self._exact_iter()
188
+ elif self.method == 'Semántica (MinHash)':
189
+ return self._minhash_iter()
190
+ else:
191
+ return iter(self.dataset)
192
+
193
+ def _exact_iter(self):
194
+ seen_texts = set()
195
+ for example in self.dataset:
196
+ text = example.get(self.text_col, "")
197
+ if text and isinstance(text, str):
198
+ if text not in seen_texts:
199
+ seen_texts.add(text)
200
+ yield example
201
+ else:
202
+ yield example
203
+
204
+ def _minhash_iter(self):
205
+ lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm)
206
+ for i, example in enumerate(self.dataset):
207
+ text = example.get(self.text_col, "")
208
+ if text and isinstance(text, str) and text.strip():
209
+ m = MinHash(num_perm=self.num_perm)
210
+ for d in text.split():
211
+ m.update(d.encode('utf8'))
212
+ if not lsh.query(m):
213
+ lsh.insert(f"key_{i}", m)
214
+ yield example
215
+ else:
216
+ yield example
217
+
218
+ @spaces.GPU()
219
  def hf_login(token):
220
  if not token:
221
  return "Por favor, introduce un token."
 
226
  except Exception as e:
227
  return f"❌ Error en la conexión: {e}"
228
 
229
+ @spaces.GPU()
230
  def _clean_text(example, text_col, **kwargs):
231
  text = example.get(text_col, "")
232
  if not isinstance(text, str):
 
244
  example[text_col] = text
245
  return example
246
 
247
+ @spaces.GPU()
248
  def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords):
249
  text = example.get(text_col, "")
250
  if not isinstance(text, str): return False
 
258
  lower_text = text.lower()
259
  return not any(keyword in lower_text for keyword in exclude_keywords)
260
 
261
+ @spaces.GPU()
262
  def _get_filter_functions(**kwargs):
263
  filters = []
264
  if kwargs.get('enable_quality_filter'):
 
307
  filters.append(stats_filter)
308
  return filters
309
 
310
+ @spaces.GPU()
311
  def _load_hf_streaming(ids, split="train", probabilities=None):
312
  streams = []
313
  valid_ids = []
 
335
  if probabilities and len(probabilities) != len(streams):
336
  logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
337
  probabilities = None
 
 
338
  return interleave_datasets(streams, probabilities=probabilities)
339
 
340
+ @spaces.GPU()
341
  def _load_uploaded_stream(files):
342
  all_rows = []
343
  for f in files or []:
 
359
  random.shuffle(all_rows)
360
  return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []}
361
 
362
+ @spaces.GPU()
363
  def _guess_columns(sample):
364
  text_col, image_col, audio_col, label_col = "text", "image", "audio", "label"
365
  if not isinstance(sample, dict):
 
376
  elif "labels" in keys: label_col = keys["labels"]
377
  return text_col, image_col, audio_col, label_col
378
 
379
+ @spaces.GPU()
380
  def _apply_cda(dataset, text_col, cda_config_str):
381
  try:
382
  swap_groups = json.loads(cda_config_str)
 
409
  current_texts.update(next_texts)
410
  return IterableDataset.from_generator(cda_generator)
411
 
412
+ @spaces.GPU()
413
  def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
414
  if not ratio or ratio <= 0:
415
  return dataset
 
437
  logger.warning(f"Error en retrotraducción: {e}")
438
  return IterableDataset.from_generator(bt_generator)
439
 
440
+ @spaces.GPU()
441
  def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
442
  if not num_samples or num_samples <= 0:
443
  return None
 
470
  continue
471
  return IterableDataset.from_generator(synthetic_generator)
472
 
473
+ @spaces.GPU()
474
  def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
475
  safe_steps = int(steps_per_epoch_estimate or 10000)
476
  safe_batch_size = int(batch_size or 1)
 
489
  kv_heads = heads if is_gpt2_like else (max(1, heads // 4))
490
  return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads
491
 
492
+ @spaces.GPU()
493
  def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn):
494
  if eval_ds_id:
495
  yield update_logs_fn(f"Cargando dataset de evaluación: {eval_ds_id}", "Evaluación")
 
511
  yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
512
  return None
513
 
514
+ @spaces.GPU()
515
  def _create_training_args(output_dir, repo_id, **kwargs):
516
  neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
517
  optim_args_dict = {}
 
530
  "save_strategy": "steps",
531
  "logging_steps": int(kwargs.get('logging_steps', 10)),
532
  "save_steps": int(kwargs.get('save_steps', 50)),
533
+ "evaluation_strategy": "steps" if kwargs.get('run_evaluation', False) else "no",
534
  "eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
535
  "learning_rate": float(kwargs.get('learning_rate', 2e-5)),
536
  "fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
537
  "bf16": kwargs.get('mixed_precision') == 'bf16' and device == 'cuda',
538
+ "max_grad_norm": float(kwargs.get('max_grad_norm', 1.0)),
539
  "warmup_ratio": float(kwargs.get('warmup_ratio', 0.03)),
540
  "lr_scheduler_type": kwargs.get('scheduler', 'cosine'),
541
  "weight_decay": float(kwargs.get('weight_decay', 0.01)),
 
570
 
571
  return TrainingArguments(**args_dict)
572
 
573
+ @spaces.GPU()
574
  def _generic_model_loader(model_name_or_path, model_class, **kwargs):
575
  quantization_type = kwargs.get('quantization', 'no')
576
  bnb_config = None
577
 
578
  if quantization_type != "no" and device == "cuda":
579
+ try:
580
+ import bitsandbytes as bnb
581
+ if quantization_type == "4bit":
582
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype_auto, bnb_4bit_use_double_quant=True)
583
+ elif quantization_type == "8bit":
584
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
585
+ except ImportError:
586
+ logger.warning("bitsandbytes no está instalado. No se puede cargar en 4bit/8bit.")
587
+
588
  elif quantization_type != "no" and device == "cpu":
589
  logger.warning("La cuantización solo es compatible con GPU CUDA. Se procederá sin cuantización.")
590
 
 
625
 
626
  return model
627
 
628
+ @spaces.GPU()
629
  def _find_all_linear_names(model, quantization_type):
630
  cls = torch.nn.Linear
631
  if quantization_type != 'no' and device == "cuda":
 
651
 
652
  return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
653
 
654
+ @spaces.GPU()
655
  def _conversation_formatting_func(example, tokenizer, **kwargs):
656
  conv_col = ""
657
  for key in ["messages", "conversations", "turns"]:
 
663
  except: return ""
664
  return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
665
 
666
+ @spaces.GPU()
667
  def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
668
  if kwargs.get('enable_cot_input') or kwargs.get('enable_tool_use_input'):
669
  messages = []
 
682
  return "\n".join([m['content'] for m in messages])
683
  return example.get(text_col, "")
684
 
685
+ @spaces.GPU()
686
  def _dpo_formatting_func(example, **kwargs):
687
  return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")}
688
 
689
+ @spaces.GPU()
690
  def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
691
  model.eval()
692
  encodings = tokenizer("\n\n".join(ex[text_col] for ex in islice(eval_dataset, 1000)), return_tensors="pt").to(model.device)
 
711
  ppl = torch.exp(torch.stack(nlls).mean())
712
  return ppl.item()
713
 
714
+ @spaces.GPU()
715
  def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
716
  adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
717
  if not adapter_ids:
 
743
  yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}."
744
  return temp_dir
745
 
746
+ @spaces.GPU()
747
  def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
748
  yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
749
  trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False)
 
753
  eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
754
  if eval_logs:
755
  final_metrics = eval_logs[-1]
756
+ final_metrics = {k.replace('eval_', ''): v for k, v in final_metrics.items()}
757
 
758
  yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
759
  output_dir = trainer.args.output_dir
 
772
  torch.cuda.empty_cache()
773
  return output_dir, final_metrics
774
 
775
+ @spaces.GPU()
776
  def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
777
  output_dir = tempfile.mkdtemp()
778
  is_dpo = kwargs.get('training_mode') == "DPO (Direct Preference Optimization)"
 
804
  if kwargs.get('run_evaluation'):
805
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
806
  for update in eval_dataset_gen:
807
+ if isinstance(update, dict):
808
  yield update
809
  else:
810
  eval_dataset = update
811
 
812
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
813
+ trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config}
814
 
815
  if is_dpo:
816
  trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
 
821
  trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
822
  if kwargs.get('enable_loss_reweighting'):
823
  trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
824
+
825
+ trainer = TrainerClass(**trainer_kwargs)
 
 
 
 
 
 
 
 
 
 
826
  final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
827
  return final_model_path, final_metrics
828
 
829
  except Exception as e:
830
  raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}")
831
 
832
+ @spaces.GPU()
833
  def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
834
  output_dir = tempfile.mkdtemp()
835
  try:
 
840
  tokenizer_id = kwargs.get('tokenizer_name') or model_name
841
  yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuración")
842
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
843
+ if tokenizer.pad_token is None:
844
+ tokenizer.pad_token = tokenizer.eos_token
845
 
846
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
847
  model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
848
+ model.config.pad_token_id = tokenizer.pad_token_id
849
 
850
  def preprocess(examples):
851
  return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
852
+ train_dataset = train_dataset.map(preprocess, batched=True)
853
 
854
  eval_dataset = None
855
  if kwargs.get('run_evaluation'):
856
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
857
  for update in eval_dataset_gen:
858
+ if isinstance(update, dict):
859
  yield update
860
  else:
861
  eval_dataset = update
862
+ if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True)
863
 
864
  metric = hf_evaluate.load("accuracy")
865
  def compute_metrics(eval_pred):
 
879
  except Exception as e:
880
  raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}")
881
 
882
+ @spaces.GPU()
883
  def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
884
  output_dir = tempfile.mkdtemp()
885
  try:
 
916
  if kwargs.get('run_evaluation'):
917
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
918
  for update in eval_dataset_gen:
919
+ if isinstance(update, dict):
920
  yield update
921
  else:
922
  eval_dataset = update
 
944
  except Exception as e:
945
  raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}")
946
 
947
+ @spaces.GPU()
948
  def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
949
  output_dir = tempfile.mkdtemp()
950
  try:
 
1009
  eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
1010
  eval_dataset_raw = None
1011
  for update in eval_dataset_raw_gen:
1012
+ if isinstance(update, dict):
1013
  yield update
1014
  else:
1015
  eval_dataset_raw = update
 
1029
  except Exception as e:
1030
  raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}")
1031
 
1032
+ @spaces.GPU()
1033
  def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1034
  output_dir = tempfile.mkdtemp()
1035
  try:
 
1054
  if kwargs.get('run_evaluation'):
1055
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
1056
  for update in eval_dataset_gen:
1057
+ if isinstance(update, dict):
1058
  yield update
1059
  else:
1060
  eval_dataset = update
 
1087
  except Exception as e:
1088
  raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}")
1089
 
1090
+ @spaces.GPU()
1091
  def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1092
  if device == 'cpu':
1093
  raise ValueError("El entrenamiento de Text-to-Image solo es compatible con GPU CUDA.")
 
1099
 
1100
  yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
1101
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
1102
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=torch_dtype_auto)
1103
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=torch_dtype_auto)
1104
+ unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=torch_dtype_auto)
1105
  noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
1106
 
1107
  vae.requires_grad_(False)
1108
  text_encoder.requires_grad_(False)
1109
  unet.train()
1110
 
1111
+ yield update_logs_fn("Agregando adaptadores LoRA al UNet...", "Text-to-Image (LoRA)")
1112
  unet_lora_config = LoraConfig(
1113
  r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
1114
  target_modules=["to_q", "to_k", "to_v", "to_out.0"],
1115
  )
1116
  unet.add_adapter(unet_lora_config)
1117
 
1118
+ if kwargs.get('dreambooth_train_text_encoder', False):
1119
+ yield update_logs_fn("Agregando adaptadores LoRA al Text Encoder...", "DreamBooth LoRA")
1120
+ text_encoder_lora_config = LoraConfig(
1121
+ r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
1122
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
1123
+ )
1124
+ text_encoder.add_adapter(text_encoder_lora_config)
1125
+
1126
  yield update_logs_fn("Procesando dataset de imágenes...", "Text-to-Image (LoRA)")
1127
  resolution = int(kwargs.get('diffusion_resolution', 512))
1128
 
 
1134
  ])
1135
 
1136
  def preprocess_train(examples):
1137
+ images = [image.convert("RGB") for image in examples[kwargs.get('image_col', 'image')]]
1138
  examples["pixel_values"] = [train_transforms(image) for image in images]
1139
  examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
1140
  return examples
 
1148
 
1149
  def collate_fn(examples):
1150
  pixel_values = torch.stack([example["pixel_values"] for example in examples])
1151
+ input_ids = torch.stack([e["input_ids"][0] for e in examples])
1152
  return {"pixel_values": pixel_values, "input_ids": input_ids}
1153
 
1154
  train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
1155
+
1156
+ params_to_optimize = list(unet.parameters())
1157
+ if kwargs.get('dreambooth_train_text_encoder', False):
1158
+ params_to_optimize += list(text_encoder.parameters())
1159
 
 
1160
  optimizer = torch.optim.AdamW(
1161
+ params_to_optimize, lr=float(kwargs.get('learning_rate', 2e-5)),
1162
  betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
1163
  weight_decay=float(kwargs.get('weight_decay', 0.01)),
1164
  eps=float(kwargs.get('adam_epsilon', 1e-8)),
 
1174
  num_training_steps=max_train_steps,
1175
  )
1176
 
1177
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1178
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1179
  )
1180
 
 
1181
  vae.to(accelerator.device, dtype=torch_dtype_auto)
1182
 
 
1183
  global_step = 0
1184
  final_loss = 0
1185
  for epoch in range(num_epochs):
1186
  for step, batch in enumerate(train_dataloader):
1187
  with accelerator.accumulate(unet):
1188
+ latents = vae.encode(batch["pixel_values"].to(dtype=torch_dtype_auto)).latent_dist.sample()
1189
  latents = latents * vae.config.scaling_factor
1190
  noise = torch.randn_like(latents)
1191
  bsz = latents.shape[0]
1192
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
 
 
1193
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1194
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
 
1195
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1196
  loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
1197
  final_loss = loss.detach().item()
1198
 
1199
  accelerator.backward(loss)
1200
  if accelerator.sync_gradients:
1201
+ params_to_clip = list(unet.parameters())
1202
+ if kwargs.get('dreambooth_train_text_encoder', False):
1203
+ params_to_clip += list(text_encoder.parameters())
1204
+ accelerator.clip_grad_norm_(params_to_clip, float(kwargs.get('max_grad_norm', 1.0)))
1205
 
1206
  optimizer.step()
1207
  lr_scheduler.step()
 
1209
 
1210
  if accelerator.is_main_process:
1211
  if global_step % int(kwargs.get('logging_steps', 10)) == 0:
1212
+ yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss:.4f}", "Entrenando Difusión")
1213
  global_step += 1
1214
+ if global_step >= max_train_steps:
1215
+ break
1216
+ if global_step >= max_train_steps:
1217
+ break
1218
 
 
1219
  accelerator.wait_for_everyone()
1220
  if accelerator.is_main_process:
1221
+ pipeline = StableDiffusionText2ImagePipeline.from_pretrained(
1222
+ model_name,
1223
+ unet=accelerator.unwrap_model(unet),
1224
+ text_encoder=accelerator.unwrap_model(text_encoder),
1225
+ torch_dtype=torch_dtype_auto,
1226
+ )
1227
  pipeline.save_pretrained(output_dir)
1228
 
1229
  with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
 
1238
  torch.cuda.empty_cache()
1239
  return output_dir, {"final_loss": final_loss}
1240
 
1241
+ @spaces.GPU()
1242
  def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1243
  if device == 'cpu':
1244
  raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
 
1253
 
1254
  train_dataset = train_dataset.map(add_prompt)
1255
 
1256
+ yield update_logs_fn(f"Usando el prompt de instancia para todas las imágenes: '{dreambooth_prompt}'", "DreamBooth LoRA")
1257
 
1258
  final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
1259
  return final_model_path, final_metrics
1260
 
1261
+ @spaces.GPU()
1262
  def _get_data_processing_pipeline(**kwargs):
1263
  hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]
1264
  if not hf_ids and not kwargs.get('uploads'):
 
1285
  if train_dataset is None:
1286
  train_dataset = hf_train_dataset
1287
  else:
 
1288
  all_streams = [train_dataset, hf_train_dataset]
1289
+ all_probs = [0.5, 0.5]
1290
  train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
1291
 
1292
  if train_dataset is None:
 
1296
  text_col, image_col, audio_col, label_col = _guess_columns(first_example)
1297
  kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
1298
 
1299
+ is_text_task = kwargs['training_mode'] not in ["DreamBooth LoRA (Text-to-Image)", "Text-to-Image (LoRA)", "Image Classification (Vision)", "Audio Classification (Speech)"]
1300
+ if is_text_task:
1301
  if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
1302
  clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
1303
  train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
 
1320
 
1321
  dedup_method = kwargs.get('deduplication_method')
1322
  if dedup_method != 'Ninguna':
1323
+ train_dataset = DeduplicatedIterableDataset(
1324
+ dataset=train_dataset,
1325
+ text_col=text_col,
1326
+ method=dedup_method,
1327
+ threshold=kwargs.get('minhash_threshold', 0.85),
1328
+ num_perm=int(kwargs.get('minhash_num_perm', 128))
1329
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1330
 
1331
  return train_dataset, kwargs
1332
 
1333
+ @spaces.GPU()
1334
  def _train_and_upload(**kwargs):
1335
  logs, repo_link, final_model_path, final_metrics = "", "", None, {}
1336
 
 
1420
  raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
1421
  base_model_id_for_training = temp_model_dir
1422
  kwargs["peft"] = False
 
1423
  kwargs['tokenizer_name'] = temp_model_dir
1424
  yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
1425
 
 
1432
  os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
1433
  os.environ["WANDB_LOG_MODEL"] = "checkpoint"
1434
 
 
1435
  model_card_content = MODEL_CARD_TEMPLATE.format(
1436
  repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
1437
  training_mode=kwargs.get('training_mode'),
 
1458
  train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
1459
  while True:
1460
  try:
1461
+ update = next(train_generator)
1462
+ if isinstance(update, tuple) and len(update) == 4:
1463
+ yield update + (gr.update(), gr.update())
1464
+ else:
1465
+ pass
1466
  except StopIteration as e:
1467
  final_model_path, final_metrics = e.value
1468
  break
 
1476
  eval_dataset_perp = None
1477
  eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
1478
  for update in eval_gen:
1479
+ if isinstance(update, dict):
1480
  yield update + (gr.update(), gr.update())
1481
  else:
1482
  eval_dataset_perp = update
 
1507
  gr.update(visible=False)
1508
  )
1509
 
1510
+ @spaces.GPU()
1511
  def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens):
1512
  if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1513
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode)
 
1532
  return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1533
  except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1534
 
1535
+ @spaces.GPU()
1536
  def update_inference_ui(task_mode):
1537
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
1538
  is_text_gen = task_name == "text-generation"
 
1550
  gr.update(visible=is_text_gen)
1551
  )
1552
 
1553
+ @spaces.GPU()
1554
  def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()):
1555
  if not hf_token:
1556
  return "Error: Se requiere un token de Hugging Face.", ""
 
1599
  for item in all_data:
1600
  f.write(json.dumps(item, ensure_ascii=False) + "\n")
1601
 
 
1602
  readme_content = DATASET_CARD_TEMPLATE.format(
1603
  repo_id=repo_id,
1604
  creation_type=creation_type,
 
1623
  except Exception as e:
1624
  return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", ""
1625
 
1626
+ @spaces.GPU()
1627
  def gradio_train_wrapper(*args):
1628
  kwargs = dict(zip(all_input_components_dict.keys(), args))
1629
  yield from _train_and_upload(**kwargs)
1630
 
1631
+ @spaces.GPU()
1632
  def gradio_preview_data_wrapper(*args):
1633
  kwargs = dict(zip(all_input_components_dict.keys(), args))
1634
  try:
 
1638
  dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
1639
  text_col = processed_kwargs.get('text_col')
1640
 
1641
+ model_id_for_tokenizer = kwargs.get('model_base_input')
1642
+ if not model_id_for_tokenizer:
1643
+ raise ValueError("Se necesita un ID de modelo base para cargar el tokenizer para la vista previa.")
1644
+
1645
+ tokenizer_id = kwargs.get('tokenizer_name') or model_id_for_tokenizer
1646
  tokenizer = AutoTokenizer.from_pretrained(
1647
+ tokenizer_id, trust_remote_code=True, use_fast=False
 
1648
  )
1649
  if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
1650
  if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
 
1653
  for i, example in enumerate(islice(dataset, 5)):
1654
  formatted_text = ""
1655
  if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
1656
+ formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False)
1657
  else:
1658
  formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
1659
 
1660
  preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
1661
 
1662
  preview_text = "\n".join(preview_samples)
1663
+ if not preview_samples:
1664
+ preview_text = "No se pudieron generar muestras. Revisa la configuración del dataset, los filtros y el formato."
1665
  yield preview_text
1666
 
1667
  except Exception as e:
1668
  yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
1669
 
1670
+ @spaces.GPU()
1671
  def toggle_training_mode_ui(is_scratch):
1672
  return (
1673
  gr.update(visible=not is_scratch),
 
1678
  gr.update(visible=is_scratch)
1679
  )
1680
 
1681
+ @spaces.GPU()
1682
  def toggle_task_specific_ui(training_mode):
1683
  is_classification = "Classification" in training_mode
1684
  is_dpo = "DPO" in training_mode
1685
  is_sft = "Causal" in training_mode
1686
  is_ner = "Token Classification" in training_mode
1687
  is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
1688
+ is_streaming = not is_diffusion
1689
 
1690
  return (
1691
  gr.update(visible=is_classification or is_ner),
 
1695
  gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
1696
  gr.update(visible=not is_diffusion),
1697
  gr.update(visible=is_diffusion),
1698
+ gr.update(visible=not is_streaming),
1699
+ gr.update(visible=is_streaming)
1700
  )
1701
 
1702
+ @spaces.GPU()
1703
  def toggle_auto_modules_ui(is_auto):
1704
  return gr.update(visible=not is_auto)
1705
 
1706
+ @spaces.GPU()
1707
  def toggle_dataset_creator_ui(choice):
1708
  is_synth = choice == "Sintético"
1709
  return gr.update(visible=is_synth), gr.update(visible=not is_synth)
 
1734
  dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
1735
  dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
1736
  with gr.Column(scale=2):
1737
+ dset_status_output = gr.Textbox(label="Estado", lines=10, interactive=False)
1738
  dset_link_output = gr.Markdown()
1739
 
1740
  dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
 
1790
  with gr.Accordion("Avanzados", open=False):
1791
  warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
1792
  weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
1793
+ max_grad_norm = gr.Textbox(label="Norma Máxima de Gradiente", value="1.0")
1794
  logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1795
  save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1796
  save_total_limit = gr.Textbox(label="Límite Total de Guardado", value="1")
 
1850
  diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resolución")
1851
  with gr.Group(visible=False) as dreambooth_ui:
1852
  dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
 
 
 
1853
  dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
1854
  with gr.Group(visible=False) as classification_labels_ui:
1855
  classification_labels = gr.Textbox(label="Etiquetas de Clasificación (csv)", placeholder="p.ej. positivo,negativo")
 
1868
  enable_cda = gr.Checkbox(label="Habilitar Aumentación Contrafactual (CDA)", value=False)
1869
  cda_json_config = gr.Textbox(label="Configuración CDA (JSON)", placeholder='[["ella", "él"], ["mujer", "hombre"]]')
1870
 
 
1871
  with gr.Accordion("🔌 Integraciones", open=False):
1872
  wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
1873
  wandb_project_input = gr.Textbox(label="Proyecto W&B")
 
1912
  "diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
1913
  "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
1914
  "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
1915
+ "dreambooth_instance_prompt": dreambooth_instance_prompt,
 
1916
  "dreambooth_train_text_encoder": dreambooth_train_text_encoder
1917
  }
1918
 
 
1984
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1985
  )
1986
 
1987
+ demo.queue().launch(debug=True)