Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -170,51 +170,40 @@ class DebiasingSFTTrainer(SFTTrainer):
|
|
| 170 |
return (loss, outputs) if return_outputs else loss
|
| 171 |
|
| 172 |
@spaces.GPU()
|
| 173 |
-
|
| 174 |
-
def
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
self._info = dataset.info
|
| 185 |
-
|
| 186 |
-
def __iter__(self):
|
| 187 |
-
if self.method == 'Exacta':
|
| 188 |
-
return self._exact_iter()
|
| 189 |
-
elif self.method == 'Sem谩ntica (MinHash)':
|
| 190 |
-
return self._minhash_iter()
|
| 191 |
-
else:
|
| 192 |
-
return iter(self.dataset)
|
| 193 |
-
|
| 194 |
-
def _exact_iter(self):
|
| 195 |
-
seen_texts = set()
|
| 196 |
-
for example in self.dataset:
|
| 197 |
-
text = example.get(self.text_col, "")
|
| 198 |
-
if text and isinstance(text, str):
|
| 199 |
-
if text not in seen_texts:
|
| 200 |
-
seen_texts.add(text)
|
| 201 |
yield example
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
lsh.insert(f"key_{i}", m)
|
| 215 |
yield example
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
@spaces.GPU()
|
| 220 |
def hf_login(token):
|
|
@@ -558,6 +547,10 @@ def _create_training_args(output_dir, repo_id, **kwargs):
|
|
| 558 |
"adam_epsilon": float(kwargs.get('adam_epsilon', 1e-8)),
|
| 559 |
"no_cuda": device == 'cpu'
|
| 560 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
is_diffusion_task = kwargs.get('training_mode', '') in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
|
| 563 |
if is_diffusion_task:
|
|
@@ -652,35 +645,43 @@ def _find_all_linear_names(model, quantization_type):
|
|
| 652 |
|
| 653 |
return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
|
| 654 |
|
| 655 |
-
@spaces.GPU()
|
| 656 |
-
def _conversation_formatting_func(example, tokenizer, **kwargs):
|
| 657 |
-
conv_col = ""
|
| 658 |
-
for key in ["messages", "conversations", "turns"]:
|
| 659 |
-
if key in example: conv_col = key; break
|
| 660 |
-
if not conv_col: return ""
|
| 661 |
-
conversation = example[conv_col]
|
| 662 |
-
if isinstance(conversation, str):
|
| 663 |
-
try: conversation = ast.literal_eval(conversation)
|
| 664 |
-
except: return ""
|
| 665 |
-
return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
|
| 666 |
-
|
| 667 |
@spaces.GPU()
|
| 668 |
def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
|
| 669 |
-
if kwargs.get('
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
messages = []
|
| 671 |
prompt = example.get(kwargs.get('prompt_col_input', 'prompt'), "")
|
| 672 |
if prompt: messages.append({"role": "user", "content": prompt})
|
|
|
|
| 673 |
response_parts = []
|
| 674 |
-
if kwargs.get('enable_cot_input') and example.get(kwargs.get('reasoning_col_input', 'reasoning')):
|
| 675 |
-
|
| 676 |
-
if
|
| 677 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
if messages:
|
| 679 |
try:
|
| 680 |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
| 681 |
except Exception as e:
|
| 682 |
-
logger.error(f"Error
|
| 683 |
return "\n".join([m['content'] for m in messages])
|
|
|
|
|
|
|
| 684 |
return example.get(text_col, "")
|
| 685 |
|
| 686 |
@spaces.GPU()
|
|
@@ -811,17 +812,20 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
|
|
| 811 |
eval_dataset = update
|
| 812 |
|
| 813 |
TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
|
| 814 |
-
trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config}
|
| 815 |
|
| 816 |
if is_dpo:
|
| 817 |
trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
|
|
|
|
|
|
|
| 818 |
if eval_dataset:
|
| 819 |
eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
|
|
|
|
| 820 |
else:
|
| 821 |
sft_kwargs = kwargs.copy()
|
| 822 |
-
trainer_kwargs.update({"formatting_func": lambda ex
|
| 823 |
if kwargs.get('enable_loss_reweighting'):
|
| 824 |
-
trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
|
| 825 |
|
| 826 |
trainer = TrainerClass(**trainer_kwargs)
|
| 827 |
final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
|
|
@@ -1277,7 +1281,7 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1277 |
if kwargs.get('uploads'):
|
| 1278 |
uploaded_data_map = _load_uploaded_stream(kwargs.get('uploads'))
|
| 1279 |
if uploaded_data_map and uploaded_data_map["train"]:
|
| 1280 |
-
train_dataset =
|
| 1281 |
uploaded_val_data = uploaded_data_map["validation"]
|
| 1282 |
|
| 1283 |
if hf_ids:
|
|
@@ -1325,7 +1329,7 @@ def _get_data_processing_pipeline(**kwargs):
|
|
| 1325 |
|
| 1326 |
dedup_method = kwargs.get('deduplication_method')
|
| 1327 |
if dedup_method != 'Ninguna':
|
| 1328 |
-
train_dataset =
|
| 1329 |
dataset=train_dataset,
|
| 1330 |
text_col=text_col,
|
| 1331 |
method=dedup_method,
|
|
@@ -1659,8 +1663,10 @@ def gradio_preview_data_wrapper(*args):
|
|
| 1659 |
formatted_text = ""
|
| 1660 |
if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
|
| 1661 |
formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False)
|
| 1662 |
-
|
| 1663 |
formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
|
|
|
|
|
|
|
| 1664 |
|
| 1665 |
preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
|
| 1666 |
|
|
@@ -1696,14 +1702,19 @@ def toggle_task_specific_ui(training_mode):
|
|
| 1696 |
gr.update(visible=is_classification or is_ner),
|
| 1697 |
gr.update(visible=is_dpo),
|
| 1698 |
gr.update(visible=is_sft),
|
| 1699 |
-
gr.update(visible=
|
| 1700 |
gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
|
| 1701 |
gr.update(visible=not is_diffusion),
|
| 1702 |
gr.update(visible=is_diffusion),
|
|
|
|
| 1703 |
gr.update(visible=not is_streaming),
|
| 1704 |
-
gr.update(visible=is_streaming)
|
| 1705 |
)
|
| 1706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1707 |
@spaces.GPU()
|
| 1708 |
def toggle_auto_modules_ui(is_auto):
|
| 1709 |
return gr.update(visible=not is_auto)
|
|
@@ -1784,10 +1795,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1784 |
gradient_accumulation = gr.Textbox(label="Acumulaci贸n de Gradiente", value="8")
|
| 1785 |
with gr.Row():
|
| 1786 |
block_size = gr.Textbox(label="Longitud de Secuencia", value="1024")
|
|
|
|
|
|
|
| 1787 |
with gr.Group(visible=False) as epochs_ui:
|
| 1788 |
epochs = gr.Textbox(label="脡pocas", value="1")
|
| 1789 |
-
with gr.Group(visible=True) as max_steps_ui:
|
| 1790 |
-
max_steps = gr.Textbox(label="M谩ximos Pasos de Entrenamiento", value="100")
|
| 1791 |
with gr.Row():
|
| 1792 |
optimizer = gr.Dropdown(["adamw_torch", "adafactor", "sgd", "adagrad"], label="Optimizador", value="adamw_torch")
|
| 1793 |
scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine")
|
|
@@ -1799,6 +1810,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1799 |
logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
|
| 1800 |
save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
|
| 1801 |
save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
|
|
|
|
| 1802 |
resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False)
|
| 1803 |
with gr.Row():
|
| 1804 |
adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9")
|
|
@@ -1851,25 +1863,34 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1851 |
num_synthetic_samples = gr.Number(label="N煤mero de Muestras", value=1000)
|
| 1852 |
|
| 1853 |
with gr.Accordion("馃摑 Configuraci贸n de Formato y Tarea", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1854 |
with gr.Group(visible=False) as diffusion_ui:
|
| 1855 |
diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
|
| 1856 |
with gr.Group(visible=False) as dreambooth_ui:
|
| 1857 |
dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
|
| 1858 |
dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
|
| 1859 |
-
with gr.Group(visible=False) as classification_labels_ui:
|
| 1860 |
-
classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
|
| 1861 |
-
with gr.Group(visible=False) as dpo_ui:
|
| 1862 |
-
prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt")
|
| 1863 |
-
dpo_chosen_col_input = gr.Textbox(label="Columna Elegida", value="chosen")
|
| 1864 |
-
dpo_rejected_col_input = gr.Textbox(label="Columna Rechazada", value="rejected")
|
| 1865 |
-
with gr.Group(visible=True) as sft_ui:
|
| 1866 |
-
chat_template_jinja = gr.Textbox(label="Plantilla de Chat Jinja2 (opcional)", lines=5)
|
| 1867 |
|
| 1868 |
with gr.Accordion("馃搳 Evaluaci贸n y Mitigaci贸n de Sesgos", open=False):
|
| 1869 |
run_evaluation = gr.Checkbox(label="Ejecutar Evaluaci贸n", value=False)
|
| 1870 |
run_perplexity_evaluation = gr.Checkbox(label="Calcular Perplejidad", value=True)
|
| 1871 |
enable_loss_reweighting = gr.Checkbox(label="Habilitar Re-ponderaci贸n de P茅rdida", value=False)
|
| 1872 |
reweighting_terms = gr.Textbox(label="T茅rminos para Re-ponderar (csv)", placeholder="sesgo,injusto")
|
|
|
|
| 1873 |
enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
|
| 1874 |
cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
|
| 1875 |
|
|
@@ -1901,6 +1922,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1901 |
"adam_beta1": adam_beta1, "adam_beta2": adam_beta2, "adam_epsilon": adam_epsilon,
|
| 1902 |
"disable_gradient_checkpointing": disable_gradient_checkpointing, "group_by_length": group_by_length,
|
| 1903 |
"neftune_noise_alpha": neftune_noise_alpha, "optim_args": optim_args, "attn_implementation": attn_implementation,
|
|
|
|
| 1904 |
"peft": peft, "quantization": quantization, "lora_r": lora_r, "lora_alpha": lora_alpha,
|
| 1905 |
"lora_dropout": lora_dropout, "auto_find_target_modules": auto_find_target_modules, "target_modules": target_modules,
|
| 1906 |
"modules_to_save": modules_to_save, "use_dora": use_dora, "use_rslora": use_rslora, "init_lora_weights": init_lora_weights,
|
|
@@ -1912,10 +1934,14 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1912 |
"enable_back_translation": enable_back_translation, "bt_model_id": bt_model_id,
|
| 1913 |
"bt_reverse_model_id": bt_reverse_model_id, "enable_synthetic_data": enable_synthetic_data,
|
| 1914 |
"synthetic_model_id": synthetic_model_id, "num_synthetic_samples": num_synthetic_samples,
|
| 1915 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1916 |
"dpo_rejected_col_input": dpo_rejected_col_input, "classification_labels": classification_labels,
|
| 1917 |
"diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
|
| 1918 |
-
"enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
|
| 1919 |
"wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
|
| 1920 |
"dreambooth_instance_prompt": dreambooth_instance_prompt,
|
| 1921 |
"dreambooth_train_text_encoder": dreambooth_train_text_encoder
|
|
@@ -1939,7 +1965,13 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1939 |
training_mode.change(
|
| 1940 |
toggle_task_specific_ui,
|
| 1941 |
inputs=[training_mode],
|
| 1942 |
-
outputs=[classification_labels_ui, dpo_ui, sft_ui, diffusion_ui, dreambooth_ui, peft_accordion, epochs_ui, max_steps_ui]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1943 |
)
|
| 1944 |
|
| 1945 |
auto_find_target_modules.change(
|
|
|
|
| 170 |
return (loss, outputs) if return_outputs else loss
|
| 171 |
|
| 172 |
@spaces.GPU()
|
| 173 |
+
def _create_deduplicated_iterable_dataset(dataset, text_col, method, threshold=0.85, num_perm=128):
|
| 174 |
+
def gen():
|
| 175 |
+
if method == 'Exacta':
|
| 176 |
+
seen_texts = set()
|
| 177 |
+
for example in dataset:
|
| 178 |
+
text = example.get(text_col, "")
|
| 179 |
+
if text and isinstance(text, str):
|
| 180 |
+
if text not in seen_texts:
|
| 181 |
+
seen_texts.add(text)
|
| 182 |
+
yield example
|
| 183 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
yield example
|
| 185 |
+
elif method == 'Sem谩ntica (MinHash)':
|
| 186 |
+
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
| 187 |
+
for i, example in enumerate(dataset):
|
| 188 |
+
text = example.get(text_col, "")
|
| 189 |
+
if text and isinstance(text, str) and text.strip():
|
| 190 |
+
m = MinHash(num_perm=num_perm)
|
| 191 |
+
for d in text.split():
|
| 192 |
+
m.update(d.encode('utf8'))
|
| 193 |
+
if not lsh.query(m):
|
| 194 |
+
lsh.insert(f"key_{i}", m)
|
| 195 |
+
yield example
|
| 196 |
+
else:
|
|
|
|
| 197 |
yield example
|
| 198 |
+
else:
|
| 199 |
+
yield from dataset
|
| 200 |
+
|
| 201 |
+
new_ds = IterableDataset.from_generator(gen)
|
| 202 |
+
if hasattr(dataset, 'info'):
|
| 203 |
+
new_ds.info = dataset.info
|
| 204 |
+
elif hasattr(dataset, '_info'):
|
| 205 |
+
new_ds.info = dataset._info
|
| 206 |
+
return new_ds
|
| 207 |
|
| 208 |
@spaces.GPU()
|
| 209 |
def hf_login(token):
|
|
|
|
| 547 |
"adam_epsilon": float(kwargs.get('adam_epsilon', 1e-8)),
|
| 548 |
"no_cuda": device == 'cpu'
|
| 549 |
}
|
| 550 |
+
|
| 551 |
+
if kwargs.get('early_stopping_patience', 0) > 0 and kwargs.get('run_evaluation', False):
|
| 552 |
+
args_dict['early_stopping_patience'] = int(kwargs['early_stopping_patience'])
|
| 553 |
+
args_dict['load_best_model_at_end'] = True
|
| 554 |
|
| 555 |
is_diffusion_task = kwargs.get('training_mode', '') in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
|
| 556 |
if is_diffusion_task:
|
|
|
|
| 645 |
|
| 646 |
return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
|
| 647 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
@spaces.GPU()
|
| 649 |
def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
|
| 650 |
+
if kwargs.get('sft_format_style') == "Conversacional":
|
| 651 |
+
conv_col = ""
|
| 652 |
+
for key in ["messages", "conversations", "turns"]:
|
| 653 |
+
if key in example: conv_col = key; break
|
| 654 |
+
if not conv_col: return ""
|
| 655 |
+
conversation = example[conv_col]
|
| 656 |
+
if isinstance(conversation, str):
|
| 657 |
+
try: conversation = ast.literal_eval(conversation)
|
| 658 |
+
except: return ""
|
| 659 |
+
return tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
|
| 660 |
+
|
| 661 |
+
if kwargs.get('sft_format_style') == "Razonamiento/Herramientas":
|
| 662 |
messages = []
|
| 663 |
prompt = example.get(kwargs.get('prompt_col_input', 'prompt'), "")
|
| 664 |
if prompt: messages.append({"role": "user", "content": prompt})
|
| 665 |
+
|
| 666 |
response_parts = []
|
| 667 |
+
if kwargs.get('enable_cot_input') and example.get(kwargs.get('reasoning_col_input', 'reasoning')):
|
| 668 |
+
response_parts.append(f"<thinking>{example[kwargs.get('reasoning_col_input', 'reasoning')]}</thinking>")
|
| 669 |
+
if kwargs.get('enable_tool_use_input') and example.get(kwargs.get('tool_use_col_input', 'tools')):
|
| 670 |
+
response_parts.append(f"<tool_code>{example[kwargs.get('tool_use_col_input', 'tools')]}</tool_code>")
|
| 671 |
+
if example.get(kwargs.get('response_col_input', 'response')):
|
| 672 |
+
response_parts.append(example[kwargs.get('response_col_input', 'response')])
|
| 673 |
+
|
| 674 |
+
if response_parts:
|
| 675 |
+
messages.append({"role": "assistant", "content": "\n".join(response_parts)})
|
| 676 |
+
|
| 677 |
if messages:
|
| 678 |
try:
|
| 679 |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
| 680 |
except Exception as e:
|
| 681 |
+
logger.error(f"Error aplicando la plantilla de chat: {e}.")
|
| 682 |
return "\n".join([m['content'] for m in messages])
|
| 683 |
+
return ""
|
| 684 |
+
|
| 685 |
return example.get(text_col, "")
|
| 686 |
|
| 687 |
@spaces.GPU()
|
|
|
|
| 812 |
eval_dataset = update
|
| 813 |
|
| 814 |
TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
|
| 815 |
+
trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "tokenizer": tokenizer, "peft_config": peft_config}
|
| 816 |
|
| 817 |
if is_dpo:
|
| 818 |
trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
|
| 819 |
+
if train_dataset:
|
| 820 |
+
train_dataset = train_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
|
| 821 |
if eval_dataset:
|
| 822 |
eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
|
| 823 |
+
trainer_kwargs.update({"train_dataset": train_dataset, "eval_dataset": eval_dataset})
|
| 824 |
else:
|
| 825 |
sft_kwargs = kwargs.copy()
|
| 826 |
+
trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs), "max_seq_length": int(kwargs.get('block_size'))})
|
| 827 |
if kwargs.get('enable_loss_reweighting'):
|
| 828 |
+
trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': float(kwargs.get('reweighting_factor', 2.0))})
|
| 829 |
|
| 830 |
trainer = TrainerClass(**trainer_kwargs)
|
| 831 |
final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
|
|
|
|
| 1281 |
if kwargs.get('uploads'):
|
| 1282 |
uploaded_data_map = _load_uploaded_stream(kwargs.get('uploads'))
|
| 1283 |
if uploaded_data_map and uploaded_data_map["train"]:
|
| 1284 |
+
train_dataset = IterableDataset.from_generator(lambda: iter(uploaded_data_map["train"]))
|
| 1285 |
uploaded_val_data = uploaded_data_map["validation"]
|
| 1286 |
|
| 1287 |
if hf_ids:
|
|
|
|
| 1329 |
|
| 1330 |
dedup_method = kwargs.get('deduplication_method')
|
| 1331 |
if dedup_method != 'Ninguna':
|
| 1332 |
+
train_dataset = _create_deduplicated_iterable_dataset(
|
| 1333 |
dataset=train_dataset,
|
| 1334 |
text_col=text_col,
|
| 1335 |
method=dedup_method,
|
|
|
|
| 1663 |
formatted_text = ""
|
| 1664 |
if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
|
| 1665 |
formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False)
|
| 1666 |
+
elif kwargs['training_mode'] == "Causal Language Modeling (SFT/LoRA)":
|
| 1667 |
formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
|
| 1668 |
+
else:
|
| 1669 |
+
formatted_text = str(example)
|
| 1670 |
|
| 1671 |
preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
|
| 1672 |
|
|
|
|
| 1702 |
gr.update(visible=is_classification or is_ner),
|
| 1703 |
gr.update(visible=is_dpo),
|
| 1704 |
gr.update(visible=is_sft),
|
| 1705 |
+
gr.update(visible=is_diffusion),
|
| 1706 |
gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
|
| 1707 |
gr.update(visible=not is_diffusion),
|
| 1708 |
gr.update(visible=is_diffusion),
|
| 1709 |
+
gr.update(visible=is_streaming),
|
| 1710 |
gr.update(visible=not is_streaming),
|
|
|
|
| 1711 |
)
|
| 1712 |
|
| 1713 |
+
@spaces.GPU()
|
| 1714 |
+
def toggle_sft_format_ui(format_style):
|
| 1715 |
+
is_tool = format_style == "Razonamiento/Herramientas"
|
| 1716 |
+
return gr.update(visible=is_tool)
|
| 1717 |
+
|
| 1718 |
@spaces.GPU()
|
| 1719 |
def toggle_auto_modules_ui(is_auto):
|
| 1720 |
return gr.update(visible=not is_auto)
|
|
|
|
| 1795 |
gradient_accumulation = gr.Textbox(label="Acumulaci贸n de Gradiente", value="8")
|
| 1796 |
with gr.Row():
|
| 1797 |
block_size = gr.Textbox(label="Longitud de Secuencia", value="1024")
|
| 1798 |
+
with gr.Group(visible=True) as max_steps_ui:
|
| 1799 |
+
max_steps = gr.Textbox(label="M谩ximos Pasos de Entrenamiento", value="100")
|
| 1800 |
with gr.Group(visible=False) as epochs_ui:
|
| 1801 |
epochs = gr.Textbox(label="脡pocas", value="1")
|
|
|
|
|
|
|
| 1802 |
with gr.Row():
|
| 1803 |
optimizer = gr.Dropdown(["adamw_torch", "adafactor", "sgd", "adagrad"], label="Optimizador", value="adamw_torch")
|
| 1804 |
scheduler = gr.Dropdown(["cosine", "linear", "constant"], label="Planificador LR", value="cosine")
|
|
|
|
| 1810 |
logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
|
| 1811 |
save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
|
| 1812 |
save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
|
| 1813 |
+
early_stopping_patience = gr.Number(label="Paciencia para Early Stopping (0 para desactivar)", value=0)
|
| 1814 |
resume_from_checkpoint = gr.Checkbox(label="Reanudar desde Checkpoint", value=False)
|
| 1815 |
with gr.Row():
|
| 1816 |
adam_beta1 = gr.Textbox(label="Adam Beta1", value="0.9")
|
|
|
|
| 1863 |
num_synthetic_samples = gr.Number(label="N煤mero de Muestras", value=1000)
|
| 1864 |
|
| 1865 |
with gr.Accordion("馃摑 Configuraci贸n de Formato y Tarea", open=False):
|
| 1866 |
+
with gr.Group(visible=True) as sft_ui:
|
| 1867 |
+
sft_format_style = gr.Radio(["Columna de Texto", "Conversacional", "Razonamiento/Herramientas"], label="Formato de Datos SFT", value="Columna de Texto")
|
| 1868 |
+
chat_template_jinja = gr.Textbox(label="Plantilla de Chat Jinja2 (opcional)", lines=5)
|
| 1869 |
+
with gr.Group(visible=False) as sft_tool_ui:
|
| 1870 |
+
enable_cot_input = gr.Checkbox(label="Habilitar Razonamiento (CoT)", value=True)
|
| 1871 |
+
enable_tool_use_input = gr.Checkbox(label="Habilitar Uso de Herramientas", value=True)
|
| 1872 |
+
prompt_col_input = gr.Textbox(label="Columna de Prompt/Usuario", value="prompt")
|
| 1873 |
+
response_col_input = gr.Textbox(label="Columna de Respuesta Final", value="response")
|
| 1874 |
+
reasoning_col_input = gr.Textbox(label="Columna de Razonamiento", value="reasoning")
|
| 1875 |
+
tool_use_col_input = gr.Textbox(label="Columna de Uso de Herramientas", value="tools")
|
| 1876 |
+
with gr.Group(visible=False) as dpo_ui:
|
| 1877 |
+
dpo_prompt_col_input = gr.Textbox(label="Columna de Prompt", value="prompt")
|
| 1878 |
+
dpo_chosen_col_input = gr.Textbox(label="Columna Elegida", value="chosen")
|
| 1879 |
+
dpo_rejected_col_input = gr.Textbox(label="Columna Rechazada", value="rejected")
|
| 1880 |
+
with gr.Group(visible=False) as classification_labels_ui:
|
| 1881 |
+
classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
|
| 1882 |
with gr.Group(visible=False) as diffusion_ui:
|
| 1883 |
diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
|
| 1884 |
with gr.Group(visible=False) as dreambooth_ui:
|
| 1885 |
dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
|
| 1886 |
dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1887 |
|
| 1888 |
with gr.Accordion("馃搳 Evaluaci贸n y Mitigaci贸n de Sesgos", open=False):
|
| 1889 |
run_evaluation = gr.Checkbox(label="Ejecutar Evaluaci贸n", value=False)
|
| 1890 |
run_perplexity_evaluation = gr.Checkbox(label="Calcular Perplejidad", value=True)
|
| 1891 |
enable_loss_reweighting = gr.Checkbox(label="Habilitar Re-ponderaci贸n de P茅rdida", value=False)
|
| 1892 |
reweighting_terms = gr.Textbox(label="T茅rminos para Re-ponderar (csv)", placeholder="sesgo,injusto")
|
| 1893 |
+
reweighting_factor = gr.Slider(1.1, 10.0, 2.0, label="Factor de Re-ponderaci贸n")
|
| 1894 |
enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
|
| 1895 |
cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
|
| 1896 |
|
|
|
|
| 1922 |
"adam_beta1": adam_beta1, "adam_beta2": adam_beta2, "adam_epsilon": adam_epsilon,
|
| 1923 |
"disable_gradient_checkpointing": disable_gradient_checkpointing, "group_by_length": group_by_length,
|
| 1924 |
"neftune_noise_alpha": neftune_noise_alpha, "optim_args": optim_args, "attn_implementation": attn_implementation,
|
| 1925 |
+
"early_stopping_patience": early_stopping_patience,
|
| 1926 |
"peft": peft, "quantization": quantization, "lora_r": lora_r, "lora_alpha": lora_alpha,
|
| 1927 |
"lora_dropout": lora_dropout, "auto_find_target_modules": auto_find_target_modules, "target_modules": target_modules,
|
| 1928 |
"modules_to_save": modules_to_save, "use_dora": use_dora, "use_rslora": use_rslora, "init_lora_weights": init_lora_weights,
|
|
|
|
| 1934 |
"enable_back_translation": enable_back_translation, "bt_model_id": bt_model_id,
|
| 1935 |
"bt_reverse_model_id": bt_reverse_model_id, "enable_synthetic_data": enable_synthetic_data,
|
| 1936 |
"synthetic_model_id": synthetic_model_id, "num_synthetic_samples": num_synthetic_samples,
|
| 1937 |
+
"sft_format_style": sft_format_style, "chat_template_jinja": chat_template_jinja,
|
| 1938 |
+
"enable_cot_input": enable_cot_input, "enable_tool_use_input": enable_tool_use_input,
|
| 1939 |
+
"prompt_col_input": prompt_col_input, "response_col_input": response_col_input,
|
| 1940 |
+
"reasoning_col_input": reasoning_col_input, "tool_use_col_input": tool_use_col_input,
|
| 1941 |
+
"dpo_prompt_col_input": dpo_prompt_col_input, "dpo_chosen_col_input": dpo_chosen_col_input,
|
| 1942 |
"dpo_rejected_col_input": dpo_rejected_col_input, "classification_labels": classification_labels,
|
| 1943 |
"diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
|
| 1944 |
+
"enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms, "reweighting_factor": reweighting_factor,
|
| 1945 |
"wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
|
| 1946 |
"dreambooth_instance_prompt": dreambooth_instance_prompt,
|
| 1947 |
"dreambooth_train_text_encoder": dreambooth_train_text_encoder
|
|
|
|
| 1965 |
training_mode.change(
|
| 1966 |
toggle_task_specific_ui,
|
| 1967 |
inputs=[training_mode],
|
| 1968 |
+
outputs=[classification_labels_ui, dpo_ui, sft_ui, diffusion_ui, dreambooth_ui, peft_accordion, epochs_ui, max_steps_ui, peft_accordion]
|
| 1969 |
+
)
|
| 1970 |
+
|
| 1971 |
+
sft_format_style.change(
|
| 1972 |
+
toggle_sft_format_ui,
|
| 1973 |
+
inputs=[sft_format_style],
|
| 1974 |
+
outputs=[sft_tool_ui]
|
| 1975 |
)
|
| 1976 |
|
| 1977 |
auto_find_target_modules.change(
|