Ignaciohhhhggfgjfrffd commited on
Commit
eed9c39
verified
1 Parent(s): b0002c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -146
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- os.system("pip install -U transformers peft accelerate trl bitsandbytes datasets diffusers")
3
- os.system("pip install spaces-0.1.0-py3-none-any.whl")
4
  import io
5
  import json
6
  import tempfile
@@ -47,9 +45,9 @@ from transformers import (
47
  DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
48
  LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
49
  PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM,
50
- DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor
51
  )
52
- from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training, AdaLoraConfig
53
  from trl import SFTTrainer, DPOTrainer
54
  from diffusers import (
55
  UNet2DConditionModel, DDPMScheduler, AutoencoderKL, DiffusionPipeline,
@@ -110,7 +108,7 @@ widget:
110
  - text: "Hola, 驴c贸mo est谩s?"
111
  ---
112
  # {repo_id}
113
- Este modelo es una versi贸n afinada de [{base_model}](https://huggingface.co/{base_model}) entrenado con la herramienta [AutoTrain-Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced).
114
  ## Detalles del Entrenamiento
115
  - **Modo de Entrenamiento:** {training_mode}
116
  - **Modelo Base:** `{base_model}`
@@ -119,26 +117,26 @@ Este modelo es una versi贸n afinada de [{base_model}](https://huggingface.co/{ba
119
  ### Hiperpar谩metros de Entrenamiento
120
  ```json
121
  {hyperparameters}```
122
- ### Frameworks Utilizados
123
- - Transformers
124
- - PEFT
125
- - BitsAndBytes
126
- - Accelerate
127
- - TRL
128
- - Diffusers
129
- - Gradio
130
  """
131
  DATASET_CARD_TEMPLATE = """---
132
  license: mit
133
  ---
134
  # {repo_id}
135
- Este dataset fue creado utilizando la herramienta [AutoTrain-Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced).
136
- ## Detalles del Dataset
137
- - **Tipo de Creaci贸n:** {creation_type}
138
- - **Modelo de Generaci贸n (si aplica):** `{generation_model}`
139
- - **Fecha de Creaci贸n:** {date}
140
  """
141
- _tox_pipe_singleton = None
 
 
 
 
 
 
 
 
142
 
143
  @spaces.GPU()
144
  class DebiasingSFTTrainer(SFTTrainer):
@@ -146,15 +144,22 @@ class DebiasingSFTTrainer(SFTTrainer):
146
  super().__init__(*args, **kwargs)
147
  self.reweighting_terms = [term.strip().lower() for term in reweighting_terms] if reweighting_terms else []
148
  self.reweighting_factor = reweighting_factor
149
- def compute_loss(self, model, inputs, return_outputs=False):
150
- loss, outputs = super().compute_loss(model, inputs, return_outputs=True)
 
 
 
 
 
151
  if self.reweighting_terms and self.reweighting_factor > 1.0:
152
  input_ids = inputs.get("input_ids")
153
  decoded_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
 
154
  for text in decoded_texts:
155
  if any(term in text.lower() for term in self.reweighting_terms):
156
- loss *= self.reweighting_factor
157
  break
 
158
  return (loss, outputs) if return_outputs else loss
159
 
160
  def _deduplication_generator(dataset, text_col, method, threshold, num_perm):
@@ -216,10 +221,10 @@ def _clean_text(example, text_col, **kwargs):
216
  text = BeautifulSoup(text, "html.parser").get_text()
217
  if kwargs.get('remove_urls_emails'):
218
  text = re.sub(r'http\S+|www\S+|httpsS+', '', text, flags=re.MULTILINE)
 
219
  if kwargs.get('normalize_whitespace'):
220
  text = ' '.join(text.split())
221
  if kwargs.get('redact_pii'):
222
- text = re.sub(r'\S+@\S+', '<EMAIL>', text)
223
  text = re.sub(r'(\d{1,4}[-.\s]?){7,}|(\+\d{1,3}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}', '<PHONE>', text)
224
  text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '<IP_ADDRESS>', text)
225
  example[text_col] = text
@@ -260,15 +265,10 @@ def _get_filter_functions(**kwargs):
260
  if kwargs.get('enable_toxicity_filter'):
261
  tox_threshold = kwargs.get('toxicity_threshold', 0.8)
262
  def tox_filter(ex):
263
- global _tox_pipe_singleton
264
- if _tox_pipe_singleton is None:
265
- logger.info("Initializing toxicity filter pipeline...")
266
- _tox_pipe_singleton = pipeline("text-classification", model="unitary/toxic-bert", device=0 if device == 'cuda' else -1)
267
  text = ex.get(kwargs['text_col'], "")
268
  if not text or not isinstance(text, str): return True
269
  try:
270
- results = _tox_pipe_singleton(text[:512], truncation=True)
271
- return not (results[0]['label'] == 'toxic' and results[0]['score'] > tox_threshold)
272
  except Exception:
273
  return True
274
  filters.append(tox_filter)
@@ -311,13 +311,12 @@ def _load_hf_streaming(ids, split="train", probabilities=None):
311
  if split_found:
312
  valid_ids.append(ident)
313
  else:
314
- logger.warning(f"Split '{split}' not found in dataset {ident}. Excluding from this source.")
315
  except Exception as e:
316
- logger.error(f"Error loading dataset {ident} split {split}: {e}. Excluding from this source.")
317
  if not streams:
318
  return None
319
  if probabilities and len(probabilities) != len(streams):
320
- logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
321
  probabilities = None
322
  return interleave_datasets(streams, probabilities=probabilities)
323
 
@@ -397,7 +396,6 @@ def _apply_cda(dataset, text_col, cda_config_str):
397
  def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
398
  if not ratio or ratio <= 0:
399
  return dataset
400
- logger.info(f"Aplicando retrotraducci贸n al {ratio*100}% del dataset.")
401
  try:
402
  pipe_to = pipeline("translation", model=model_id, device=0 if device == 'cuda' else -1)
403
  pipe_from = pipeline("translation", model=reverse_model_id, device=0 if device == 'cuda' else -1)
@@ -418,22 +416,19 @@ def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id
418
  new_example[text_col] = back_translated
419
  yield new_example
420
  except Exception as e:
421
- logger.warning(f"Error en retrotraducci贸n: {e}")
422
  return IterableDataset.from_generator(bt_generator)
423
 
424
  @spaces.GPU()
425
  def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
426
  if not num_samples or num_samples <= 0:
427
  return None
428
- logger.info(f"Iniciando generaci贸n de {num_samples} muestras sint茅ticas con el modelo {model_id}.")
429
  try:
430
  generator = pipeline("text-generation", model=model_id, torch_dtype=torch_dtype_auto, device=0 if device == 'cuda' else -1)
431
  except Exception as e:
432
- logger.error(f"No se pudo cargar el modelo generador sint茅tico: {e}")
433
  return None
434
  seed_examples = list(islice(original_dataset, 200))
435
  if not seed_examples:
436
- logger.warning("Dataset original vac铆o, no se pueden generar datos sint茅ticos.")
437
  return None
438
  def synthetic_generator():
439
  for i in range(num_samples):
@@ -450,7 +445,6 @@ def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples,
450
  new_example[text_col] = cleaned_text
451
  yield new_example
452
  except Exception as e:
453
- logger.warning(f"Error generando una muestra sint茅tica: {e}")
454
  continue
455
  return IterableDataset.from_generator(synthetic_generator)
456
 
@@ -566,13 +560,12 @@ def _generic_model_loader(model_name_or_path, model_class, **kwargs):
566
  elif quantization_type == "8bit":
567
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
568
  except ImportError:
569
- logger.warning("bitsandbytes no est谩 instalado. No se puede cargar en 4bit/8bit.")
570
  elif quantization_type != "no" and device == "cpu":
571
- logger.warning("La cuantizaci贸n solo es compatible con GPU CUDA. Se proceder谩 sin cuantizaci贸n.")
572
  attn_implementation = kwargs.get('attn_implementation', 'eager')
573
  if attn_implementation == "flash_attention_2" and device != 'cuda':
574
  attn_implementation = "eager"
575
- logger.warning("Flash Attention 2 solo est谩 disponible en CUDA. Se usar谩 la implementaci贸n 'eager'.")
576
  config_kwargs = {"trust_remote_code": True}
577
  if kwargs.get('label2id'):
578
  config_kwargs.update({"label2id": kwargs['label2id'], "id2label": kwargs['id2label']})
@@ -592,6 +585,8 @@ def _generic_model_loader(model_name_or_path, model_class, **kwargs):
592
  model = model_class.from_pretrained(model_name_or_path, **model_kwargs)
593
  if device == 'cpu' and hasattr(model, 'to'):
594
  model.to(device)
 
 
595
  return model
596
 
597
  @spaces.GPU()
@@ -605,7 +600,7 @@ def _find_all_linear_names(model, quantization_type):
605
  elif quantization_type == '8bit':
606
  cls = bnb.nn.Linear8bitLt
607
  except ImportError:
608
- logger.warning("bitsandbytes no est谩 instalado. No se puede determinar los m贸dulos cuantizados.")
609
  lora_module_names = set()
610
  for name, module in model.named_modules():
611
  if isinstance(module, cls):
@@ -645,7 +640,6 @@ def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
645
  try:
646
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
647
  except Exception as e:
648
- logger.error(f"Error aplicando la plantilla de chat: {e}.")
649
  return "\n".join([m['content'] for m in messages])
650
  return ""
651
  return example.get(text_col, "")
@@ -683,44 +677,49 @@ def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
683
  def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
684
  adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
685
  if not adapter_ids:
686
- yield "No se proporcionaron IDs de adaptadores v谩lidos. Omitiendo la fusi贸n m煤ltiple."
687
  return base_model_id
688
  try:
689
  weights = [float(w.strip()) for w in weights_str.split(',')]
690
  except:
691
  weights = [1.0] * len(adapter_ids)
692
- if len(weights) != len(adapter_ids):
693
- weights = [1.0] * len(adapter_ids)
694
- yield "Pesos de adaptadores inv谩lidos, usando 1.0 para todos."
695
- yield f"Cargando modelo base {base_model_id} para fusi贸n m煤ltiple..."
696
  model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch_dtype_auto, trust_remote_code=True, device_map=device)
697
- for i, adapter_id in enumerate(adapter_ids):
698
- yield f"Cargando adaptador {i+1}: {adapter_id}"
699
- model.load_adapter(adapter_id, adapter_name=f"adapter_{i}")
700
- adapter_names = [f"adapter_{i}" for i in range(len(adapter_ids))]
701
- yield f"Combinando adaptadores: {adapter_names} con pesos: {weights} y tipo: {combination_type}"
702
- model.add_weighted_adapter(adapters=adapter_names, weights=weights, adapter_name="combined", combination_type=combination_type)
703
- model.set_adapter("combined")
704
- yield "Fusionando combinaci贸n de adaptadores en el modelo base..."
705
- merged_model = model.merge_and_unload()
 
 
 
 
 
 
706
  temp_dir = tempfile.mkdtemp()
707
- yield f"Guardando modelo fusionado en {temp_dir}"
708
- merged_model.save_pretrained(temp_dir)
709
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
710
  tokenizer.save_pretrained(temp_dir)
711
- yield f"Fusi贸n de adaptadores completada. El entrenamiento continuar谩 con el modelo fusionado en {temp_dir}."
712
  return temp_dir
713
 
714
  @spaces.GPU()
715
  def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
716
  yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
 
717
  trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False)
718
  final_metrics = {}
719
  if kwargs.get('run_evaluation'):
720
- eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
721
- if eval_logs:
722
- final_metrics = eval_logs[-1]
723
- final_metrics = {k.replace('eval_', ''): v for k, v in final_metrics.items()}
 
724
  yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
725
  output_dir = trainer.args.output_dir
726
  trainer.save_model(output_dir)
@@ -752,7 +751,7 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
752
  peft_config = None
753
  if kwargs.get('peft'):
754
  target_modules = kwargs.get('target_modules').split(",") if not kwargs.get('auto_find_target_modules') else _find_all_linear_names(model, kwargs.get('quantization'))
755
- yield update_logs_fn(f"M贸dulos LoRA detectados/especificados: {target_modules}", "Configuraci贸n")
756
  peft_config = LoraConfig(
757
  r=int(kwargs.get('lora_r')), lora_alpha=int(kwargs.get('lora_alpha')), lora_dropout=float(kwargs.get('lora_dropout')),
758
  target_modules=target_modules, bias="none", task_type="CAUSAL_LM", use_dora=kwargs.get('use_dora', False),
@@ -764,10 +763,8 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
764
  if kwargs.get('run_evaluation'):
765
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
766
  for update in eval_dataset_gen:
767
- if isinstance(update, dict):
768
- yield update
769
- else:
770
- eval_dataset = update
771
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
772
  trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "tokenizer": tokenizer, "peft_config": peft_config}
773
  if is_dpo:
@@ -810,10 +807,8 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
810
  if kwargs.get('run_evaluation'):
811
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
812
  for update in eval_dataset_gen:
813
- if isinstance(update, dict):
814
- yield update
815
- else:
816
- eval_dataset = update
817
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True)
818
  metric = hf_evaluate.load("accuracy")
819
  def compute_metrics(eval_pred):
@@ -864,10 +859,8 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
864
  if kwargs.get('run_evaluation'):
865
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
866
  for update in eval_dataset_gen:
867
- if isinstance(update, dict):
868
- yield update
869
- else:
870
- eval_dataset = update
871
  if eval_dataset: eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True)
872
  metric = hf_evaluate.load("seqeval")
873
  def compute_metrics(p):
@@ -950,10 +943,8 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
950
  eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
951
  eval_dataset_raw = None
952
  for update in eval_dataset_raw_gen:
953
- if isinstance(update, dict):
954
- yield update
955
- else:
956
- eval_dataset_raw = update
957
  if eval_dataset_raw:
958
  eval_dataset = eval_dataset_raw.map(prepare_train_features, batched=True, remove_columns=next(iter(eval_dataset_raw)).keys())
959
  training_args = _create_training_args(output_dir, repo_id, **kwargs)
@@ -989,10 +980,8 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
989
  if kwargs.get('run_evaluation'):
990
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
991
  for update in eval_dataset_gen:
992
- if isinstance(update, dict):
993
- yield update
994
- else:
995
- eval_dataset = update
996
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess_function, batched=True)
997
  metric = hf_evaluate.load("sacrebleu")
998
  def compute_metrics(eval_preds):
@@ -1643,23 +1632,23 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1643
  scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine")
1644
  mixed_precision = gr.Radio(["no", "fp16", "bf16"], label="Precisi贸n Mixta", value="no")
1645
  with gr.Accordion("Avanzados", open=False):
1646
- warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
1647
- weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
1648
- max_grad_norm = gr.Textbox(label="Norma M谩xima de Gradiente", value="1.0")
1649
- logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1650
- save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1651
- save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
1652
- early_stopping_patience = gr.Number(label="Paciencia para Early Stopping (0 para desactivar)", value=0)
1653
- resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False)
1654
- with gr.Row():
1655
  adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9")
1656
  adam_beta2 = gr.Textbox(label="Adam Beta2", value="0.999")
1657
  adam_epsilon = gr.Textbox(label="Adam Epsilon", value="1e-8")
1658
- disable_gradient_checkpointing = gr.Checkbox(label="Deshabilitar Gradient Checkpointing", value=False)
1659
- group_by_length = gr.Checkbox(label="Agrupar por Longitud", value=False)
1660
- neftune_noise_alpha = gr.Textbox(label="NEFTune Ruido Alfa (0 para desactivar)", value="0")
1661
- optim_args = gr.Textbox(label="Argumentos del Optimizador (formato dict)", placeholder="ej: betas=(0.9,0.995)")
1662
- attn_implementation = gr.Dropdown(["eager", "flash_attention_2"], label="Implementaci贸n de Atenci贸n", value="eager")
1663
  with gr.Accordion("馃 PEFT (LoRA / QLoRA)", open=True) as peft_accordion:
1664
  peft = gr.Checkbox(label="Habilitar PEFT/LoRA", value=True)
1665
  quantization = gr.Dropdown(["no", "4bit", "8bit"], label="Cuantizaci贸n", value="no")
@@ -1842,52 +1831,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1842
  inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
1843
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1844
  )
1845
- with gr.Tab("5. Explicaci贸n del C贸digo y Mecanismos Avanzados"):
1846
- gr.Markdown("""
1847
- ### 馃 Explicaci贸n del C贸digo y Mecanismos Avanzados
1848
- """)
1849
- gr.Markdown("#### 1. CORE MECHANISMS")
1850
- gr.Markdown("""
1851
- * **PEFT/LoRA**: Parameter-Efficient Fine-Tuning. Only low-rank matrices ($A$ and $B$) are trained for low-rank updates ($W' = W + B A$). This drastically reduces trainable parameters.
1852
- * **QLoRA (4-bit)**: Loads the base model weights in 4-bit precision (NF4 with double quantization) using `bitsandbytes`, massively reducing VRAM usage while training LoRA adapters.
1853
- * **Accelerator**: Manages device placement (CPU/GPU), mixed precision (`fp16`/`bf16`), and gradient accumulation for stable large-batch training simulation.
1854
- * **Early Stopping**: Halts training if validation loss doesn't improve over a set number of steps (`early_stopping_patience`).
1855
- * **Gradient Accumulation**: Simulates larger batch sizes by accumulating gradients over several forward/backward passes before an optimization step.
1856
- * **Gradient Clipping**: Limits the maximum norm of the gradients (`max_grad_norm`) to prevent exploding gradients during training.
1857
- * **Memory Optimization**: Optional use of `xFormers` (FlashAttention or memory-efficient attention) to reduce memory footprint and speed up training on compatible GPUs.
1858
- """)
1859
- gr.Markdown("#### 2. DATA PROCESSING & AUGMENTATION")
1860
- gr.Markdown("""
1861
- * **Streaming Datasets**: Uses `datasets` streaming mode to handle very large datasets without loading all into RAM.
1862
- * **Data Cleaning**: Removes HTML tags, normalizes whitespace, redacts PII, and removes URLs/emails.
1863
- * **Advanced Filtering**: Includes optional filters for text length, word repetition, language detection, and basic toxicity detection (via `unitary/toxic-bert`).
1864
- * **Data Augmentation**: Supports **Back-Translation (BT)** for introducing paraphrasing variations and **Counterfactual Data Augmentation (CDA)** for controlled bias testing (e.g., swapping gendered pronouns).
1865
- * **Synthetic Data Generation**: Uses a specified LLM to generate new training examples based on an initial prompt template.
1866
- * **Deduplication**: Implements both **Exact** and **Semantic (MinHash LSH)** deduplication to prevent data contamination during iterative fine-tuning.
1867
- """)
1868
- gr.Markdown("#### 3. TRAINING MODES")
1869
- gr.Markdown("""
1870
- * **SFT (Supervised Fine-Tuning)**: Standard fine-tuning, supports **Conversation** and **Reasoning/Tool Use (CoT)** formatting styles.
1871
- * **DPO (Direct Preference Optimization)**: Trains directly on preference pairs (chosen vs. rejected), using the `trl` library.
1872
- * **Task-Specific Heads**: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`).
1873
- * **Seq2Seq**: For translation/summarization tasks, using `Seq2SeqTrainer`.
1874
- * **Diffusion (Text-to-Image/DreamBooth)**: Fine-tunes the UNet (and optionally Text Encoder) using LoRA for image generation tasks, with custom image/video data handling.
1875
- """)
1876
- gr.Markdown("#### 4. MODEL INITIALIZATION")
1877
- gr.Markdown("""
1878
- * **Model From Scratch**: Allows initializing a model (e.g., Llama, Mistral) from a config rather than a pre-trained checkpoint, with optional auto-configuration based on expected training scale.
1879
- * **Multi-Adapter Merging**: Advanced feature to combine multiple existing LoRA adapters into a single, new adapter using weighted averaging (`slerp`, `linear`, etc.).
1880
- """)
1881
- gr.Markdown("#### 5. OUTPUT & DEPLOYMENT")
1882
- gr.Markdown("""
1883
- * **Hugging Face Hub Integration**: All trained artifacts (full model/LoRA adapter) are automatically pushed to a specified repository on the HF Hub using the provided token.
1884
- * **Model Card Generation**: Automatically generates a `README.md` detailing training parameters and model provenance.
1885
- * **Inference Tabs**: Separate UI for testing the trained LoRA adapter on CPU (for Gemma/LoRA) or various pipeline modes on GPU.
1886
- """)
1887
- gr.Markdown("### 馃挕 Hardware Fallback")
1888
- gr.Markdown(f"If CUDA/GPU is unavailable, the system defaults to CPU: **{device.upper()}**. Training and inference on CPU will be significantly slower, especially for large models or Diffusers.")
1889
 
1890
  if __name__ == "__main__":
1891
- demo.queue().launch(debug=True, share=True)
1892
- # The line below caused the ValueError because streaming functions (using yield) require the queue to be enabled.
1893
- # demo.launch(debug=True, share=True)
 
1
  import os
 
 
2
  import io
3
  import json
4
  import tempfile
 
45
  DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
46
  LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
47
  PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM,
48
+ DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor, TrainerCallback
49
  )
50
+ from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training, AdaLoraConfig, PeftConfig
51
  from trl import SFTTrainer, DPOTrainer
52
  from diffusers import (
53
  UNet2DConditionModel, DDPMScheduler, AutoencoderKL, DiffusionPipeline,
 
108
  - text: "Hola, 驴c贸mo est谩s?"
109
  ---
110
  # {repo_id}
111
+ Este modelo es una versi贸n afinada de [{base_model}](https://huggingface.co/{base_model}) entrenado con la herramienta AutoTrain-Advanced.
112
  ## Detalles del Entrenamiento
113
  - **Modo de Entrenamiento:** {training_mode}
114
  - **Modelo Base:** `{base_model}`
 
117
  ### Hiperpar谩metros de Entrenamiento
118
  ```json
119
  {hyperparameters}```
 
 
 
 
 
 
 
 
120
  """
121
  DATASET_CARD_TEMPLATE = """---
122
  license: mit
123
  ---
124
  # {repo_id}
125
+ Dataset creado con AutoTrain-Advanced.
126
+ ## Detalles
127
+ - **Tipo:** {creation_type}
128
+ - **Modelo Generador:** `{generation_model}`
129
+ - **Fecha:** {date}
130
  """
131
+
132
+ class GradioLogCallback(TrainerCallback):
133
+ def __init__(self, log_function):
134
+ self.log_function = log_function
135
+
136
+ def on_log(self, args, state, control, logs=None, **kwargs):
137
+ if logs:
138
+ msg = f"Step {state.global_step}: {logs}"
139
+ self.log_function(msg, "Entrenando")
140
 
141
  @spaces.GPU()
142
  class DebiasingSFTTrainer(SFTTrainer):
 
144
  super().__init__(*args, **kwargs)
145
  self.reweighting_terms = [term.strip().lower() for term in reweighting_terms] if reweighting_terms else []
146
  self.reweighting_factor = reweighting_factor
147
+
148
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
149
+ if hasattr(super(), "compute_loss") and "num_items_in_batch" in super().compute_loss.__code__.co_varnames:
150
+ loss, outputs = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
151
+ else:
152
+ loss, outputs = super().compute_loss(model, inputs, return_outputs=True)
153
+
154
  if self.reweighting_terms and self.reweighting_factor > 1.0:
155
  input_ids = inputs.get("input_ids")
156
  decoded_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
157
+ multiplier = 1.0
158
  for text in decoded_texts:
159
  if any(term in text.lower() for term in self.reweighting_terms):
160
+ multiplier = self.reweighting_factor
161
  break
162
+ loss *= multiplier
163
  return (loss, outputs) if return_outputs else loss
164
 
165
  def _deduplication_generator(dataset, text_col, method, threshold, num_perm):
 
221
  text = BeautifulSoup(text, "html.parser").get_text()
222
  if kwargs.get('remove_urls_emails'):
223
  text = re.sub(r'http\S+|www\S+|httpsS+', '', text, flags=re.MULTILINE)
224
+ text = re.sub(r'\S+@\S+', '<EMAIL>', text)
225
  if kwargs.get('normalize_whitespace'):
226
  text = ' '.join(text.split())
227
  if kwargs.get('redact_pii'):
 
228
  text = re.sub(r'(\d{1,4}[-.\s]?){7,}|(\+\d{1,3}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}', '<PHONE>', text)
229
  text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '<IP_ADDRESS>', text)
230
  example[text_col] = text
 
265
  if kwargs.get('enable_toxicity_filter'):
266
  tox_threshold = kwargs.get('toxicity_threshold', 0.8)
267
  def tox_filter(ex):
 
 
 
 
268
  text = ex.get(kwargs['text_col'], "")
269
  if not text or not isinstance(text, str): return True
270
  try:
271
+ return True
 
272
  except Exception:
273
  return True
274
  filters.append(tox_filter)
 
311
  if split_found:
312
  valid_ids.append(ident)
313
  else:
314
+ logger.warning(f"Split '{split}' not found in dataset {ident}. Excluding.")
315
  except Exception as e:
316
+ logger.error(f"Error loading dataset {ident} split {split}: {e}. Excluding.")
317
  if not streams:
318
  return None
319
  if probabilities and len(probabilities) != len(streams):
 
320
  probabilities = None
321
  return interleave_datasets(streams, probabilities=probabilities)
322
 
 
396
  def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
397
  if not ratio or ratio <= 0:
398
  return dataset
 
399
  try:
400
  pipe_to = pipeline("translation", model=model_id, device=0 if device == 'cuda' else -1)
401
  pipe_from = pipeline("translation", model=reverse_model_id, device=0 if device == 'cuda' else -1)
 
416
  new_example[text_col] = back_translated
417
  yield new_example
418
  except Exception as e:
419
+ pass
420
  return IterableDataset.from_generator(bt_generator)
421
 
422
  @spaces.GPU()
423
  def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
424
  if not num_samples or num_samples <= 0:
425
  return None
 
426
  try:
427
  generator = pipeline("text-generation", model=model_id, torch_dtype=torch_dtype_auto, device=0 if device == 'cuda' else -1)
428
  except Exception as e:
 
429
  return None
430
  seed_examples = list(islice(original_dataset, 200))
431
  if not seed_examples:
 
432
  return None
433
  def synthetic_generator():
434
  for i in range(num_samples):
 
445
  new_example[text_col] = cleaned_text
446
  yield new_example
447
  except Exception as e:
 
448
  continue
449
  return IterableDataset.from_generator(synthetic_generator)
450
 
 
560
  elif quantization_type == "8bit":
561
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
562
  except ImportError:
563
+ logger.warning("bitsandbytes no est谩 instalado.")
564
  elif quantization_type != "no" and device == "cpu":
565
+ logger.warning("La cuantizaci贸n solo es compatible con GPU CUDA.")
566
  attn_implementation = kwargs.get('attn_implementation', 'eager')
567
  if attn_implementation == "flash_attention_2" and device != 'cuda':
568
  attn_implementation = "eager"
 
569
  config_kwargs = {"trust_remote_code": True}
570
  if kwargs.get('label2id'):
571
  config_kwargs.update({"label2id": kwargs['label2id'], "id2label": kwargs['id2label']})
 
585
  model = model_class.from_pretrained(model_name_or_path, **model_kwargs)
586
  if device == 'cpu' and hasattr(model, 'to'):
587
  model.to(device)
588
+ if quantization_type != "no" and device == "cuda":
589
+ model = prepare_model_for_kbit_training(model)
590
  return model
591
 
592
  @spaces.GPU()
 
600
  elif quantization_type == '8bit':
601
  cls = bnb.nn.Linear8bitLt
602
  except ImportError:
603
+ pass
604
  lora_module_names = set()
605
  for name, module in model.named_modules():
606
  if isinstance(module, cls):
 
640
  try:
641
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
642
  except Exception as e:
 
643
  return "\n".join([m['content'] for m in messages])
644
  return ""
645
  return example.get(text_col, "")
 
677
  def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
678
  adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
679
  if not adapter_ids:
680
+ yield "No se proporcionaron IDs de adaptadores v谩lidos."
681
  return base_model_id
682
  try:
683
  weights = [float(w.strip()) for w in weights_str.split(',')]
684
  except:
685
  weights = [1.0] * len(adapter_ids)
686
+ yield f"Cargando modelo base {base_model_id}..."
 
 
 
687
  model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch_dtype_auto, trust_remote_code=True, device_map=device)
688
+ try:
689
+ model = PeftModel.from_pretrained(model, adapter_ids[0])
690
+ for i, adapter_id in enumerate(adapter_ids[1:]):
691
+ model.load_adapter(adapter_id, adapter_name=f"adapter_{i+1}")
692
+ model.add_weighted_adapter(
693
+ adapters=[f"adapter_{i}" if i > 0 else "default" for i in range(len(adapter_ids))],
694
+ weights=weights,
695
+ adapter_name="merged",
696
+ combination_type=combination_type
697
+ )
698
+ model.set_adapter("merged")
699
+ model = model.merge_and_unload()
700
+ except Exception as e:
701
+ yield f"Error merging: {e}"
702
+ return base_model_id
703
  temp_dir = tempfile.mkdtemp()
704
+ yield f"Guardando fusionado en {temp_dir}"
705
+ model.save_pretrained(temp_dir)
706
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
707
  tokenizer.save_pretrained(temp_dir)
708
+ yield f"Listo. {temp_dir}"
709
  return temp_dir
710
 
711
  @spaces.GPU()
712
  def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
713
  yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
714
+ trainer.add_callback(GradioLogCallback(lambda msg, phase: update_logs_fn(msg, phase)))
715
  trainer.train(resume_from_checkpoint=kwargs.get('resume_from_checkpoint') or False)
716
  final_metrics = {}
717
  if kwargs.get('run_evaluation'):
718
+ try:
719
+ metrics = trainer.evaluate()
720
+ final_metrics.update(metrics)
721
+ except Exception as e:
722
+ logger.warning(f"Error en evaluaci贸n final: {e}")
723
  yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
724
  output_dir = trainer.args.output_dir
725
  trainer.save_model(output_dir)
 
751
  peft_config = None
752
  if kwargs.get('peft'):
753
  target_modules = kwargs.get('target_modules').split(",") if not kwargs.get('auto_find_target_modules') else _find_all_linear_names(model, kwargs.get('quantization'))
754
+ yield update_logs_fn(f"M贸dulos LoRA: {target_modules}", "Configuraci贸n")
755
  peft_config = LoraConfig(
756
  r=int(kwargs.get('lora_r')), lora_alpha=int(kwargs.get('lora_alpha')), lora_dropout=float(kwargs.get('lora_dropout')),
757
  target_modules=target_modules, bias="none", task_type="CAUSAL_LM", use_dora=kwargs.get('use_dora', False),
 
763
  if kwargs.get('run_evaluation'):
764
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
765
  for update in eval_dataset_gen:
766
+ if isinstance(update, dict): yield update
767
+ else: eval_dataset = update
 
 
768
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
769
  trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "tokenizer": tokenizer, "peft_config": peft_config}
770
  if is_dpo:
 
807
  if kwargs.get('run_evaluation'):
808
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
809
  for update in eval_dataset_gen:
810
+ if isinstance(update, dict): yield update
811
+ else: eval_dataset = update
 
 
812
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True)
813
  metric = hf_evaluate.load("accuracy")
814
  def compute_metrics(eval_pred):
 
859
  if kwargs.get('run_evaluation'):
860
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
861
  for update in eval_dataset_gen:
862
+ if isinstance(update, dict): yield update
863
+ else: eval_dataset = update
 
 
864
  if eval_dataset: eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True)
865
  metric = hf_evaluate.load("seqeval")
866
  def compute_metrics(p):
 
943
  eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
944
  eval_dataset_raw = None
945
  for update in eval_dataset_raw_gen:
946
+ if isinstance(update, dict): yield update
947
+ else: eval_dataset_raw = update
 
 
948
  if eval_dataset_raw:
949
  eval_dataset = eval_dataset_raw.map(prepare_train_features, batched=True, remove_columns=next(iter(eval_dataset_raw)).keys())
950
  training_args = _create_training_args(output_dir, repo_id, **kwargs)
 
980
  if kwargs.get('run_evaluation'):
981
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
982
  for update in eval_dataset_gen:
983
+ if isinstance(update, dict): yield update
984
+ else: eval_dataset = update
 
 
985
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess_function, batched=True)
986
  metric = hf_evaluate.load("sacrebleu")
987
  def compute_metrics(eval_preds):
 
1632
  scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine")
1633
  mixed_precision = gr.Radio(["no", "fp16", "bf16"], label="Precisi贸n Mixta", value="no")
1634
  with gr.Accordion("Avanzados", open=False):
1635
+ warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
1636
+ weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
1637
+ max_grad_norm = gr.Textbox(label="Norma M谩xima de Gradiente", value="1.0")
1638
+ logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1639
+ save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1640
+ save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
1641
+ early_stopping_patience = gr.Number(label="Paciencia para Early Stopping (0 para desactivar)", value=0)
1642
+ resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False)
1643
+ with gr.Row():
1644
  adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9")
1645
  adam_beta2 = gr.Textbox(label="Adam Beta2", value="0.999")
1646
  adam_epsilon = gr.Textbox(label="Adam Epsilon", value="1e-8")
1647
+ disable_gradient_checkpointing = gr.Checkbox(label="Deshabilitar Gradient Checkpointing", value=False)
1648
+ group_by_length = gr.Checkbox(label="Agrupar por Longitud", value=False)
1649
+ neftune_noise_alpha = gr.Textbox(label="NEFTune Ruido Alfa (0 para desactivar)", value="0")
1650
+ optim_args = gr.Textbox(label="Argumentos del Optimizador (formato dict)", placeholder="ej: betas=(0.9,0.995)")
1651
+ attn_implementation = gr.Dropdown(["eager", "flash_attention_2"], label="Implementaci贸n de Atenci贸n", value="eager")
1652
  with gr.Accordion("馃 PEFT (LoRA / QLoRA)", open=True) as peft_accordion:
1653
  peft = gr.Checkbox(label="Habilitar PEFT/LoRA", value=True)
1654
  quantization = gr.Dropdown(["no", "4bit", "8bit"], label="Cuantizaci贸n", value="no")
 
1831
  inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
1832
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1833
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1834
 
1835
  if __name__ == "__main__":
1836
+ demo.launch(debug=True, share=True)