Ignaciohhhhggfgjfrffd commited on
Commit
5e9eb8a
verified
1 Parent(s): 8da19b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -82
app.py CHANGED
@@ -170,51 +170,40 @@ class DebiasingSFTTrainer(SFTTrainer):
170
  return (loss, outputs) if return_outputs else loss
171
 
172
  @spaces.GPU()
173
- class DeduplicatedIterableDataset(IterableDataset):
174
- def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
175
- super().__init__(ex_iterable=iter([]))
176
- self.dataset = dataset
177
- self.text_col = text_col
178
- self.method = method
179
- self.threshold = threshold
180
- self.num_perm = num_perm
181
- if hasattr(dataset, '_info'):
182
- self._info = dataset._info
183
- elif hasattr(dataset, 'info'):
184
- self._info = dataset.info
185
-
186
- def __iter__(self):
187
- if self.method == 'Exacta':
188
- return self._exact_iter()
189
- elif self.method == 'Sem谩ntica (MinHash)':
190
- return self._minhash_iter()
191
- else:
192
- return iter(self.dataset)
193
-
194
- def _exact_iter(self):
195
- seen_texts = set()
196
- for example in self.dataset:
197
- text = example.get(self.text_col, "")
198
- if text and isinstance(text, str):
199
- if text not in seen_texts:
200
- seen_texts.add(text)
201
  yield example
202
- else:
203
- yield example
204
-
205
- def _minhash_iter(self):
206
- lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm)
207
- for i, example in enumerate(self.dataset):
208
- text = example.get(self.text_col, "")
209
- if text and isinstance(text, str) and text.strip():
210
- m = MinHash(num_perm=self.num_perm)
211
- for d in text.split():
212
- m.update(d.encode('utf8'))
213
- if not lsh.query(m):
214
- lsh.insert(f"key_{i}", m)
215
  yield example
216
- else:
217
- yield example
 
 
 
 
 
 
 
218
 
219
  @spaces.GPU()
220
  def hf_login(token):
@@ -558,6 +547,10 @@ def _create_training_args(output_dir, repo_id, **kwargs):
558
  "adam_epsilon": float(kwargs.get('adam_epsilon', 1e-8)),
559
  "no_cuda": device == 'cpu'
560
  }
 
 
 
 
561
 
562
  is_diffusion_task = kwargs.get('training_mode', '') in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
563
  if is_diffusion_task:
@@ -652,35 +645,43 @@ def _find_all_linear_names(model, quantization_type):
652
 
653
  return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
654
 
655
- @spaces.GPU()
656
- def _conversation_formatting_func(example, tokenizer, **kwargs):
657
- conv_col = ""
658
- for key in ["messages", "conversations", "turns"]:
659
- if key in example: conv_col = key; break
660
- if not conv_col: return ""
661
- conversation = example[conv_col]
662
- if isinstance(conversation, str):
663
- try: conversation = ast.literal_eval(conversation)
664
- except: return ""
665
- return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
666
-
667
  @spaces.GPU()
668
  def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
669
- if kwargs.get('enable_cot_input') or kwargs.get('enable_tool_use_input'):
 
 
 
 
 
 
 
 
 
 
 
670
  messages = []
671
  prompt = example.get(kwargs.get('prompt_col_input', 'prompt'), "")
672
  if prompt: messages.append({"role": "user", "content": prompt})
 
673
  response_parts = []
674
- if kwargs.get('enable_cot_input') and example.get(kwargs.get('reasoning_col_input', 'reasoning')): response_parts.append(f"<thinking>{example[kwargs.get('reasoning_col_input', 'reasoning')]}</thinking>")
675
- if kwargs.get('enable_tool_use_input') and example.get(kwargs.get('tool_use_col_input', 'tools')): response_parts.append(f"<tool_code>{example[kwargs.get('tool_use_col_input', 'tools')]}</tool_code>")
676
- if example.get(kwargs.get('response_col_input', 'response')): response_parts.append(example[kwargs.get('response_col_input', 'response')])
677
- if response_parts: messages.append({"role": "assistant", "content": "\n".join(response_parts)})
 
 
 
 
 
 
678
  if messages:
679
  try:
680
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
681
  except Exception as e:
682
- logger.error(f"Error applying chat template: {e}.")
683
  return "\n".join([m['content'] for m in messages])
 
 
684
  return example.get(text_col, "")
685
 
686
  @spaces.GPU()
@@ -811,17 +812,20 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
811
  eval_dataset = update
812
 
813
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
814
- trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config}
815
 
816
  if is_dpo:
817
  trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
 
 
818
  if eval_dataset:
819
  eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
 
820
  else:
821
  sft_kwargs = kwargs.copy()
822
- trainer_kwargs.update({"formatting_func": lambda ex, tc=text_col, skw=sft_kwargs: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=tc, **skw)})
823
  if kwargs.get('enable_loss_reweighting'):
824
- trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
825
 
826
  trainer = TrainerClass(**trainer_kwargs)
827
  final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
@@ -1277,7 +1281,7 @@ def _get_data_processing_pipeline(**kwargs):
1277
  if kwargs.get('uploads'):
1278
  uploaded_data_map = _load_uploaded_stream(kwargs.get('uploads'))
1279
  if uploaded_data_map and uploaded_data_map["train"]:
1280
- train_dataset = Dataset.from_list(uploaded_data_map["train"])
1281
  uploaded_val_data = uploaded_data_map["validation"]
1282
 
1283
  if hf_ids:
@@ -1325,7 +1329,7 @@ def _get_data_processing_pipeline(**kwargs):
1325
 
1326
  dedup_method = kwargs.get('deduplication_method')
1327
  if dedup_method != 'Ninguna':
1328
- train_dataset = DeduplicatedIterableDataset(
1329
  dataset=train_dataset,
1330
  text_col=text_col,
1331
  method=dedup_method,
@@ -1659,8 +1663,10 @@ def gradio_preview_data_wrapper(*args):
1659
  formatted_text = ""
1660
  if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
1661
  formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False)
1662
- else:
1663
  formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
 
 
1664
 
1665
  preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
1666
 
@@ -1696,14 +1702,19 @@ def toggle_task_specific_ui(training_mode):
1696
  gr.update(visible=is_classification or is_ner),
1697
  gr.update(visible=is_dpo),
1698
  gr.update(visible=is_sft),
1699
- gr.update(visible=training_mode == "Text-to-Image (LoRA)"),
1700
  gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
1701
  gr.update(visible=not is_diffusion),
1702
  gr.update(visible=is_diffusion),
 
1703
  gr.update(visible=not is_streaming),
1704
- gr.update(visible=is_streaming)
1705
  )
1706
 
 
 
 
 
 
1707
  @spaces.GPU()
1708
  def toggle_auto_modules_ui(is_auto):
1709
  return gr.update(visible=not is_auto)
@@ -1784,10 +1795,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1784
  gradient_accumulation = gr.Textbox(label="Acumulaci贸n de Gradiente", value="8")
1785
  with gr.Row():
1786
  block_size = gr.Textbox(label="Longitud de Secuencia", value="1024")
 
 
1787
  with gr.Group(visible=False) as epochs_ui:
1788
  epochs = gr.Textbox(label="脡pocas", value="1")
1789
- with gr.Group(visible=True) as max_steps_ui:
1790
- max_steps = gr.Textbox(label="M谩ximos Pasos de Entrenamiento", value="100")
1791
  with gr.Row():
1792
  optimizer = gr.Dropdown(["adamw_torch", "adafactor", "sgd", "adagrad"], label="Optimizador", value="adamw_torch")
1793
  scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine")
@@ -1799,6 +1810,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1799
  logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1800
  save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1801
  save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
 
1802
  resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False)
1803
  with gr.Row():
1804
  adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9")
@@ -1851,25 +1863,34 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1851
  num_synthetic_samples = gr.Number(label="N煤mero de Muestras", value=1000)
1852
 
1853
  with gr.Accordion("馃摑 Configuraci贸n de Formato y Tarea", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1854
  with gr.Group(visible=False) as diffusion_ui:
1855
  diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
1856
  with gr.Group(visible=False) as dreambooth_ui:
1857
  dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
1858
  dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
1859
- with gr.Group(visible=False) as classification_labels_ui:
1860
- classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
1861
- with gr.Group(visible=False) as dpo_ui:
1862
- prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt")
1863
- dpo_chosen_col_input = gr.Textbox(label="Columna Elegida", value="chosen")
1864
- dpo_rejected_col_input = gr.Textbox(label="Columna Rechazada", value="rejected")
1865
- with gr.Group(visible=True) as sft_ui:
1866
- chat_template_jinja = gr.Textbox(label="Plantilla de Chat Jinja2 (opcional)", lines=5)
1867
 
1868
  with gr.Accordion("馃搳 Evaluaci贸n y Mitigaci贸n de Sesgos", open=False):
1869
  run_evaluation = gr.Checkbox(label="Ejecutar Evaluaci贸n", value=False)
1870
  run_perplexity_evaluation = gr.Checkbox(label="Calcular Perplejidad", value=True)
1871
  enable_loss_reweighting = gr.Checkbox(label="Habilitar Re-ponderaci贸n de P茅rdida", value=False)
1872
  reweighting_terms = gr.Textbox(label="T茅rminos para Re-ponderar (csv)", placeholder="sesgo,injusto")
 
1873
  enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
1874
  cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
1875
 
@@ -1901,6 +1922,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1901
  "adam_beta1": adam_beta1, "adam_beta2": adam_beta2, "adam_epsilon": adam_epsilon,
1902
  "disable_gradient_checkpointing": disable_gradient_checkpointing, "group_by_length": group_by_length,
1903
  "neftune_noise_alpha": neftune_noise_alpha, "optim_args": optim_args, "attn_implementation": attn_implementation,
 
1904
  "peft": peft, "quantization": quantization, "lora_r": lora_r, "lora_alpha": lora_alpha,
1905
  "lora_dropout": lora_dropout, "auto_find_target_modules": auto_find_target_modules, "target_modules": target_modules,
1906
  "modules_to_save": modules_to_save, "use_dora": use_dora, "use_rslora": use_rslora, "init_lora_weights": init_lora_weights,
@@ -1912,10 +1934,14 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1912
  "enable_back_translation": enable_back_translation, "bt_model_id": bt_model_id,
1913
  "bt_reverse_model_id": bt_reverse_model_id, "enable_synthetic_data": enable_synthetic_data,
1914
  "synthetic_model_id": synthetic_model_id, "num_synthetic_samples": num_synthetic_samples,
1915
- "chat_template_jinja": chat_template_jinja, "prompt_col_input": prompt_col_input, "dpo_chosen_col_input": dpo_chosen_col_input,
 
 
 
 
1916
  "dpo_rejected_col_input": dpo_rejected_col_input, "classification_labels": classification_labels,
1917
  "diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
1918
- "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
1919
  "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
1920
  "dreambooth_instance_prompt": dreambooth_instance_prompt,
1921
  "dreambooth_train_text_encoder": dreambooth_train_text_encoder
@@ -1939,7 +1965,13 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1939
  training_mode.change(
1940
  toggle_task_specific_ui,
1941
  inputs=[training_mode],
1942
- outputs=[classification_labels_ui, dpo_ui, sft_ui, diffusion_ui, dreambooth_ui, peft_accordion, epochs_ui, max_steps_ui]
 
 
 
 
 
 
1943
  )
1944
 
1945
  auto_find_target_modules.change(
 
170
  return (loss, outputs) if return_outputs else loss
171
 
172
  @spaces.GPU()
173
+ def _create_deduplicated_iterable_dataset(dataset, text_col, method, threshold=0.85, num_perm=128):
174
+ def gen():
175
+ if method == 'Exacta':
176
+ seen_texts = set()
177
+ for example in dataset:
178
+ text = example.get(text_col, "")
179
+ if text and isinstance(text, str):
180
+ if text not in seen_texts:
181
+ seen_texts.add(text)
182
+ yield example
183
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  yield example
185
+ elif method == 'Sem谩ntica (MinHash)':
186
+ lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
187
+ for i, example in enumerate(dataset):
188
+ text = example.get(text_col, "")
189
+ if text and isinstance(text, str) and text.strip():
190
+ m = MinHash(num_perm=num_perm)
191
+ for d in text.split():
192
+ m.update(d.encode('utf8'))
193
+ if not lsh.query(m):
194
+ lsh.insert(f"key_{i}", m)
195
+ yield example
196
+ else:
 
197
  yield example
198
+ else:
199
+ yield from dataset
200
+
201
+ new_ds = IterableDataset.from_generator(gen)
202
+ if hasattr(dataset, 'info'):
203
+ new_ds.info = dataset.info
204
+ elif hasattr(dataset, '_info'):
205
+ new_ds.info = dataset._info
206
+ return new_ds
207
 
208
  @spaces.GPU()
209
  def hf_login(token):
 
547
  "adam_epsilon": float(kwargs.get('adam_epsilon', 1e-8)),
548
  "no_cuda": device == 'cpu'
549
  }
550
+
551
+ if kwargs.get('early_stopping_patience', 0) > 0 and kwargs.get('run_evaluation', False):
552
+ args_dict['early_stopping_patience'] = int(kwargs['early_stopping_patience'])
553
+ args_dict['load_best_model_at_end'] = True
554
 
555
  is_diffusion_task = kwargs.get('training_mode', '') in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
556
  if is_diffusion_task:
 
645
 
646
  return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
647
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  @spaces.GPU()
649
  def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
650
+ if kwargs.get('sft_format_style') == "Conversacional":
651
+ conv_col = ""
652
+ for key in ["messages", "conversations", "turns"]:
653
+ if key in example: conv_col = key; break
654
+ if not conv_col: return ""
655
+ conversation = example[conv_col]
656
+ if isinstance(conversation, str):
657
+ try: conversation = ast.literal_eval(conversation)
658
+ except: return ""
659
+ return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
660
+
661
+ if kwargs.get('sft_format_style') == "Razonamiento/Herramientas":
662
  messages = []
663
  prompt = example.get(kwargs.get('prompt_col_input', 'prompt'), "")
664
  if prompt: messages.append({"role": "user", "content": prompt})
665
+
666
  response_parts = []
667
+ if kwargs.get('enable_cot_input') and example.get(kwargs.get('reasoning_col_input', 'reasoning')):
668
+ response_parts.append(f"<thinking>{example[kwargs.get('reasoning_col_input', 'reasoning')]}</thinking>")
669
+ if kwargs.get('enable_tool_use_input') and example.get(kwargs.get('tool_use_col_input', 'tools')):
670
+ response_parts.append(f"<tool_code>{example[kwargs.get('tool_use_col_input', 'tools')]}</tool_code>")
671
+ if example.get(kwargs.get('response_col_input', 'response')):
672
+ response_parts.append(example[kwargs.get('response_col_input', 'response')])
673
+
674
+ if response_parts:
675
+ messages.append({"role": "assistant", "content": "\n".join(response_parts)})
676
+
677
  if messages:
678
  try:
679
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
680
  except Exception as e:
681
+ logger.error(f"Error aplicando la plantilla de chat: {e}.")
682
  return "\n".join([m['content'] for m in messages])
683
+ return ""
684
+
685
  return example.get(text_col, "")
686
 
687
  @spaces.GPU()
 
812
  eval_dataset = update
813
 
814
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
815
+ trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "tokenizer": tokenizer, "peft_config": peft_config}
816
 
817
  if is_dpo:
818
  trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
819
+ if train_dataset:
820
+ train_dataset = train_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
821
  if eval_dataset:
822
  eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
823
+ trainer_kwargs.update({"train_dataset": train_dataset, "eval_dataset": eval_dataset})
824
  else:
825
  sft_kwargs = kwargs.copy()
826
+ trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs), "max_seq_length": int(kwargs.get('block_size'))})
827
  if kwargs.get('enable_loss_reweighting'):
828
+ trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': float(kwargs.get('reweighting_factor', 2.0))})
829
 
830
  trainer = TrainerClass(**trainer_kwargs)
831
  final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
 
1281
  if kwargs.get('uploads'):
1282
  uploaded_data_map = _load_uploaded_stream(kwargs.get('uploads'))
1283
  if uploaded_data_map and uploaded_data_map["train"]:
1284
+ train_dataset = IterableDataset.from_generator(lambda: iter(uploaded_data_map["train"]))
1285
  uploaded_val_data = uploaded_data_map["validation"]
1286
 
1287
  if hf_ids:
 
1329
 
1330
  dedup_method = kwargs.get('deduplication_method')
1331
  if dedup_method != 'Ninguna':
1332
+ train_dataset = _create_deduplicated_iterable_dataset(
1333
  dataset=train_dataset,
1334
  text_col=text_col,
1335
  method=dedup_method,
 
1663
  formatted_text = ""
1664
  if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
1665
  formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False)
1666
+ elif kwargs['training_mode'] == "Causal Language Modeling (SFT/LoRA)":
1667
  formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
1668
+ else:
1669
+ formatted_text = str(example)
1670
 
1671
  preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
1672
 
 
1702
  gr.update(visible=is_classification or is_ner),
1703
  gr.update(visible=is_dpo),
1704
  gr.update(visible=is_sft),
1705
+ gr.update(visible=is_diffusion),
1706
  gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
1707
  gr.update(visible=not is_diffusion),
1708
  gr.update(visible=is_diffusion),
1709
+ gr.update(visible=is_streaming),
1710
  gr.update(visible=not is_streaming),
 
1711
  )
1712
 
1713
+ @spaces.GPU()
1714
+ def toggle_sft_format_ui(format_style):
1715
+ is_tool = format_style == "Razonamiento/Herramientas"
1716
+ return gr.update(visible=is_tool)
1717
+
1718
  @spaces.GPU()
1719
  def toggle_auto_modules_ui(is_auto):
1720
  return gr.update(visible=not is_auto)
 
1795
  gradient_accumulation = gr.Textbox(label="Acumulaci贸n de Gradiente", value="8")
1796
  with gr.Row():
1797
  block_size = gr.Textbox(label="Longitud de Secuencia", value="1024")
1798
+ with gr.Group(visible=True) as max_steps_ui:
1799
+ max_steps = gr.Textbox(label="M谩ximos Pasos de Entrenamiento", value="100")
1800
  with gr.Group(visible=False) as epochs_ui:
1801
  epochs = gr.Textbox(label="脡pocas", value="1")
 
 
1802
  with gr.Row():
1803
  optimizer = gr.Dropdown(["adamw_torch", "adafactor", "sgd", "adagrad"], label="Optimizador", value="adamw_torch")
1804
  scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine")
 
1810
  logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1811
  save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1812
  save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
1813
+ early_stopping_patience = gr.Number(label="Paciencia para Early Stopping (0 para desactivar)", value=0)
1814
  resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False)
1815
  with gr.Row():
1816
  adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9")
 
1863
  num_synthetic_samples = gr.Number(label="N煤mero de Muestras", value=1000)
1864
 
1865
  with gr.Accordion("馃摑 Configuraci贸n de Formato y Tarea", open=False):
1866
+ with gr.Group(visible=True) as sft_ui:
1867
+ sft_format_style = gr.Radio(["Columna de Texto", "Conversacional", "Razonamiento/Herramientas"], label="Formato de Datos SFT", value="Columna de Texto")
1868
+ chat_template_jinja = gr.Textbox(label="Plantilla de Chat Jinja2 (opcional)", lines=5)
1869
+ with gr.Group(visible=False) as sft_tool_ui:
1870
+ enable_cot_input = gr.Checkbox(label="Habilitar Razonamiento (CoT)", value=True)
1871
+ enable_tool_use_input = gr.Checkbox(label="Habilitar Uso de Herramientas", value=True)
1872
+ prompt_col_input = gr.Textbox(label="Columna de Prompt/Usuario", value="prompt")
1873
+ response_col_input = gr.Textbox(label="Columna de Respuesta Final", value="response")
1874
+ reasoning_col_input = gr.Textbox(label="Columna de Razonamiento", value="reasoning")
1875
+ tool_use_col_input = gr.Textbox(label="Columna de Uso de Herramientas", value="tools")
1876
+ with gr.Group(visible=False) as dpo_ui:
1877
+ dpo_prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt")
1878
+ dpo_chosen_col_input = gr.Textbox(label="Columna Elegida", value="chosen")
1879
+ dpo_rejected_col_input = gr.Textbox(label="Columna Rechazada", value="rejected")
1880
+ with gr.Group(visible=False) as classification_labels_ui:
1881
+ classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
1882
  with gr.Group(visible=False) as diffusion_ui:
1883
  diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
1884
  with gr.Group(visible=False) as dreambooth_ui:
1885
  dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
1886
  dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
 
 
 
 
 
 
 
 
1887
 
1888
  with gr.Accordion("馃搳 Evaluaci贸n y Mitigaci贸n de Sesgos", open=False):
1889
  run_evaluation = gr.Checkbox(label="Ejecutar Evaluaci贸n", value=False)
1890
  run_perplexity_evaluation = gr.Checkbox(label="Calcular Perplejidad", value=True)
1891
  enable_loss_reweighting = gr.Checkbox(label="Habilitar Re-ponderaci贸n de P茅rdida", value=False)
1892
  reweighting_terms = gr.Textbox(label="T茅rminos para Re-ponderar (csv)", placeholder="sesgo,injusto")
1893
+ reweighting_factor = gr.Slider(1.1, 10.0, 2.0, label="Factor de Re-ponderaci贸n")
1894
  enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
1895
  cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
1896
 
 
1922
  "adam_beta1": adam_beta1, "adam_beta2": adam_beta2, "adam_epsilon": adam_epsilon,
1923
  "disable_gradient_checkpointing": disable_gradient_checkpointing, "group_by_length": group_by_length,
1924
  "neftune_noise_alpha": neftune_noise_alpha, "optim_args": optim_args, "attn_implementation": attn_implementation,
1925
+ "early_stopping_patience": early_stopping_patience,
1926
  "peft": peft, "quantization": quantization, "lora_r": lora_r, "lora_alpha": lora_alpha,
1927
  "lora_dropout": lora_dropout, "auto_find_target_modules": auto_find_target_modules, "target_modules": target_modules,
1928
  "modules_to_save": modules_to_save, "use_dora": use_dora, "use_rslora": use_rslora, "init_lora_weights": init_lora_weights,
 
1934
  "enable_back_translation": enable_back_translation, "bt_model_id": bt_model_id,
1935
  "bt_reverse_model_id": bt_reverse_model_id, "enable_synthetic_data": enable_synthetic_data,
1936
  "synthetic_model_id": synthetic_model_id, "num_synthetic_samples": num_synthetic_samples,
1937
+ "sft_format_style": sft_format_style, "chat_template_jinja": chat_template_jinja,
1938
+ "enable_cot_input": enable_cot_input, "enable_tool_use_input": enable_tool_use_input,
1939
+ "prompt_col_input": prompt_col_input, "response_col_input": response_col_input,
1940
+ "reasoning_col_input": reasoning_col_input, "tool_use_col_input": tool_use_col_input,
1941
+ "dpo_prompt_col_input": dpo_prompt_col_input, "dpo_chosen_col_input": dpo_chosen_col_input,
1942
  "dpo_rejected_col_input": dpo_rejected_col_input, "classification_labels": classification_labels,
1943
  "diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
1944
+ "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms, "reweighting_factor": reweighting_factor,
1945
  "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
1946
  "dreambooth_instance_prompt": dreambooth_instance_prompt,
1947
  "dreambooth_train_text_encoder": dreambooth_train_text_encoder
 
1965
  training_mode.change(
1966
  toggle_task_specific_ui,
1967
  inputs=[training_mode],
1968
+ outputs=[classification_labels_ui, dpo_ui, sft_ui, diffusion_ui, dreambooth_ui, peft_accordion, epochs_ui, max_steps_ui, peft_accordion]
1969
+ )
1970
+
1971
+ sft_format_style.change(
1972
+ toggle_sft_format_ui,
1973
+ inputs=[sft_format_style],
1974
+ outputs=[sft_tool_ui]
1975
  )
1976
 
1977
  auto_find_target_modules.change(