Ignaciohhhhggfgjfrffd commited on
Commit
c4e90bf
verified
1 Parent(s): 0af4c13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -146
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  os.system("pip install -U transformers peft accelerate trl bitsandbytes datasets diffusers")
3
  os.system("pip install spaces-0.1.0-py3-none-any.whl")
4
- import os
5
  import io
6
  import json
7
  import tempfile
@@ -17,9 +16,6 @@ import re
17
  import ast
18
  from itertools import islice
19
  from pathlib import Path
20
- from collections import defaultdict
21
- from datetime import datetime
22
-
23
  import torch
24
  import torch.nn.functional as F
25
  from torch.utils.data import DataLoader
@@ -33,14 +29,15 @@ from langdetect import detect_langs
33
  import textstat
34
  from datasketch import MinHash, MinHashLSH
35
  import gradio as gr
36
- from datasets import load_dataset, IterableDataset, Dataset, DatasetDict, interleave_datasets, Audio
 
37
  from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
38
  from transformers import (
39
  AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
40
  AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
41
- SpeechT5ForTextToSpeech, SpeechT5Processor, SpeechT5HifiGan, AutoModelForImageClassification,
42
  AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
43
- DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq,
44
  AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
45
  DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
46
  LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
@@ -57,8 +54,8 @@ from diffusers import (
57
  )
58
  import evaluate as hf_evaluate
59
  from jinja2 import Template
 
60
 
61
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
62
  logger = logging.getLogger(__name__)
63
 
64
  if torch.cuda.is_available():
@@ -97,8 +94,7 @@ TASK_TO_PIPELINE_MAP = {
97
  "DreamBooth LoRA (Text-to-Image)": "text-to-image",
98
  }
99
 
100
- MODEL_CARD_TEMPLATE = """
101
- ---
102
  language: es
103
  license: apache-2.0
104
  tags:
@@ -136,8 +132,7 @@ Este modelo es una versi贸n afinada de [{base_model}](https://huggingface.co/{ba
136
  - Gradio
137
  """
138
 
139
- DATASET_CARD_TEMPLATE = """
140
- ---
141
  license: mit
142
  ---
143
 
@@ -169,52 +164,6 @@ class DebiasingSFTTrainer(SFTTrainer):
169
  break
170
  return (loss, outputs) if return_outputs else loss
171
 
172
- class DeduplicatedIterableDataset(IterableDataset):
173
- def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
174
- super().__init__(ex_iterable=iter([]))
175
- self.dataset = dataset
176
- self.text_col = text_col
177
- self.method = method
178
- self.threshold = threshold
179
- self.num_perm = num_perm
180
- if hasattr(dataset, '_info'):
181
- self._info = dataset._info
182
- elif hasattr(dataset, 'info'):
183
- self._info = dataset.info
184
-
185
- def __iter__(self):
186
- if self.method == 'Exacta':
187
- return self._exact_iter()
188
- elif self.method == 'Sem谩ntica (MinHash)':
189
- return self._minhash_iter()
190
- else:
191
- return iter(self.dataset)
192
-
193
- def _exact_iter(self):
194
- seen_texts = set()
195
- for example in self.dataset:
196
- text = example.get(self.text_col, "")
197
- if text and isinstance(text, str):
198
- if text not in seen_texts:
199
- seen_texts.add(text)
200
- yield example
201
- else:
202
- yield example
203
-
204
- def _minhash_iter(self):
205
- lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm)
206
- for i, example in enumerate(self.dataset):
207
- text = example.get(self.text_col, "")
208
- if text and isinstance(text, str) and text.strip():
209
- m = MinHash(num_perm=self.num_perm)
210
- for d in text.split():
211
- m.update(d.encode('utf8'))
212
- if not lsh.query(m):
213
- lsh.insert(f"key_{i}", m)
214
- yield example
215
- else:
216
- yield example
217
-
218
  def hf_login(token):
219
  if not token:
220
  return "Por favor, introduce un token."
@@ -330,6 +279,8 @@ def _load_hf_streaming(ids, split="train", probabilities=None):
330
  if probabilities and len(probabilities) != len(streams):
331
  logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
332
  probabilities = None
 
 
333
  return interleave_datasets(streams, probabilities=probabilities)
334
 
335
  def _load_uploaded_stream(files):
@@ -517,7 +468,6 @@ def _create_training_args(output_dir, repo_id, **kwargs):
517
  "save_strategy": "steps",
518
  "logging_steps": int(kwargs.get('logging_steps', 10)),
519
  "save_steps": int(kwargs.get('save_steps', 50)),
520
- "evaluation_strategy": "steps" if kwargs.get('run_evaluation', False) else "no",
521
  "eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
522
  "learning_rate": float(kwargs.get('learning_rate', 2e-5)),
523
  "fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
@@ -727,7 +677,6 @@ def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_c
727
  eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
728
  if eval_logs:
729
  final_metrics = eval_logs[-1]
730
- final_metrics = {k.replace('eval_', ''): v for k, v in final_metrics.items()}
731
 
732
  yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
733
  output_dir = trainer.args.output_dir
@@ -777,16 +726,16 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
777
  if kwargs.get('run_evaluation'):
778
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
779
  for update in eval_dataset_gen:
780
- if isinstance(update, dict):
781
  yield update
782
  else:
783
  eval_dataset = update
784
 
785
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
786
- trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config, "tokenizer": tokenizer, "max_seq_length": int(kwargs.get('block_size'))}
787
 
788
  if is_dpo:
789
- trainer_kwargs.update({"beta": 0.1, "max_prompt_length": int(kwargs.get('block_size')) // 2})
790
  if eval_dataset:
791
  eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
792
  else:
@@ -794,8 +743,18 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
794
  trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
795
  if kwargs.get('enable_loss_reweighting'):
796
  trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
797
-
798
- trainer = TrainerClass(**trainer_kwargs)
 
 
 
 
 
 
 
 
 
 
799
  final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
800
  return final_model_path, final_metrics
801
 
@@ -812,26 +771,23 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
812
  tokenizer_id = kwargs.get('tokenizer_name') or model_name
813
  yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuraci贸n")
814
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
815
- if tokenizer.pad_token is None:
816
- tokenizer.pad_token = tokenizer.eos_token
817
 
818
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuraci贸n")
819
  model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
820
- model.config.pad_token_id = tokenizer.pad_token_id
821
 
822
  def preprocess(examples):
823
  return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
824
- train_dataset = train_dataset.map(preprocess, batched=True)
825
 
826
  eval_dataset = None
827
  if kwargs.get('run_evaluation'):
828
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
829
  for update in eval_dataset_gen:
830
- if isinstance(update, dict):
831
  yield update
832
  else:
833
  eval_dataset = update
834
- if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True)
835
 
836
  metric = hf_evaluate.load("accuracy")
837
  def compute_metrics(eval_pred):
@@ -887,7 +843,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
887
  if kwargs.get('run_evaluation'):
888
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
889
  for update in eval_dataset_gen:
890
- if isinstance(update, dict):
891
  yield update
892
  else:
893
  eval_dataset = update
@@ -979,7 +935,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
979
  eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
980
  eval_dataset_raw = None
981
  for update in eval_dataset_raw_gen:
982
- if isinstance(update, dict):
983
  yield update
984
  else:
985
  eval_dataset_raw = update
@@ -1023,7 +979,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
1023
  if kwargs.get('run_evaluation'):
1024
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
1025
  for update in eval_dataset_gen:
1026
- if isinstance(update, dict):
1027
  yield update
1028
  else:
1029
  eval_dataset = update
@@ -1067,30 +1023,22 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1067
 
1068
  yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
1069
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
1070
- text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=torch_dtype_auto)
1071
- vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=torch_dtype_auto)
1072
- unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=torch_dtype_auto)
1073
  noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
1074
 
1075
  vae.requires_grad_(False)
1076
  text_encoder.requires_grad_(False)
1077
  unet.train()
1078
 
1079
- yield update_logs_fn("Agregando adaptadores LoRA al UNet...", "Text-to-Image (LoRA)")
1080
  unet_lora_config = LoraConfig(
1081
  r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
1082
  target_modules=["to_q", "to_k", "to_v", "to_out.0"],
1083
  )
1084
  unet.add_adapter(unet_lora_config)
1085
 
1086
- if kwargs.get('dreambooth_train_text_encoder', False):
1087
- yield update_logs_fn("Agregando adaptadores LoRA al Text Encoder...", "DreamBooth LoRA")
1088
- text_encoder_lora_config = LoraConfig(
1089
- r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
1090
- target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
1091
- )
1092
- text_encoder.add_adapter(text_encoder_lora_config)
1093
-
1094
  yield update_logs_fn("Procesando dataset de im谩genes...", "Text-to-Image (LoRA)")
1095
  resolution = int(kwargs.get('diffusion_resolution', 512))
1096
 
@@ -1102,7 +1050,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1102
  ])
1103
 
1104
  def preprocess_train(examples):
1105
- images = [image.convert("RGB") for image in examples[kwargs.get('image_col', 'image')]]
1106
  examples["pixel_values"] = [train_transforms(image) for image in images]
1107
  examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
1108
  return examples
@@ -1116,17 +1064,14 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1116
 
1117
  def collate_fn(examples):
1118
  pixel_values = torch.stack([example["pixel_values"] for example in examples])
1119
- input_ids = torch.stack([e["input_ids"][0] for e in examples])
1120
  return {"pixel_values": pixel_values, "input_ids": input_ids}
1121
 
1122
  train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
1123
-
1124
- params_to_optimize = list(unet.parameters())
1125
- if kwargs.get('dreambooth_train_text_encoder', False):
1126
- params_to_optimize += list(text_encoder.parameters())
1127
 
 
1128
  optimizer = torch.optim.AdamW(
1129
- params_to_optimize, lr=float(kwargs.get('learning_rate', 2e-5)),
1130
  betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
1131
  weight_decay=float(kwargs.get('weight_decay', 0.01)),
1132
  eps=float(kwargs.get('adam_epsilon', 1e-8)),
@@ -1142,34 +1087,36 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1142
  num_training_steps=max_train_steps,
1143
  )
1144
 
1145
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1146
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1147
  )
1148
 
 
1149
  vae.to(accelerator.device, dtype=torch_dtype_auto)
1150
 
 
1151
  global_step = 0
1152
  final_loss = 0
1153
  for epoch in range(num_epochs):
1154
  for step, batch in enumerate(train_dataloader):
1155
  with accelerator.accumulate(unet):
1156
- latents = vae.encode(batch["pixel_values"].to(dtype=torch_dtype_auto)).latent_dist.sample()
1157
  latents = latents * vae.config.scaling_factor
1158
  noise = torch.randn_like(latents)
1159
  bsz = latents.shape[0]
1160
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
 
 
1161
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1162
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
 
1163
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1164
  loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
1165
  final_loss = loss.detach().item()
1166
 
1167
  accelerator.backward(loss)
1168
  if accelerator.sync_gradients:
1169
- params_to_clip = list(unet.parameters())
1170
- if kwargs.get('dreambooth_train_text_encoder', False):
1171
- params_to_clip += list(text_encoder.parameters())
1172
- accelerator.clip_grad_norm_(params_to_clip, float(kwargs.get('max_grad_norm', 1.0)))
1173
 
1174
  optimizer.step()
1175
  lr_scheduler.step()
@@ -1177,21 +1124,16 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1177
 
1178
  if accelerator.is_main_process:
1179
  if global_step % int(kwargs.get('logging_steps', 10)) == 0:
1180
- yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss:.4f}", "Entrenando Difusi贸n")
1181
  global_step += 1
1182
- if global_step >= max_train_steps:
1183
- break
1184
- if global_step >= max_train_steps:
1185
- break
1186
 
 
1187
  accelerator.wait_for_everyone()
1188
  if accelerator.is_main_process:
1189
- pipeline = StableDiffusionText2ImagePipeline.from_pretrained(
1190
- model_name,
1191
- unet=accelerator.unwrap_model(unet),
1192
- text_encoder=accelerator.unwrap_model(text_encoder),
1193
- torch_dtype=torch_dtype_auto,
1194
- )
1195
  pipeline.save_pretrained(output_dir)
1196
 
1197
  with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
@@ -1206,6 +1148,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, mode
1206
  torch.cuda.empty_cache()
1207
  return output_dir, {"final_loss": final_loss}
1208
 
 
1209
  def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1210
  if device == 'cpu':
1211
  raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
@@ -1220,7 +1163,7 @@ def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, mo
1220
 
1221
  train_dataset = train_dataset.map(add_prompt)
1222
 
1223
- yield update_logs_fn(f"Usando el prompt de instancia para todas las im谩genes: '{dreambooth_prompt}'", "DreamBooth LoRA")
1224
 
1225
  final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
1226
  return final_model_path, final_metrics
@@ -1251,8 +1194,9 @@ def _get_data_processing_pipeline(**kwargs):
1251
  if train_dataset is None:
1252
  train_dataset = hf_train_dataset
1253
  else:
 
1254
  all_streams = [train_dataset, hf_train_dataset]
1255
- all_probs = [0.5, 0.5]
1256
  train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
1257
 
1258
  if train_dataset is None:
@@ -1262,8 +1206,7 @@ def _get_data_processing_pipeline(**kwargs):
1262
  text_col, image_col, audio_col, label_col = _guess_columns(first_example)
1263
  kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
1264
 
1265
- is_text_task = kwargs['training_mode'] not in ["DreamBooth LoRA (Text-to-Image)", "Text-to-Image (LoRA)", "Image Classification (Vision)", "Audio Classification (Speech)"]
1266
- if is_text_task:
1267
  if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
1268
  clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
1269
  train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
@@ -1286,13 +1229,35 @@ def _get_data_processing_pipeline(**kwargs):
1286
 
1287
  dedup_method = kwargs.get('deduplication_method')
1288
  if dedup_method != 'Ninguna':
1289
- train_dataset = DeduplicatedIterableDataset(
1290
- dataset=train_dataset,
1291
- text_col=text_col,
1292
- method=dedup_method,
1293
- threshold=kwargs.get('minhash_threshold', 0.85),
1294
- num_perm=int(kwargs.get('minhash_num_perm', 128))
1295
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1296
 
1297
  return train_dataset, kwargs
1298
 
@@ -1385,6 +1350,7 @@ def _train_and_upload(**kwargs):
1385
  raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
1386
  base_model_id_for_training = temp_model_dir
1387
  kwargs["peft"] = False
 
1388
  kwargs['tokenizer_name'] = temp_model_dir
1389
  yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
1390
 
@@ -1397,6 +1363,7 @@ def _train_and_upload(**kwargs):
1397
  os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
1398
  os.environ["WANDB_LOG_MODEL"] = "checkpoint"
1399
 
 
1400
  model_card_content = MODEL_CARD_TEMPLATE.format(
1401
  repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
1402
  training_mode=kwargs.get('training_mode'),
@@ -1423,11 +1390,8 @@ def _train_and_upload(**kwargs):
1423
  train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
1424
  while True:
1425
  try:
1426
- update = next(train_generator)
1427
- if isinstance(update, tuple) and len(update) == 4:
1428
- yield update + (gr.update(), gr.update())
1429
- else:
1430
- pass
1431
  except StopIteration as e:
1432
  final_model_path, final_metrics = e.value
1433
  break
@@ -1441,7 +1405,7 @@ def _train_and_upload(**kwargs):
1441
  eval_dataset_perp = None
1442
  eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
1443
  for update in eval_gen:
1444
- if isinstance(update, dict):
1445
  yield update + (gr.update(), gr.update())
1446
  else:
1447
  eval_dataset_perp = update
@@ -1561,6 +1525,7 @@ def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, s
1561
  for item in all_data:
1562
  f.write(json.dumps(item, ensure_ascii=False) + "\n")
1563
 
 
1564
  readme_content = DATASET_CARD_TEMPLATE.format(
1565
  repo_id=repo_id,
1566
  creation_type=creation_type,
@@ -1598,13 +1563,9 @@ def gradio_preview_data_wrapper(*args):
1598
  dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
1599
  text_col = processed_kwargs.get('text_col')
1600
 
1601
- model_id_for_tokenizer = kwargs.get('model_base_input')
1602
- if not model_id_for_tokenizer:
1603
- raise ValueError("Se necesita un ID de modelo base para cargar el tokenizer para la vista previa.")
1604
-
1605
- tokenizer_id = kwargs.get('tokenizer_name') or model_id_for_tokenizer
1606
  tokenizer = AutoTokenizer.from_pretrained(
1607
- tokenizer_id, trust_remote_code=True, use_fast=False
 
1608
  )
1609
  if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
1610
  if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
@@ -1613,15 +1574,15 @@ def gradio_preview_data_wrapper(*args):
1613
  for i, example in enumerate(islice(dataset, 5)):
1614
  formatted_text = ""
1615
  if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
1616
- formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2, ensure_ascii=False)
1617
  else:
1618
  formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
1619
 
1620
  preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
1621
 
1622
  preview_text = "\n".join(preview_samples)
1623
- if not preview_samples:
1624
- preview_text = "No se pudieron generar muestras. Revisa la configuraci贸n del dataset, los filtros y el formato."
1625
  yield preview_text
1626
 
1627
  except Exception as e:
@@ -1643,7 +1604,6 @@ def toggle_task_specific_ui(training_mode):
1643
  is_sft = "Causal" in training_mode
1644
  is_ner = "Token Classification" in training_mode
1645
  is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
1646
- is_streaming = not is_diffusion
1647
 
1648
  return (
1649
  gr.update(visible=is_classification or is_ner),
@@ -1653,10 +1613,10 @@ def toggle_task_specific_ui(training_mode):
1653
  gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
1654
  gr.update(visible=not is_diffusion),
1655
  gr.update(visible=is_diffusion),
1656
- gr.update(visible=not is_streaming),
1657
- gr.update(visible=is_streaming)
1658
  )
1659
 
 
1660
  def toggle_auto_modules_ui(is_auto):
1661
  return gr.update(visible=not is_auto)
1662
 
@@ -1690,7 +1650,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1690
  dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
1691
  dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
1692
  with gr.Column(scale=2):
1693
- dset_status_output = gr.Textbox(label="Estado", lines=10, interactive=False)
1694
  dset_link_output = gr.Markdown()
1695
 
1696
  dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
@@ -1746,7 +1706,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1746
  with gr.Accordion("Avanzados", open=False):
1747
  warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
1748
  weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
1749
- max_grad_norm = gr.Textbox(label="Norma M谩xima de Gradiente", value="1.0")
1750
  logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1751
  save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1752
  save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
@@ -1806,6 +1766,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1806
  diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
1807
  with gr.Group(visible=False) as dreambooth_ui:
1808
  dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
 
 
 
1809
  dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
1810
  with gr.Group(visible=False) as classification_labels_ui:
1811
  classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
@@ -1824,6 +1787,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1824
  enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
1825
  cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
1826
 
 
1827
  with gr.Accordion("馃攲 Integraciones", open=False):
1828
  wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
1829
  wandb_project_input = gr.Textbox(label="Proyecto W&B")
@@ -1868,7 +1832,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1868
  "diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
1869
  "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
1870
  "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
1871
- "dreambooth_instance_prompt": dreambooth_instance_prompt,
 
1872
  "dreambooth_train_text_encoder": dreambooth_train_text_encoder
1873
  }
1874
 
@@ -1940,4 +1905,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1940
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1941
  )
1942
 
1943
- demo.queue().launch(debug=True)
 
1
  import os
2
  os.system("pip install -U transformers peft accelerate trl bitsandbytes datasets diffusers")
3
  os.system("pip install spaces-0.1.0-py3-none-any.whl")
 
4
  import io
5
  import json
6
  import tempfile
 
16
  import ast
17
  from itertools import islice
18
  from pathlib import Path
 
 
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch.utils.data import DataLoader
 
29
  import textstat
30
  from datasketch import MinHash, MinHashLSH
31
  import gradio as gr
32
+ import spaces
33
+ from datasets import load_dataset, IterableDataset, Dataset, DatasetDict
34
  from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi
35
  from transformers import (
36
  AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
37
  AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
38
+ AutoModelForImageClassification,
39
  AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
40
+ DataCollatorForTokenClassification, AutoModelForQuestionAnswering,
41
  AutoProcessor, DataCollatorWithPadding, pipeline, CLIPTextModel, CLIPTokenizer,
42
  DataCollatorForSeq2Seq, AutoModelForSequenceClassification, BitsAndBytesConfig,
43
  LlamaConfig, LlamaForCausalLM, MistralConfig, MistralForCausalLM, GemmaConfig, GemmaForCausalLM, GPT2Config, GPT2LMHeadModel,
 
54
  )
55
  import evaluate as hf_evaluate
56
  from jinja2 import Template
57
+ from collections import defaultdict
58
 
 
59
  logger = logging.getLogger(__name__)
60
 
61
  if torch.cuda.is_available():
 
94
  "DreamBooth LoRA (Text-to-Image)": "text-to-image",
95
  }
96
 
97
+ MODEL_CARD_TEMPLATE = """---
 
98
  language: es
99
  license: apache-2.0
100
  tags:
 
132
  - Gradio
133
  """
134
 
135
+ DATASET_CARD_TEMPLATE = """---
 
136
  license: mit
137
  ---
138
 
 
164
  break
165
  return (loss, outputs) if return_outputs else loss
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def hf_login(token):
168
  if not token:
169
  return "Por favor, introduce un token."
 
279
  if probabilities and len(probabilities) != len(streams):
280
  logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
281
  probabilities = None
282
+
283
+ from datasets import interleave_datasets
284
  return interleave_datasets(streams, probabilities=probabilities)
285
 
286
  def _load_uploaded_stream(files):
 
468
  "save_strategy": "steps",
469
  "logging_steps": int(kwargs.get('logging_steps', 10)),
470
  "save_steps": int(kwargs.get('save_steps', 50)),
 
471
  "eval_steps": int(kwargs.get('save_steps', 50)) if kwargs.get('run_evaluation', False) else None,
472
  "learning_rate": float(kwargs.get('learning_rate', 2e-5)),
473
  "fp16": kwargs.get('mixed_precision') == 'fp16' and device == 'cuda',
 
677
  eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
678
  if eval_logs:
679
  final_metrics = eval_logs[-1]
 
680
 
681
  yield update_logs_fn("Entrenamiento finalizado.", "Guardando")
682
  output_dir = trainer.args.output_dir
 
726
  if kwargs.get('run_evaluation'):
727
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
728
  for update in eval_dataset_gen:
729
+ if isinstance(update, tuple):
730
  yield update
731
  else:
732
  eval_dataset = update
733
 
734
  TrainerClass = DPOTrainer if is_dpo else (DebiasingSFTTrainer if kwargs.get('enable_loss_reweighting') else SFTTrainer)
735
+ trainer_kwargs = {"model": model, "args": training_args, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "peft_config": peft_config, "tokenizer": tokenizer}
736
 
737
  if is_dpo:
738
+ trainer_kwargs.update({"beta": 0.1, "max_length": int(kwargs.get('block_size')), "max_prompt_length": int(kwargs.get('block_size')) // 2})
739
  if eval_dataset:
740
  eval_dataset = eval_dataset.map(lambda ex: _dpo_formatting_func(ex, **kwargs))
741
  else:
 
743
  trainer_kwargs.update({"formatting_func": lambda ex: _sft_formatting_func(example=ex, tokenizer=tokenizer, text_col=text_col, **sft_kwargs)})
744
  if kwargs.get('enable_loss_reweighting'):
745
  trainer_kwargs.update({'reweighting_terms': kwargs.get('reweighting_terms', '').split(','), 'reweighting_factor': kwargs.get('reweighting_factor', 2.0)})
746
+
747
+ try:
748
+ trainer = TrainerClass(**trainer_kwargs)
749
+ except TypeError as e:
750
+ if "unexpected keyword argument 'tokenizer'" in str(e):
751
+ logger.warning("Caught TypeError for tokenizer argument. Retrying without it for TRL compatibility.")
752
+ trainer_kwargs.pop("tokenizer", None)
753
+ trainer = TrainerClass(**trainer_kwargs)
754
+ trainer.tokenizer = tokenizer
755
+ else:
756
+ raise e
757
+
758
  final_model_path, final_metrics = yield from _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs)
759
  return final_model_path, final_metrics
760
 
 
771
  tokenizer_id = kwargs.get('tokenizer_name') or model_name
772
  yield update_logs_fn(f"Cargando tokenizer '{tokenizer_id}'...", "Configuraci贸n")
773
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
 
 
774
 
775
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuraci贸n")
776
  model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
 
777
 
778
  def preprocess(examples):
779
  return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
780
+ train_dataset = train_dataset.map(preprocess)
781
 
782
  eval_dataset = None
783
  if kwargs.get('run_evaluation'):
784
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
785
  for update in eval_dataset_gen:
786
+ if isinstance(update, tuple):
787
  yield update
788
  else:
789
  eval_dataset = update
790
+ if eval_dataset: eval_dataset = eval_dataset.map(preprocess)
791
 
792
  metric = hf_evaluate.load("accuracy")
793
  def compute_metrics(eval_pred):
 
843
  if kwargs.get('run_evaluation'):
844
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
845
  for update in eval_dataset_gen:
846
+ if isinstance(update, tuple):
847
  yield update
848
  else:
849
  eval_dataset = update
 
935
  eval_dataset_raw_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
936
  eval_dataset_raw = None
937
  for update in eval_dataset_raw_gen:
938
+ if isinstance(update, tuple):
939
  yield update
940
  else:
941
  eval_dataset_raw = update
 
979
  if kwargs.get('run_evaluation'):
980
  eval_dataset_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), update_logs_fn)
981
  for update in eval_dataset_gen:
982
+ if isinstance(update, tuple):
983
  yield update
984
  else:
985
  eval_dataset = update
 
1023
 
1024
  yield update_logs_fn("Configurando componentes de Diffusers...", "Text-to-Image (LoRA)")
1025
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
1026
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
1027
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
1028
+ unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")
1029
  noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
1030
 
1031
  vae.requires_grad_(False)
1032
  text_encoder.requires_grad_(False)
1033
  unet.train()
1034
 
1035
+ yield update_logs_fn("Agregando adaptadores LoRA al modelo...", "Text-to-Image (LoRA)")
1036
  unet_lora_config = LoraConfig(
1037
  r=int(kwargs.get('lora_r', 16)), lora_alpha=int(kwargs.get('lora_alpha', 32)),
1038
  target_modules=["to_q", "to_k", "to_v", "to_out.0"],
1039
  )
1040
  unet.add_adapter(unet_lora_config)
1041
 
 
 
 
 
 
 
 
 
1042
  yield update_logs_fn("Procesando dataset de im谩genes...", "Text-to-Image (LoRA)")
1043
  resolution = int(kwargs.get('diffusion_resolution', 512))
1044
 
 
1050
  ])
1051
 
1052
  def preprocess_train(examples):
1053
+ images = [Image.open(image).convert("RGB") for image in examples[kwargs.get('image_col', 'image')]]
1054
  examples["pixel_values"] = [train_transforms(image) for image in images]
1055
  examples["input_ids"] = tokenizer(examples[kwargs.get('text_col', 'text')], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
1056
  return examples
 
1064
 
1065
  def collate_fn(examples):
1066
  pixel_values = torch.stack([example["pixel_values"] for example in examples])
1067
+ input_ids = torch.stack([example["input_ids"] for example in examples])
1068
  return {"pixel_values": pixel_values, "input_ids": input_ids}
1069
 
1070
  train_dataloader = DataLoader(processed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=int(kwargs.get('batch_size', 1)))
 
 
 
 
1071
 
1072
+ yield update_logs_fn("Configurando optimizador y planificador...", "Text-to-Image (LoRA)")
1073
  optimizer = torch.optim.AdamW(
1074
+ unet.parameters(), lr=float(kwargs.get('learning_rate', 2e-5)),
1075
  betas=(float(kwargs.get('adam_beta1', 0.9)), float(kwargs.get('adam_beta2', 0.999))),
1076
  weight_decay=float(kwargs.get('weight_decay', 0.01)),
1077
  eps=float(kwargs.get('adam_epsilon', 1e-8)),
 
1087
  num_training_steps=max_train_steps,
1088
  )
1089
 
1090
+ unet, optimizer, train_dataloader, lr_scheduler, text_encoder, vae = accelerator.prepare(
1091
+ unet, optimizer, train_dataloader, lr_scheduler, text_encoder, vae
1092
  )
1093
 
1094
+ text_encoder.to(accelerator.device, dtype=torch_dtype_auto)
1095
  vae.to(accelerator.device, dtype=torch_dtype_auto)
1096
 
1097
+ yield update_logs_fn("Iniciando bucle de entrenamiento de difusi贸n...", "Text-to-Image (LoRA)")
1098
  global_step = 0
1099
  final_loss = 0
1100
  for epoch in range(num_epochs):
1101
  for step, batch in enumerate(train_dataloader):
1102
  with accelerator.accumulate(unet):
1103
+ latents = vae.encode(batch["pixel_values"].to(torch_dtype_auto)).latent_dist.sample()
1104
  latents = latents * vae.config.scaling_factor
1105
  noise = torch.randn_like(latents)
1106
  bsz = latents.shape[0]
1107
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1108
+ timesteps = timesteps.long()
1109
+
1110
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1111
+ encoder_hidden_states = text_encoder(batch["input_ids"].to(accelerator.device))[0]
1112
+
1113
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1114
  loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
1115
  final_loss = loss.detach().item()
1116
 
1117
  accelerator.backward(loss)
1118
  if accelerator.sync_gradients:
1119
+ accelerator.clip_grad_norm_(unet.parameters(), float(kwargs.get('max_grad_norm', 1.0)))
 
 
 
1120
 
1121
  optimizer.step()
1122
  lr_scheduler.step()
 
1124
 
1125
  if accelerator.is_main_process:
1126
  if global_step % int(kwargs.get('logging_steps', 10)) == 0:
1127
+ yield update_logs_fn(f"Epoch {epoch}, Step {step}, Loss: {final_loss}", "Text-to-Image (LoRA)")
1128
  global_step += 1
 
 
 
 
1129
 
1130
+ yield update_logs_fn("Entrenamiento completado, guardando modelo...", "Text-to-Image (LoRA)")
1131
  accelerator.wait_for_everyone()
1132
  if accelerator.is_main_process:
1133
+ unwrapped_unet = accelerator.unwrap_model(unet)
1134
+
1135
+ pipeline = StableDiffusionText2ImagePipeline.from_pretrained(model_name, torch_dtype=torch_dtype_auto)
1136
+ pipeline.unet.load_state_dict(unwrapped_unet.state_dict())
 
 
1137
  pipeline.save_pretrained(output_dir)
1138
 
1139
  with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
 
1148
  torch.cuda.empty_cache()
1149
  return output_dir, {"final_loss": final_loss}
1150
 
1151
+
1152
  def train_dreambooth_lora(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1153
  if device == 'cpu':
1154
  raise ValueError("El entrenamiento de DreamBooth solo es compatible con GPU CUDA.")
 
1163
 
1164
  train_dataset = train_dataset.map(add_prompt)
1165
 
1166
+ yield update_logs_fn(f"Usando el prompt de instancia para todas las im谩genes: '{dreambooth_prompt}'", "DreamBooth LoRA (Text-to-Image)")
1167
 
1168
  final_model_path, final_metrics = yield from train_text_to_image(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs)
1169
  return final_model_path, final_metrics
 
1194
  if train_dataset is None:
1195
  train_dataset = hf_train_dataset
1196
  else:
1197
+ from datasets import interleave_datasets
1198
  all_streams = [train_dataset, hf_train_dataset]
1199
+ all_probs = [0.5, 0.5] if not probabilities else [probabilities] + probabilities[1:]
1200
  train_dataset = interleave_datasets(all_streams, probabilities=all_probs)
1201
 
1202
  if train_dataset is None:
 
1206
  text_col, image_col, audio_col, label_col = _guess_columns(first_example)
1207
  kwargs.update({'text_col': text_col, 'image_col': image_col, 'audio_col': audio_col, 'label_col': label_col, 'uploaded_val_data': uploaded_val_data})
1208
 
1209
+ if kwargs['training_mode'] not in ["DreamBooth LoRA (Text-to-Image)", "Text-to-Image (LoRA)"]:
 
1210
  if any([kwargs.get('remove_html_tags'), kwargs.get('normalize_whitespace'), kwargs.get('remove_urls_emails'), kwargs.get('redact_pii')]):
1211
  clean_kwargs = {k:v for k,v in kwargs.items() if k in ['remove_html_tags', 'normalize_whitespace', 'remove_urls_emails', 'redact_pii']}
1212
  train_dataset = train_dataset.map(lambda ex: _clean_text(ex, text_col, **clean_kwargs))
 
1229
 
1230
  dedup_method = kwargs.get('deduplication_method')
1231
  if dedup_method != 'Ninguna':
1232
+ base_iterator = train_dataset
1233
+ if dedup_method == 'Exacta':
1234
+ def dedup_generator_exact():
1235
+ seen_texts = set()
1236
+ for example in base_iterator:
1237
+ text = example.get(text_col, "")
1238
+ if not isinstance(text, str) or text not in seen_texts:
1239
+ if isinstance(text, str) and text:
1240
+ seen_texts.add(text)
1241
+ yield example
1242
+ train_dataset = IterableDataset.from_generator(dedup_generator_exact)
1243
+ elif dedup_method == 'Sem谩ntica (MinHash)':
1244
+ threshold = kwargs.get('minhash_threshold', 0.85)
1245
+ num_perm = int(kwargs.get('minhash_num_perm', 128))
1246
+ def dedup_generator_minhash():
1247
+ lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
1248
+ for i, example in enumerate(base_iterator):
1249
+ text = example.get(text_col, "")
1250
+ if text and isinstance(text, str) and text.strip():
1251
+ m = MinHash(num_perm=num_perm)
1252
+ for d in text.split():
1253
+ m.update(d.encode('utf8'))
1254
+ if not lsh.query(m):
1255
+ lsh.insert(f"key_{i}", m)
1256
+ yield example
1257
+ else:
1258
+ yield example
1259
+ train_dataset = IterableDataset.from_generator(dedup_generator_minhash)
1260
+
1261
 
1262
  return train_dataset, kwargs
1263
 
 
1350
  raise Exception(f"No se pudo cargar el tokenizer base '{tokenizer_id}' para el modelo desde cero: {e}")
1351
  base_model_id_for_training = temp_model_dir
1352
  kwargs["peft"] = False
1353
+ kwargs["merge_adapter"] = False
1354
  kwargs['tokenizer_name'] = temp_model_dir
1355
  yield update_logs(f"Modelo {architecture} inicializado en {temp_model_dir}.", "Modelo Cero") + (gr.update(), gr.update())
1356
 
 
1363
  os.environ["WANDB_PROJECT"] = kwargs.get('wandb_project_input') or f"{repo_base}"
1364
  os.environ["WANDB_LOG_MODEL"] = "checkpoint"
1365
 
1366
+ from datetime import datetime
1367
  model_card_content = MODEL_CARD_TEMPLATE.format(
1368
  repo_id=repo_id, base_model=model_name, base_model_name=model_name.split('/')[-1],
1369
  training_mode=kwargs.get('training_mode'),
 
1390
  train_generator = train_func(base_model_id_for_training, train_dataset, repo_id, update_logs, model_card_content, **kwargs)
1391
  while True:
1392
  try:
1393
+ update_tuple = next(train_generator)
1394
+ yield update_tuple + (gr.update(), gr.update())
 
 
 
1395
  except StopIteration as e:
1396
  final_model_path, final_metrics = e.value
1397
  break
 
1405
  eval_dataset_perp = None
1406
  eval_gen = _get_eval_dataset(kwargs.get('datasets_hf_text').split(","), kwargs.get('eval_dataset_hf'), kwargs.get('uploaded_val_data'), lambda m, p: update_logs(m, p))
1407
  for update in eval_gen:
1408
+ if isinstance(update, tuple):
1409
  yield update + (gr.update(), gr.update())
1410
  else:
1411
  eval_dataset_perp = update
 
1525
  for item in all_data:
1526
  f.write(json.dumps(item, ensure_ascii=False) + "\n")
1527
 
1528
+ from datetime import datetime
1529
  readme_content = DATASET_CARD_TEMPLATE.format(
1530
  repo_id=repo_id,
1531
  creation_type=creation_type,
 
1563
  dataset, processed_kwargs = _get_data_processing_pipeline(**kwargs)
1564
  text_col = processed_kwargs.get('text_col')
1565
 
 
 
 
 
 
1566
  tokenizer = AutoTokenizer.from_pretrained(
1567
+ kwargs.get('tokenizer_name') or kwargs.get('model_base_input') or 'gpt2',
1568
+ trust_remote_code=True, use_fast=False
1569
  )
1570
  if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
1571
  if kwargs.get('chat_template_jinja', '').strip(): tokenizer.chat_template = kwargs['chat_template_jinja']
 
1574
  for i, example in enumerate(islice(dataset, 5)):
1575
  formatted_text = ""
1576
  if kwargs['training_mode'] == "DPO (Direct Preference Optimization)":
1577
+ formatted_text = json.dumps(_dpo_formatting_func(example, **kwargs), indent=2)
1578
  else:
1579
  formatted_text = _sft_formatting_func(example, text_col, tokenizer, **kwargs)
1580
 
1581
  preview_samples.append(f"--- MUESTRA {i+1} ---\n{formatted_text}\n")
1582
 
1583
  preview_text = "\n".join(preview_samples)
1584
+ if not preview_text:
1585
+ preview_text = "No se pudieron generar muestras. Revisa la configuraci贸n del dataset y el formato."
1586
  yield preview_text
1587
 
1588
  except Exception as e:
 
1604
  is_sft = "Causal" in training_mode
1605
  is_ner = "Token Classification" in training_mode
1606
  is_diffusion = training_mode in ["Text-to-Image (LoRA)", "DreamBooth LoRA (Text-to-Image)"]
 
1607
 
1608
  return (
1609
  gr.update(visible=is_classification or is_ner),
 
1613
  gr.update(visible=training_mode == "DreamBooth LoRA (Text-to-Image)"),
1614
  gr.update(visible=not is_diffusion),
1615
  gr.update(visible=is_diffusion),
1616
+ gr.update(visible=not is_diffusion),
 
1617
  )
1618
 
1619
+
1620
  def toggle_auto_modules_ui(is_auto):
1621
  return gr.update(visible=not is_auto)
1622
 
 
1650
  dset_file_uploads = gr.File(label="Subir Archivos (.jsonl, .csv, .txt)", file_count="multiple")
1651
  dset_create_button = gr.Button("Crear y Subir Dataset", variant="primary")
1652
  with gr.Column(scale=2):
1653
+ dset_status_output = gr.Textbox(label="Estado", lines=10)
1654
  dset_link_output = gr.Markdown()
1655
 
1656
  dset_creation_type.change(toggle_dataset_creator_ui, inputs=[dset_creation_type], outputs=[dset_synth_group, dset_file_group])
 
1706
  with gr.Accordion("Avanzados", open=False):
1707
  warmup_ratio = gr.Slider(0.0, 0.5, 0.03, label="Ratio de Calentamiento")
1708
  weight_decay = gr.Textbox(label="Decaimiento de Peso", value="0.01")
1709
+ max_grad_norm = gr.Textbox(label="Norma M谩xima de Gradiente", value="0.3")
1710
  logging_steps = gr.Textbox(label="Pasos de Registro", value="10")
1711
  save_steps = gr.Textbox(label="Pasos de Guardado", value="50")
1712
  save_total_limit = gr.Textbox(label="L铆mite Total de Guardado", value="1")
 
1766
  diffusion_resolution = gr.Slider(256, 1024, 512, step=64, label="Resoluci贸n")
1767
  with gr.Group(visible=False) as dreambooth_ui:
1768
  dreambooth_instance_prompt = gr.Textbox(label="Prompt de Instancia", placeholder="p.ej. 'foto de perro sks'")
1769
+ dreambooth_class_prompt = gr.Textbox(label="Prompt de Clase (Opcional)", placeholder="p.ej. 'foto de perro'")
1770
+ dreambooth_num_class_images = gr.Slider(0, 1000, 100, step=10, label="N潞 de Im谩genes de Clase")
1771
+ dreambooth_prior_loss_weight = gr.Slider(0.0, 2.0, 1.0, label="Peso de P茅rdida a Priori")
1772
  dreambooth_train_text_encoder = gr.Checkbox(label="Entrenar Text Encoder", value=True)
1773
  with gr.Group(visible=False) as classification_labels_ui:
1774
  classification_labels = gr.Textbox(label="Etiquetas de Clasificaci贸n (csv)", placeholder="p.ej. positivo,negativo")
 
1787
  enable_cda = gr.Checkbox(label="Habilitar Aumentaci贸n Contrafactual (CDA)", value=False)
1788
  cda_json_config = gr.Textbox(label="Configuraci贸n CDA (JSON)", placeholder='[["ella", "茅l"], ["mujer", "hombre"]]')
1789
 
1790
+
1791
  with gr.Accordion("馃攲 Integraciones", open=False):
1792
  wandb_api_key_input = gr.Textbox(label="Clave API de W&B", type="password")
1793
  wandb_project_input = gr.Textbox(label="Proyecto W&B")
 
1832
  "diffusion_resolution": diffusion_resolution, "run_evaluation": run_evaluation, "run_perplexity_evaluation": run_perplexity_evaluation,
1833
  "enable_loss_reweighting": enable_loss_reweighting, "reweighting_terms": reweighting_terms,
1834
  "wandb_api_key_input": wandb_api_key_input, "wandb_project_input": wandb_project_input,
1835
+ "dreambooth_instance_prompt": dreambooth_instance_prompt, "dreambooth_class_prompt": dreambooth_class_prompt,
1836
+ "dreambooth_num_class_images": dreambooth_num_class_images, "dreambooth_prior_loss_weight": dreambooth_prior_loss_weight,
1837
  "dreambooth_train_text_encoder": dreambooth_train_text_encoder
1838
  }
1839
 
 
1905
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1906
  )
1907
 
1908
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)